##// END OF EJS Templates
python3: fixed various code issues...
super-admin -
r4973:5e52ba1a default
parent child Browse files
Show More

The requested changes are too big and content was truncated. Show full diff

@@ -1,578 +1,578 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2011-2020 RhodeCode GmbH
3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import itertools
21 import itertools
22 import logging
22 import logging
23 import sys
23 import sys
24 import types
24 import types
25 import fnmatch
25 import fnmatch
26
26
27 import decorator
27 import decorator
28 import venusian
28 import venusian
29 from collections import OrderedDict
29 from collections import OrderedDict
30
30
31 from pyramid.exceptions import ConfigurationError
31 from pyramid.exceptions import ConfigurationError
32 from pyramid.renderers import render
32 from pyramid.renderers import render
33 from pyramid.response import Response
33 from pyramid.response import Response
34 from pyramid.httpexceptions import HTTPNotFound
34 from pyramid.httpexceptions import HTTPNotFound
35
35
36 from rhodecode.api.exc import (
36 from rhodecode.api.exc import (
37 JSONRPCBaseError, JSONRPCError, JSONRPCForbidden, JSONRPCValidationError)
37 JSONRPCBaseError, JSONRPCError, JSONRPCForbidden, JSONRPCValidationError)
38 from rhodecode.apps._base import TemplateArgs
38 from rhodecode.apps._base import TemplateArgs
39 from rhodecode.lib.auth import AuthUser
39 from rhodecode.lib.auth import AuthUser
40 from rhodecode.lib.base import get_ip_addr, attach_context_attributes
40 from rhodecode.lib.base import get_ip_addr, attach_context_attributes
41 from rhodecode.lib.exc_tracking import store_exception
41 from rhodecode.lib.exc_tracking import store_exception
42 from rhodecode.lib.ext_json import json
42 from rhodecode.lib.ext_json import json
43 from rhodecode.lib.utils2 import safe_str
43 from rhodecode.lib.utils2 import safe_str
44 from rhodecode.lib.plugins.utils import get_plugin_settings
44 from rhodecode.lib.plugins.utils import get_plugin_settings
45 from rhodecode.model.db import User, UserApiKeys
45 from rhodecode.model.db import User, UserApiKeys
46
46
47 log = logging.getLogger(__name__)
47 log = logging.getLogger(__name__)
48
48
49 DEFAULT_RENDERER = 'jsonrpc_renderer'
49 DEFAULT_RENDERER = 'jsonrpc_renderer'
50 DEFAULT_URL = '/_admin/apiv2'
50 DEFAULT_URL = '/_admin/apiv2'
51
51
52
52
53 def find_methods(jsonrpc_methods, pattern):
53 def find_methods(jsonrpc_methods, pattern):
54 matches = OrderedDict()
54 matches = OrderedDict()
55 if not isinstance(pattern, (list, tuple)):
55 if not isinstance(pattern, (list, tuple)):
56 pattern = [pattern]
56 pattern = [pattern]
57
57
58 for single_pattern in pattern:
58 for single_pattern in pattern:
59 for method_name, method in jsonrpc_methods.items():
59 for method_name, method in jsonrpc_methods.items():
60 if fnmatch.fnmatch(method_name, single_pattern):
60 if fnmatch.fnmatch(method_name, single_pattern):
61 matches[method_name] = method
61 matches[method_name] = method
62 return matches
62 return matches
63
63
64
64
65 class ExtJsonRenderer(object):
65 class ExtJsonRenderer(object):
66 """
66 """
67 Custom renderer that mkaes use of our ext_json lib
67 Custom renderer that mkaes use of our ext_json lib
68
68
69 """
69 """
70
70
71 def __init__(self, serializer=json.dumps, **kw):
71 def __init__(self, serializer=json.dumps, **kw):
72 """ Any keyword arguments will be passed to the ``serializer``
72 """ Any keyword arguments will be passed to the ``serializer``
73 function."""
73 function."""
74 self.serializer = serializer
74 self.serializer = serializer
75 self.kw = kw
75 self.kw = kw
76
76
77 def __call__(self, info):
77 def __call__(self, info):
78 """ Returns a plain JSON-encoded string with content-type
78 """ Returns a plain JSON-encoded string with content-type
79 ``application/json``. The content-type may be overridden by
79 ``application/json``. The content-type may be overridden by
80 setting ``request.response.content_type``."""
80 setting ``request.response.content_type``."""
81
81
82 def _render(value, system):
82 def _render(value, system):
83 request = system.get('request')
83 request = system.get('request')
84 if request is not None:
84 if request is not None:
85 response = request.response
85 response = request.response
86 ct = response.content_type
86 ct = response.content_type
87 if ct == response.default_content_type:
87 if ct == response.default_content_type:
88 response.content_type = 'application/json'
88 response.content_type = 'application/json'
89
89
90 return self.serializer(value, **self.kw)
90 return self.serializer(value, **self.kw)
91
91
92 return _render
92 return _render
93
93
94
94
95 def jsonrpc_response(request, result):
95 def jsonrpc_response(request, result):
96 rpc_id = getattr(request, 'rpc_id', None)
96 rpc_id = getattr(request, 'rpc_id', None)
97 response = request.response
97 response = request.response
98
98
99 # store content_type before render is called
99 # store content_type before render is called
100 ct = response.content_type
100 ct = response.content_type
101
101
102 ret_value = ''
102 ret_value = ''
103 if rpc_id:
103 if rpc_id:
104 ret_value = {
104 ret_value = {
105 'id': rpc_id,
105 'id': rpc_id,
106 'result': result,
106 'result': result,
107 'error': None,
107 'error': None,
108 }
108 }
109
109
110 # fetch deprecation warnings, and store it inside results
110 # fetch deprecation warnings, and store it inside results
111 deprecation = getattr(request, 'rpc_deprecation', None)
111 deprecation = getattr(request, 'rpc_deprecation', None)
112 if deprecation:
112 if deprecation:
113 ret_value['DEPRECATION_WARNING'] = deprecation
113 ret_value['DEPRECATION_WARNING'] = deprecation
114
114
115 raw_body = render(DEFAULT_RENDERER, ret_value, request=request)
115 raw_body = render(DEFAULT_RENDERER, ret_value, request=request)
116 response.body = safe_str(raw_body, response.charset)
116 response.body = safe_str(raw_body, response.charset)
117
117
118 if ct == response.default_content_type:
118 if ct == response.default_content_type:
119 response.content_type = 'application/json'
119 response.content_type = 'application/json'
120
120
121 return response
121 return response
122
122
123
123
124 def jsonrpc_error(request, message, retid=None, code=None, headers=None):
124 def jsonrpc_error(request, message, retid=None, code=None, headers=None):
125 """
125 """
126 Generate a Response object with a JSON-RPC error body
126 Generate a Response object with a JSON-RPC error body
127
127
128 :param code:
128 :param code:
129 :param retid:
129 :param retid:
130 :param message:
130 :param message:
131 """
131 """
132 err_dict = {'id': retid, 'result': None, 'error': message}
132 err_dict = {'id': retid, 'result': None, 'error': message}
133 body = render(DEFAULT_RENDERER, err_dict, request=request).encode('utf-8')
133 body = render(DEFAULT_RENDERER, err_dict, request=request).encode('utf-8')
134
134
135 return Response(
135 return Response(
136 body=body,
136 body=body,
137 status=code,
137 status=code,
138 content_type='application/json',
138 content_type='application/json',
139 headerlist=headers
139 headerlist=headers
140 )
140 )
141
141
142
142
143 def exception_view(exc, request):
143 def exception_view(exc, request):
144 rpc_id = getattr(request, 'rpc_id', None)
144 rpc_id = getattr(request, 'rpc_id', None)
145
145
146 if isinstance(exc, JSONRPCError):
146 if isinstance(exc, JSONRPCError):
147 fault_message = safe_str(exc.message)
147 fault_message = safe_str(exc.message)
148 log.debug('json-rpc error rpc_id:%s "%s"', rpc_id, fault_message)
148 log.debug('json-rpc error rpc_id:%s "%s"', rpc_id, fault_message)
149 elif isinstance(exc, JSONRPCValidationError):
149 elif isinstance(exc, JSONRPCValidationError):
150 colander_exc = exc.colander_exception
150 colander_exc = exc.colander_exception
151 # TODO(marcink): think maybe of nicer way to serialize errors ?
151 # TODO(marcink): think maybe of nicer way to serialize errors ?
152 fault_message = colander_exc.asdict()
152 fault_message = colander_exc.asdict()
153 log.debug('json-rpc colander error rpc_id:%s "%s"', rpc_id, fault_message)
153 log.debug('json-rpc colander error rpc_id:%s "%s"', rpc_id, fault_message)
154 elif isinstance(exc, JSONRPCForbidden):
154 elif isinstance(exc, JSONRPCForbidden):
155 fault_message = 'Access was denied to this resource.'
155 fault_message = 'Access was denied to this resource.'
156 log.warning('json-rpc forbidden call rpc_id:%s "%s"', rpc_id, fault_message)
156 log.warning('json-rpc forbidden call rpc_id:%s "%s"', rpc_id, fault_message)
157 elif isinstance(exc, HTTPNotFound):
157 elif isinstance(exc, HTTPNotFound):
158 method = request.rpc_method
158 method = request.rpc_method
159 log.debug('json-rpc method `%s` not found in list of '
159 log.debug('json-rpc method `%s` not found in list of '
160 'api calls: %s, rpc_id:%s',
160 'api calls: %s, rpc_id:%s',
161 method, request.registry.jsonrpc_methods.keys(), rpc_id)
161 method, request.registry.jsonrpc_methods.keys(), rpc_id)
162
162
163 similar = 'none'
163 similar = 'none'
164 try:
164 try:
165 similar_paterns = ['*{}*'.format(x) for x in method.split('_')]
165 similar_paterns = ['*{}*'.format(x) for x in method.split('_')]
166 similar_found = find_methods(
166 similar_found = find_methods(
167 request.registry.jsonrpc_methods, similar_paterns)
167 request.registry.jsonrpc_methods, similar_paterns)
168 similar = ', '.join(similar_found.keys()) or similar
168 similar = ', '.join(similar_found.keys()) or similar
169 except Exception:
169 except Exception:
170 # make the whole above block safe
170 # make the whole above block safe
171 pass
171 pass
172
172
173 fault_message = "No such method: {}. Similar methods: {}".format(
173 fault_message = "No such method: {}. Similar methods: {}".format(
174 method, similar)
174 method, similar)
175 else:
175 else:
176 fault_message = 'undefined error'
176 fault_message = 'undefined error'
177 exc_info = exc.exc_info()
177 exc_info = exc.exc_info()
178 store_exception(id(exc_info), exc_info, prefix='rhodecode-api')
178 store_exception(id(exc_info), exc_info, prefix='rhodecode-api')
179
179
180 statsd = request.registry.statsd
180 statsd = request.registry.statsd
181 if statsd:
181 if statsd:
182 exc_type = "{}.{}".format(exc.__class__.__module__, exc.__class__.__name__)
182 exc_type = "{}.{}".format(exc.__class__.__module__, exc.__class__.__name__)
183 statsd.incr('rhodecode_exception_total',
183 statsd.incr('rhodecode_exception_total',
184 tags=["exc_source:api", "type:{}".format(exc_type)])
184 tags=["exc_source:api", "type:{}".format(exc_type)])
185
185
186 return jsonrpc_error(request, fault_message, rpc_id)
186 return jsonrpc_error(request, fault_message, rpc_id)
187
187
188
188
189 def request_view(request):
189 def request_view(request):
190 """
190 """
191 Main request handling method. It handles all logic to call a specific
191 Main request handling method. It handles all logic to call a specific
192 exposed method
192 exposed method
193 """
193 """
194 # cython compatible inspect
194 # cython compatible inspect
195 from rhodecode.config.patches import inspect_getargspec
195 from rhodecode.config.patches import inspect_getargspec
196 inspect = inspect_getargspec()
196 inspect = inspect_getargspec()
197
197
198 # check if we can find this session using api_key, get_by_auth_token
198 # check if we can find this session using api_key, get_by_auth_token
199 # search not expired tokens only
199 # search not expired tokens only
200 try:
200 try:
201 api_user = User.get_by_auth_token(request.rpc_api_key)
201 api_user = User.get_by_auth_token(request.rpc_api_key)
202
202
203 if api_user is None:
203 if api_user is None:
204 return jsonrpc_error(
204 return jsonrpc_error(
205 request, retid=request.rpc_id, message='Invalid API KEY')
205 request, retid=request.rpc_id, message='Invalid API KEY')
206
206
207 if not api_user.active:
207 if not api_user.active:
208 return jsonrpc_error(
208 return jsonrpc_error(
209 request, retid=request.rpc_id,
209 request, retid=request.rpc_id,
210 message='Request from this user not allowed')
210 message='Request from this user not allowed')
211
211
212 # check if we are allowed to use this IP
212 # check if we are allowed to use this IP
213 auth_u = AuthUser(
213 auth_u = AuthUser(
214 api_user.user_id, request.rpc_api_key, ip_addr=request.rpc_ip_addr)
214 api_user.user_id, request.rpc_api_key, ip_addr=request.rpc_ip_addr)
215 if not auth_u.ip_allowed:
215 if not auth_u.ip_allowed:
216 return jsonrpc_error(
216 return jsonrpc_error(
217 request, retid=request.rpc_id,
217 request, retid=request.rpc_id,
218 message='Request from IP:%s not allowed' % (
218 message='Request from IP:%s not allowed' % (
219 request.rpc_ip_addr,))
219 request.rpc_ip_addr,))
220 else:
220 else:
221 log.info('Access for IP:%s allowed', request.rpc_ip_addr)
221 log.info('Access for IP:%s allowed', request.rpc_ip_addr)
222
222
223 # register our auth-user
223 # register our auth-user
224 request.rpc_user = auth_u
224 request.rpc_user = auth_u
225 request.environ['rc_auth_user_id'] = auth_u.user_id
225 request.environ['rc_auth_user_id'] = auth_u.user_id
226
226
227 # now check if token is valid for API
227 # now check if token is valid for API
228 auth_token = request.rpc_api_key
228 auth_token = request.rpc_api_key
229 token_match = api_user.authenticate_by_token(
229 token_match = api_user.authenticate_by_token(
230 auth_token, roles=[UserApiKeys.ROLE_API])
230 auth_token, roles=[UserApiKeys.ROLE_API])
231 invalid_token = not token_match
231 invalid_token = not token_match
232
232
233 log.debug('Checking if API KEY is valid with proper role')
233 log.debug('Checking if API KEY is valid with proper role')
234 if invalid_token:
234 if invalid_token:
235 return jsonrpc_error(
235 return jsonrpc_error(
236 request, retid=request.rpc_id,
236 request, retid=request.rpc_id,
237 message='API KEY invalid or, has bad role for an API call')
237 message='API KEY invalid or, has bad role for an API call')
238
238
239 except Exception:
239 except Exception:
240 log.exception('Error on API AUTH')
240 log.exception('Error on API AUTH')
241 return jsonrpc_error(
241 return jsonrpc_error(
242 request, retid=request.rpc_id, message='Invalid API KEY')
242 request, retid=request.rpc_id, message='Invalid API KEY')
243
243
244 method = request.rpc_method
244 method = request.rpc_method
245 func = request.registry.jsonrpc_methods[method]
245 func = request.registry.jsonrpc_methods[method]
246
246
247 # now that we have a method, add request._req_params to
247 # now that we have a method, add request._req_params to
248 # self.kargs and dispatch control to WGIController
248 # self.kargs and dispatch control to WGIController
249 argspec = inspect.getargspec(func)
249 argspec = inspect.getargspec(func)
250 arglist = argspec[0]
250 arglist = argspec[0]
251 defaults = map(type, argspec[3] or [])
251 defaults = map(type, argspec[3] or [])
252 default_empty = types.NotImplementedType
252 default_empty = types.NotImplementedType
253
253
254 # kw arguments required by this method
254 # kw arguments required by this method
255 func_kwargs = dict(itertools.izip_longest(
255 func_kwargs = dict(itertools.zip_longest(
256 reversed(arglist), reversed(defaults), fillvalue=default_empty))
256 reversed(arglist), reversed(defaults), fillvalue=default_empty))
257
257
258 # This attribute will need to be first param of a method that uses
258 # This attribute will need to be first param of a method that uses
259 # api_key, which is translated to instance of user at that name
259 # api_key, which is translated to instance of user at that name
260 user_var = 'apiuser'
260 user_var = 'apiuser'
261 request_var = 'request'
261 request_var = 'request'
262
262
263 for arg in [user_var, request_var]:
263 for arg in [user_var, request_var]:
264 if arg not in arglist:
264 if arg not in arglist:
265 return jsonrpc_error(
265 return jsonrpc_error(
266 request,
266 request,
267 retid=request.rpc_id,
267 retid=request.rpc_id,
268 message='This method [%s] does not support '
268 message='This method [%s] does not support '
269 'required parameter `%s`' % (func.__name__, arg))
269 'required parameter `%s`' % (func.__name__, arg))
270
270
271 # get our arglist and check if we provided them as args
271 # get our arglist and check if we provided them as args
272 for arg, default in func_kwargs.items():
272 for arg, default in func_kwargs.items():
273 if arg in [user_var, request_var]:
273 if arg in [user_var, request_var]:
274 # user_var and request_var are pre-hardcoded parameters and we
274 # user_var and request_var are pre-hardcoded parameters and we
275 # don't need to do any translation
275 # don't need to do any translation
276 continue
276 continue
277
277
278 # skip the required param check if it's default value is
278 # skip the required param check if it's default value is
279 # NotImplementedType (default_empty)
279 # NotImplementedType (default_empty)
280 if default == default_empty and arg not in request.rpc_params:
280 if default == default_empty and arg not in request.rpc_params:
281 return jsonrpc_error(
281 return jsonrpc_error(
282 request,
282 request,
283 retid=request.rpc_id,
283 retid=request.rpc_id,
284 message=('Missing non optional `%s` arg in JSON DATA' % arg)
284 message=('Missing non optional `%s` arg in JSON DATA' % arg)
285 )
285 )
286
286
287 # sanitize extra passed arguments
287 # sanitize extra passed arguments
288 for k in request.rpc_params.keys()[:]:
288 for k in request.rpc_params.keys()[:]:
289 if k not in func_kwargs:
289 if k not in func_kwargs:
290 del request.rpc_params[k]
290 del request.rpc_params[k]
291
291
292 call_params = request.rpc_params
292 call_params = request.rpc_params
293 call_params.update({
293 call_params.update({
294 'request': request,
294 'request': request,
295 'apiuser': auth_u
295 'apiuser': auth_u
296 })
296 })
297
297
298 # register some common functions for usage
298 # register some common functions for usage
299 attach_context_attributes(TemplateArgs(), request, request.rpc_user.user_id)
299 attach_context_attributes(TemplateArgs(), request, request.rpc_user.user_id)
300
300
301 statsd = request.registry.statsd
301 statsd = request.registry.statsd
302
302
303 try:
303 try:
304 ret_value = func(**call_params)
304 ret_value = func(**call_params)
305 resp = jsonrpc_response(request, ret_value)
305 resp = jsonrpc_response(request, ret_value)
306 if statsd:
306 if statsd:
307 statsd.incr('rhodecode_api_call_success_total')
307 statsd.incr('rhodecode_api_call_success_total')
308 return resp
308 return resp
309 except JSONRPCBaseError:
309 except JSONRPCBaseError:
310 raise
310 raise
311 except Exception:
311 except Exception:
312 log.exception('Unhandled exception occurred on api call: %s', func)
312 log.exception('Unhandled exception occurred on api call: %s', func)
313 exc_info = sys.exc_info()
313 exc_info = sys.exc_info()
314 exc_id, exc_type_name = store_exception(
314 exc_id, exc_type_name = store_exception(
315 id(exc_info), exc_info, prefix='rhodecode-api')
315 id(exc_info), exc_info, prefix='rhodecode-api')
316 error_headers = [('RhodeCode-Exception-Id', str(exc_id)),
316 error_headers = [('RhodeCode-Exception-Id', str(exc_id)),
317 ('RhodeCode-Exception-Type', str(exc_type_name))]
317 ('RhodeCode-Exception-Type', str(exc_type_name))]
318 err_resp = jsonrpc_error(
318 err_resp = jsonrpc_error(
319 request, retid=request.rpc_id, message='Internal server error',
319 request, retid=request.rpc_id, message='Internal server error',
320 headers=error_headers)
320 headers=error_headers)
321 if statsd:
321 if statsd:
322 statsd.incr('rhodecode_api_call_fail_total')
322 statsd.incr('rhodecode_api_call_fail_total')
323 return err_resp
323 return err_resp
324
324
325
325
326 def setup_request(request):
326 def setup_request(request):
327 """
327 """
328 Parse a JSON-RPC request body. It's used inside the predicates method
328 Parse a JSON-RPC request body. It's used inside the predicates method
329 to validate and bootstrap requests for usage in rpc calls.
329 to validate and bootstrap requests for usage in rpc calls.
330
330
331 We need to raise JSONRPCError here if we want to return some errors back to
331 We need to raise JSONRPCError here if we want to return some errors back to
332 user.
332 user.
333 """
333 """
334
334
335 log.debug('Executing setup request: %r', request)
335 log.debug('Executing setup request: %r', request)
336 request.rpc_ip_addr = get_ip_addr(request.environ)
336 request.rpc_ip_addr = get_ip_addr(request.environ)
337 # TODO(marcink): deprecate GET at some point
337 # TODO(marcink): deprecate GET at some point
338 if request.method not in ['POST', 'GET']:
338 if request.method not in ['POST', 'GET']:
339 log.debug('unsupported request method "%s"', request.method)
339 log.debug('unsupported request method "%s"', request.method)
340 raise JSONRPCError(
340 raise JSONRPCError(
341 'unsupported request method "%s". Please use POST' % request.method)
341 'unsupported request method "%s". Please use POST' % request.method)
342
342
343 if 'CONTENT_LENGTH' not in request.environ:
343 if 'CONTENT_LENGTH' not in request.environ:
344 log.debug("No Content-Length")
344 log.debug("No Content-Length")
345 raise JSONRPCError("Empty body, No Content-Length in request")
345 raise JSONRPCError("Empty body, No Content-Length in request")
346
346
347 else:
347 else:
348 length = request.environ['CONTENT_LENGTH']
348 length = request.environ['CONTENT_LENGTH']
349 log.debug('Content-Length: %s', length)
349 log.debug('Content-Length: %s', length)
350
350
351 if length == 0:
351 if length == 0:
352 log.debug("Content-Length is 0")
352 log.debug("Content-Length is 0")
353 raise JSONRPCError("Content-Length is 0")
353 raise JSONRPCError("Content-Length is 0")
354
354
355 raw_body = request.body
355 raw_body = request.body
356 log.debug("Loading JSON body now")
356 log.debug("Loading JSON body now")
357 try:
357 try:
358 json_body = json.loads(raw_body)
358 json_body = json.loads(raw_body)
359 except ValueError as e:
359 except ValueError as e:
360 # catch JSON errors Here
360 # catch JSON errors Here
361 raise JSONRPCError("JSON parse error ERR:%s RAW:%r" % (e, raw_body))
361 raise JSONRPCError("JSON parse error ERR:%s RAW:%r" % (e, raw_body))
362
362
363 request.rpc_id = json_body.get('id')
363 request.rpc_id = json_body.get('id')
364 request.rpc_method = json_body.get('method')
364 request.rpc_method = json_body.get('method')
365
365
366 # check required base parameters
366 # check required base parameters
367 try:
367 try:
368 api_key = json_body.get('api_key')
368 api_key = json_body.get('api_key')
369 if not api_key:
369 if not api_key:
370 api_key = json_body.get('auth_token')
370 api_key = json_body.get('auth_token')
371
371
372 if not api_key:
372 if not api_key:
373 raise KeyError('api_key or auth_token')
373 raise KeyError('api_key or auth_token')
374
374
375 # TODO(marcink): support passing in token in request header
375 # TODO(marcink): support passing in token in request header
376
376
377 request.rpc_api_key = api_key
377 request.rpc_api_key = api_key
378 request.rpc_id = json_body['id']
378 request.rpc_id = json_body['id']
379 request.rpc_method = json_body['method']
379 request.rpc_method = json_body['method']
380 request.rpc_params = json_body['args'] \
380 request.rpc_params = json_body['args'] \
381 if isinstance(json_body['args'], dict) else {}
381 if isinstance(json_body['args'], dict) else {}
382
382
383 log.debug('method: %s, params: %.10240r', request.rpc_method, request.rpc_params)
383 log.debug('method: %s, params: %.10240r', request.rpc_method, request.rpc_params)
384 except KeyError as e:
384 except KeyError as e:
385 raise JSONRPCError('Incorrect JSON data. Missing %s' % e)
385 raise JSONRPCError('Incorrect JSON data. Missing %s' % e)
386
386
387 log.debug('setup complete, now handling method:%s rpcid:%s',
387 log.debug('setup complete, now handling method:%s rpcid:%s',
388 request.rpc_method, request.rpc_id, )
388 request.rpc_method, request.rpc_id, )
389
389
390
390
391 class RoutePredicate(object):
391 class RoutePredicate(object):
392 def __init__(self, val, config):
392 def __init__(self, val, config):
393 self.val = val
393 self.val = val
394
394
395 def text(self):
395 def text(self):
396 return 'jsonrpc route = %s' % self.val
396 return 'jsonrpc route = %s' % self.val
397
397
398 phash = text
398 phash = text
399
399
400 def __call__(self, info, request):
400 def __call__(self, info, request):
401 if self.val:
401 if self.val:
402 # potentially setup and bootstrap our call
402 # potentially setup and bootstrap our call
403 setup_request(request)
403 setup_request(request)
404
404
405 # Always return True so that even if it isn't a valid RPC it
405 # Always return True so that even if it isn't a valid RPC it
406 # will fall through to the underlaying handlers like notfound_view
406 # will fall through to the underlaying handlers like notfound_view
407 return True
407 return True
408
408
409
409
410 class NotFoundPredicate(object):
410 class NotFoundPredicate(object):
411 def __init__(self, val, config):
411 def __init__(self, val, config):
412 self.val = val
412 self.val = val
413 self.methods = config.registry.jsonrpc_methods
413 self.methods = config.registry.jsonrpc_methods
414
414
415 def text(self):
415 def text(self):
416 return 'jsonrpc method not found = {}.'.format(self.val)
416 return 'jsonrpc method not found = {}.'.format(self.val)
417
417
418 phash = text
418 phash = text
419
419
420 def __call__(self, info, request):
420 def __call__(self, info, request):
421 return hasattr(request, 'rpc_method')
421 return hasattr(request, 'rpc_method')
422
422
423
423
424 class MethodPredicate(object):
424 class MethodPredicate(object):
425 def __init__(self, val, config):
425 def __init__(self, val, config):
426 self.method = val
426 self.method = val
427
427
428 def text(self):
428 def text(self):
429 return 'jsonrpc method = %s' % self.method
429 return 'jsonrpc method = %s' % self.method
430
430
431 phash = text
431 phash = text
432
432
433 def __call__(self, context, request):
433 def __call__(self, context, request):
434 # we need to explicitly return False here, so pyramid doesn't try to
434 # we need to explicitly return False here, so pyramid doesn't try to
435 # execute our view directly. We need our main handler to execute things
435 # execute our view directly. We need our main handler to execute things
436 return getattr(request, 'rpc_method') == self.method
436 return getattr(request, 'rpc_method') == self.method
437
437
438
438
439 def add_jsonrpc_method(config, view, **kwargs):
439 def add_jsonrpc_method(config, view, **kwargs):
440 # pop the method name
440 # pop the method name
441 method = kwargs.pop('method', None)
441 method = kwargs.pop('method', None)
442
442
443 if method is None:
443 if method is None:
444 raise ConfigurationError(
444 raise ConfigurationError(
445 'Cannot register a JSON-RPC method without specifying the "method"')
445 'Cannot register a JSON-RPC method without specifying the "method"')
446
446
447 # we define custom predicate, to enable to detect conflicting methods,
447 # we define custom predicate, to enable to detect conflicting methods,
448 # those predicates are kind of "translation" from the decorator variables
448 # those predicates are kind of "translation" from the decorator variables
449 # to internal predicates names
449 # to internal predicates names
450
450
451 kwargs['jsonrpc_method'] = method
451 kwargs['jsonrpc_method'] = method
452
452
453 # register our view into global view store for validation
453 # register our view into global view store for validation
454 config.registry.jsonrpc_methods[method] = view
454 config.registry.jsonrpc_methods[method] = view
455
455
456 # we're using our main request_view handler, here, so each method
456 # we're using our main request_view handler, here, so each method
457 # has a unified handler for itself
457 # has a unified handler for itself
458 config.add_view(request_view, route_name='apiv2', **kwargs)
458 config.add_view(request_view, route_name='apiv2', **kwargs)
459
459
460
460
461 class jsonrpc_method(object):
461 class jsonrpc_method(object):
462 """
462 """
463 decorator that works similar to @add_view_config decorator,
463 decorator that works similar to @add_view_config decorator,
464 but tailored for our JSON RPC
464 but tailored for our JSON RPC
465 """
465 """
466
466
467 venusian = venusian # for testing injection
467 venusian = venusian # for testing injection
468
468
469 def __init__(self, method=None, **kwargs):
469 def __init__(self, method=None, **kwargs):
470 self.method = method
470 self.method = method
471 self.kwargs = kwargs
471 self.kwargs = kwargs
472
472
473 def __call__(self, wrapped):
473 def __call__(self, wrapped):
474 kwargs = self.kwargs.copy()
474 kwargs = self.kwargs.copy()
475 kwargs['method'] = self.method or wrapped.__name__
475 kwargs['method'] = self.method or wrapped.__name__
476 depth = kwargs.pop('_depth', 0)
476 depth = kwargs.pop('_depth', 0)
477
477
478 def callback(context, name, ob):
478 def callback(context, name, ob):
479 config = context.config.with_package(info.module)
479 config = context.config.with_package(info.module)
480 config.add_jsonrpc_method(view=ob, **kwargs)
480 config.add_jsonrpc_method(view=ob, **kwargs)
481
481
482 info = venusian.attach(wrapped, callback, category='pyramid',
482 info = venusian.attach(wrapped, callback, category='pyramid',
483 depth=depth + 1)
483 depth=depth + 1)
484 if info.scope == 'class':
484 if info.scope == 'class':
485 # ensure that attr is set if decorating a class method
485 # ensure that attr is set if decorating a class method
486 kwargs.setdefault('attr', wrapped.__name__)
486 kwargs.setdefault('attr', wrapped.__name__)
487
487
488 kwargs['_info'] = info.codeinfo # fbo action_method
488 kwargs['_info'] = info.codeinfo # fbo action_method
489 return wrapped
489 return wrapped
490
490
491
491
492 class jsonrpc_deprecated_method(object):
492 class jsonrpc_deprecated_method(object):
493 """
493 """
494 Marks method as deprecated, adds log.warning, and inject special key to
494 Marks method as deprecated, adds log.warning, and inject special key to
495 the request variable to mark method as deprecated.
495 the request variable to mark method as deprecated.
496 Also injects special docstring that extract_docs will catch to mark
496 Also injects special docstring that extract_docs will catch to mark
497 method as deprecated.
497 method as deprecated.
498
498
499 :param use_method: specify which method should be used instead of
499 :param use_method: specify which method should be used instead of
500 the decorated one
500 the decorated one
501
501
502 Use like::
502 Use like::
503
503
504 @jsonrpc_method()
504 @jsonrpc_method()
505 @jsonrpc_deprecated_method(use_method='new_func', deprecated_at_version='3.0.0')
505 @jsonrpc_deprecated_method(use_method='new_func', deprecated_at_version='3.0.0')
506 def old_func(request, apiuser, arg1, arg2):
506 def old_func(request, apiuser, arg1, arg2):
507 ...
507 ...
508 """
508 """
509
509
510 def __init__(self, use_method, deprecated_at_version):
510 def __init__(self, use_method, deprecated_at_version):
511 self.use_method = use_method
511 self.use_method = use_method
512 self.deprecated_at_version = deprecated_at_version
512 self.deprecated_at_version = deprecated_at_version
513 self.deprecated_msg = ''
513 self.deprecated_msg = ''
514
514
515 def __call__(self, func):
515 def __call__(self, func):
516 self.deprecated_msg = 'Please use method `{method}` instead.'.format(
516 self.deprecated_msg = 'Please use method `{method}` instead.'.format(
517 method=self.use_method)
517 method=self.use_method)
518
518
519 docstring = """\n
519 docstring = """\n
520 .. deprecated:: {version}
520 .. deprecated:: {version}
521
521
522 {deprecation_message}
522 {deprecation_message}
523
523
524 {original_docstring}
524 {original_docstring}
525 """
525 """
526 func.__doc__ = docstring.format(
526 func.__doc__ = docstring.format(
527 version=self.deprecated_at_version,
527 version=self.deprecated_at_version,
528 deprecation_message=self.deprecated_msg,
528 deprecation_message=self.deprecated_msg,
529 original_docstring=func.__doc__)
529 original_docstring=func.__doc__)
530 return decorator.decorator(self.__wrapper, func)
530 return decorator.decorator(self.__wrapper, func)
531
531
532 def __wrapper(self, func, *fargs, **fkwargs):
532 def __wrapper(self, func, *fargs, **fkwargs):
533 log.warning('DEPRECATED API CALL on function %s, please '
533 log.warning('DEPRECATED API CALL on function %s, please '
534 'use `%s` instead', func, self.use_method)
534 'use `%s` instead', func, self.use_method)
535 # alter function docstring to mark as deprecated, this is picked up
535 # alter function docstring to mark as deprecated, this is picked up
536 # via fabric file that generates API DOC.
536 # via fabric file that generates API DOC.
537 result = func(*fargs, **fkwargs)
537 result = func(*fargs, **fkwargs)
538
538
539 request = fargs[0]
539 request = fargs[0]
540 request.rpc_deprecation = 'DEPRECATED METHOD ' + self.deprecated_msg
540 request.rpc_deprecation = 'DEPRECATED METHOD ' + self.deprecated_msg
541 return result
541 return result
542
542
543
543
544 def add_api_methods(config):
544 def add_api_methods(config):
545 from rhodecode.api.views import (
545 from rhodecode.api.views import (
546 deprecated_api, gist_api, pull_request_api, repo_api, repo_group_api,
546 deprecated_api, gist_api, pull_request_api, repo_api, repo_group_api,
547 server_api, search_api, testing_api, user_api, user_group_api)
547 server_api, search_api, testing_api, user_api, user_group_api)
548
548
549 config.scan('rhodecode.api.views')
549 config.scan('rhodecode.api.views')
550
550
551
551
552 def includeme(config):
552 def includeme(config):
553 plugin_module = 'rhodecode.api'
553 plugin_module = 'rhodecode.api'
554 plugin_settings = get_plugin_settings(
554 plugin_settings = get_plugin_settings(
555 plugin_module, config.registry.settings)
555 plugin_module, config.registry.settings)
556
556
557 if not hasattr(config.registry, 'jsonrpc_methods'):
557 if not hasattr(config.registry, 'jsonrpc_methods'):
558 config.registry.jsonrpc_methods = OrderedDict()
558 config.registry.jsonrpc_methods = OrderedDict()
559
559
560 # match filter by given method only
560 # match filter by given method only
561 config.add_view_predicate('jsonrpc_method', MethodPredicate)
561 config.add_view_predicate('jsonrpc_method', MethodPredicate)
562 config.add_view_predicate('jsonrpc_method_not_found', NotFoundPredicate)
562 config.add_view_predicate('jsonrpc_method_not_found', NotFoundPredicate)
563
563
564 config.add_renderer(DEFAULT_RENDERER, ExtJsonRenderer(
564 config.add_renderer(DEFAULT_RENDERER, ExtJsonRenderer(
565 serializer=json.dumps, indent=4))
565 serializer=json.dumps, indent=4))
566 config.add_directive('add_jsonrpc_method', add_jsonrpc_method)
566 config.add_directive('add_jsonrpc_method', add_jsonrpc_method)
567
567
568 config.add_route_predicate(
568 config.add_route_predicate(
569 'jsonrpc_call', RoutePredicate)
569 'jsonrpc_call', RoutePredicate)
570
570
571 config.add_route(
571 config.add_route(
572 'apiv2', plugin_settings.get('url', DEFAULT_URL), jsonrpc_call=True)
572 'apiv2', plugin_settings.get('url', DEFAULT_URL), jsonrpc_call=True)
573
573
574 # register some exception handling view
574 # register some exception handling view
575 config.add_view(exception_view, context=JSONRPCBaseError)
575 config.add_view(exception_view, context=JSONRPCBaseError)
576 config.add_notfound_view(exception_view, jsonrpc_method_not_found=True)
576 config.add_notfound_view(exception_view, jsonrpc_method_not_found=True)
577
577
578 add_api_methods(config)
578 add_api_methods(config)
@@ -1,419 +1,419 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2011-2020 RhodeCode GmbH
3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import logging
21 import logging
22 import itertools
22 import itertools
23 import base64
23 import base64
24
24
25 from rhodecode.api import (
25 from rhodecode.api import (
26 jsonrpc_method, JSONRPCError, JSONRPCForbidden, find_methods)
26 jsonrpc_method, JSONRPCError, JSONRPCForbidden, find_methods)
27
27
28 from rhodecode.api.utils import (
28 from rhodecode.api.utils import (
29 Optional, OAttr, has_superadmin_permission, get_user_or_error)
29 Optional, OAttr, has_superadmin_permission, get_user_or_error)
30 from rhodecode.lib.utils import repo2db_mapper
30 from rhodecode.lib.utils import repo2db_mapper
31 from rhodecode.lib import system_info
31 from rhodecode.lib import system_info
32 from rhodecode.lib import user_sessions
32 from rhodecode.lib import user_sessions
33 from rhodecode.lib import exc_tracking
33 from rhodecode.lib import exc_tracking
34 from rhodecode.lib.ext_json import json
34 from rhodecode.lib.ext_json import json
35 from rhodecode.lib.utils2 import safe_int
35 from rhodecode.lib.utils2 import safe_int
36 from rhodecode.model.db import UserIpMap
36 from rhodecode.model.db import UserIpMap
37 from rhodecode.model.scm import ScmModel
37 from rhodecode.model.scm import ScmModel
38 from rhodecode.model.settings import VcsSettingsModel
38 from rhodecode.model.settings import VcsSettingsModel
39 from rhodecode.apps.file_store import utils
39 from rhodecode.apps.file_store import utils
40 from rhodecode.apps.file_store.exceptions import FileNotAllowedException, \
40 from rhodecode.apps.file_store.exceptions import FileNotAllowedException, \
41 FileOverSizeException
41 FileOverSizeException
42
42
43 log = logging.getLogger(__name__)
43 log = logging.getLogger(__name__)
44
44
45
45
46 @jsonrpc_method()
46 @jsonrpc_method()
47 def get_server_info(request, apiuser):
47 def get_server_info(request, apiuser):
48 """
48 """
49 Returns the |RCE| server information.
49 Returns the |RCE| server information.
50
50
51 This includes the running version of |RCE| and all installed
51 This includes the running version of |RCE| and all installed
52 packages. This command takes the following options:
52 packages. This command takes the following options:
53
53
54 :param apiuser: This is filled automatically from the |authtoken|.
54 :param apiuser: This is filled automatically from the |authtoken|.
55 :type apiuser: AuthUser
55 :type apiuser: AuthUser
56
56
57 Example output:
57 Example output:
58
58
59 .. code-block:: bash
59 .. code-block:: bash
60
60
61 id : <id_given_in_input>
61 id : <id_given_in_input>
62 result : {
62 result : {
63 'modules': [<module name>,...]
63 'modules': [<module name>,...]
64 'py_version': <python version>,
64 'py_version': <python version>,
65 'platform': <platform type>,
65 'platform': <platform type>,
66 'rhodecode_version': <rhodecode version>
66 'rhodecode_version': <rhodecode version>
67 }
67 }
68 error : null
68 error : null
69 """
69 """
70
70
71 if not has_superadmin_permission(apiuser):
71 if not has_superadmin_permission(apiuser):
72 raise JSONRPCForbidden()
72 raise JSONRPCForbidden()
73
73
74 server_info = ScmModel().get_server_info(request.environ)
74 server_info = ScmModel().get_server_info(request.environ)
75 # rhodecode-index requires those
75 # rhodecode-index requires those
76
76
77 server_info['index_storage'] = server_info['search']['value']['location']
77 server_info['index_storage'] = server_info['search']['value']['location']
78 server_info['storage'] = server_info['storage']['value']['path']
78 server_info['storage'] = server_info['storage']['value']['path']
79
79
80 return server_info
80 return server_info
81
81
82
82
83 @jsonrpc_method()
83 @jsonrpc_method()
84 def get_repo_store(request, apiuser):
84 def get_repo_store(request, apiuser):
85 """
85 """
86 Returns the |RCE| repository storage information.
86 Returns the |RCE| repository storage information.
87
87
88 :param apiuser: This is filled automatically from the |authtoken|.
88 :param apiuser: This is filled automatically from the |authtoken|.
89 :type apiuser: AuthUser
89 :type apiuser: AuthUser
90
90
91 Example output:
91 Example output:
92
92
93 .. code-block:: bash
93 .. code-block:: bash
94
94
95 id : <id_given_in_input>
95 id : <id_given_in_input>
96 result : {
96 result : {
97 'modules': [<module name>,...]
97 'modules': [<module name>,...]
98 'py_version': <python version>,
98 'py_version': <python version>,
99 'platform': <platform type>,
99 'platform': <platform type>,
100 'rhodecode_version': <rhodecode version>
100 'rhodecode_version': <rhodecode version>
101 }
101 }
102 error : null
102 error : null
103 """
103 """
104
104
105 if not has_superadmin_permission(apiuser):
105 if not has_superadmin_permission(apiuser):
106 raise JSONRPCForbidden()
106 raise JSONRPCForbidden()
107
107
108 path = VcsSettingsModel().get_repos_location()
108 path = VcsSettingsModel().get_repos_location()
109 return {"path": path}
109 return {"path": path}
110
110
111
111
112 @jsonrpc_method()
112 @jsonrpc_method()
113 def get_ip(request, apiuser, userid=Optional(OAttr('apiuser'))):
113 def get_ip(request, apiuser, userid=Optional(OAttr('apiuser'))):
114 """
114 """
115 Displays the IP Address as seen from the |RCE| server.
115 Displays the IP Address as seen from the |RCE| server.
116
116
117 * This command displays the IP Address, as well as all the defined IP
117 * This command displays the IP Address, as well as all the defined IP
118 addresses for the specified user. If the ``userid`` is not set, the
118 addresses for the specified user. If the ``userid`` is not set, the
119 data returned is for the user calling the method.
119 data returned is for the user calling the method.
120
120
121 This command can only be run using an |authtoken| with admin rights to
121 This command can only be run using an |authtoken| with admin rights to
122 the specified repository.
122 the specified repository.
123
123
124 This command takes the following options:
124 This command takes the following options:
125
125
126 :param apiuser: This is filled automatically from |authtoken|.
126 :param apiuser: This is filled automatically from |authtoken|.
127 :type apiuser: AuthUser
127 :type apiuser: AuthUser
128 :param userid: Sets the userid for which associated IP Address data
128 :param userid: Sets the userid for which associated IP Address data
129 is returned.
129 is returned.
130 :type userid: Optional(str or int)
130 :type userid: Optional(str or int)
131
131
132 Example output:
132 Example output:
133
133
134 .. code-block:: bash
134 .. code-block:: bash
135
135
136 id : <id_given_in_input>
136 id : <id_given_in_input>
137 result : {
137 result : {
138 "server_ip_addr": "<ip_from_clien>",
138 "server_ip_addr": "<ip_from_clien>",
139 "user_ips": [
139 "user_ips": [
140 {
140 {
141 "ip_addr": "<ip_with_mask>",
141 "ip_addr": "<ip_with_mask>",
142 "ip_range": ["<start_ip>", "<end_ip>"],
142 "ip_range": ["<start_ip>", "<end_ip>"],
143 },
143 },
144 ...
144 ...
145 ]
145 ]
146 }
146 }
147
147
148 """
148 """
149 if not has_superadmin_permission(apiuser):
149 if not has_superadmin_permission(apiuser):
150 raise JSONRPCForbidden()
150 raise JSONRPCForbidden()
151
151
152 userid = Optional.extract(userid, evaluate_locals=locals())
152 userid = Optional.extract(userid, evaluate_locals=locals())
153 userid = getattr(userid, 'user_id', userid)
153 userid = getattr(userid, 'user_id', userid)
154
154
155 user = get_user_or_error(userid)
155 user = get_user_or_error(userid)
156 ips = UserIpMap.query().filter(UserIpMap.user == user).all()
156 ips = UserIpMap.query().filter(UserIpMap.user == user).all()
157 return {
157 return {
158 'server_ip_addr': request.rpc_ip_addr,
158 'server_ip_addr': request.rpc_ip_addr,
159 'user_ips': ips
159 'user_ips': ips
160 }
160 }
161
161
162
162
163 @jsonrpc_method()
163 @jsonrpc_method()
164 def rescan_repos(request, apiuser, remove_obsolete=Optional(False)):
164 def rescan_repos(request, apiuser, remove_obsolete=Optional(False)):
165 """
165 """
166 Triggers a rescan of the specified repositories.
166 Triggers a rescan of the specified repositories.
167
167
168 * If the ``remove_obsolete`` option is set, it also deletes repositories
168 * If the ``remove_obsolete`` option is set, it also deletes repositories
169 that are found in the database but not on the file system, so called
169 that are found in the database but not on the file system, so called
170 "clean zombies".
170 "clean zombies".
171
171
172 This command can only be run using an |authtoken| with admin rights to
172 This command can only be run using an |authtoken| with admin rights to
173 the specified repository.
173 the specified repository.
174
174
175 This command takes the following options:
175 This command takes the following options:
176
176
177 :param apiuser: This is filled automatically from the |authtoken|.
177 :param apiuser: This is filled automatically from the |authtoken|.
178 :type apiuser: AuthUser
178 :type apiuser: AuthUser
179 :param remove_obsolete: Deletes repositories from the database that
179 :param remove_obsolete: Deletes repositories from the database that
180 are not found on the filesystem.
180 are not found on the filesystem.
181 :type remove_obsolete: Optional(``True`` | ``False``)
181 :type remove_obsolete: Optional(``True`` | ``False``)
182
182
183 Example output:
183 Example output:
184
184
185 .. code-block:: bash
185 .. code-block:: bash
186
186
187 id : <id_given_in_input>
187 id : <id_given_in_input>
188 result : {
188 result : {
189 'added': [<added repository name>,...]
189 'added': [<added repository name>,...]
190 'removed': [<removed repository name>,...]
190 'removed': [<removed repository name>,...]
191 }
191 }
192 error : null
192 error : null
193
193
194 Example error output:
194 Example error output:
195
195
196 .. code-block:: bash
196 .. code-block:: bash
197
197
198 id : <id_given_in_input>
198 id : <id_given_in_input>
199 result : null
199 result : null
200 error : {
200 error : {
201 'Error occurred during rescan repositories action'
201 'Error occurred during rescan repositories action'
202 }
202 }
203
203
204 """
204 """
205 if not has_superadmin_permission(apiuser):
205 if not has_superadmin_permission(apiuser):
206 raise JSONRPCForbidden()
206 raise JSONRPCForbidden()
207
207
208 try:
208 try:
209 rm_obsolete = Optional.extract(remove_obsolete)
209 rm_obsolete = Optional.extract(remove_obsolete)
210 added, removed = repo2db_mapper(ScmModel().repo_scan(),
210 added, removed = repo2db_mapper(ScmModel().repo_scan(),
211 remove_obsolete=rm_obsolete)
211 remove_obsolete=rm_obsolete)
212 return {'added': added, 'removed': removed}
212 return {'added': added, 'removed': removed}
213 except Exception:
213 except Exception:
214 log.exception('Failed to run repo rescann')
214 log.exception('Failed to run repo rescann')
215 raise JSONRPCError(
215 raise JSONRPCError(
216 'Error occurred during rescan repositories action'
216 'Error occurred during rescan repositories action'
217 )
217 )
218
218
219
219
220 @jsonrpc_method()
220 @jsonrpc_method()
221 def cleanup_sessions(request, apiuser, older_then=Optional(60)):
221 def cleanup_sessions(request, apiuser, older_then=Optional(60)):
222 """
222 """
223 Triggers a session cleanup action.
223 Triggers a session cleanup action.
224
224
225 If the ``older_then`` option is set, only sessions that hasn't been
225 If the ``older_then`` option is set, only sessions that hasn't been
226 accessed in the given number of days will be removed.
226 accessed in the given number of days will be removed.
227
227
228 This command can only be run using an |authtoken| with admin rights to
228 This command can only be run using an |authtoken| with admin rights to
229 the specified repository.
229 the specified repository.
230
230
231 This command takes the following options:
231 This command takes the following options:
232
232
233 :param apiuser: This is filled automatically from the |authtoken|.
233 :param apiuser: This is filled automatically from the |authtoken|.
234 :type apiuser: AuthUser
234 :type apiuser: AuthUser
235 :param older_then: Deletes session that hasn't been accessed
235 :param older_then: Deletes session that hasn't been accessed
236 in given number of days.
236 in given number of days.
237 :type older_then: Optional(int)
237 :type older_then: Optional(int)
238
238
239 Example output:
239 Example output:
240
240
241 .. code-block:: bash
241 .. code-block:: bash
242
242
243 id : <id_given_in_input>
243 id : <id_given_in_input>
244 result: {
244 result: {
245 "backend": "<type of backend>",
245 "backend": "<type of backend>",
246 "sessions_removed": <number_of_removed_sessions>
246 "sessions_removed": <number_of_removed_sessions>
247 }
247 }
248 error : null
248 error : null
249
249
250 Example error output:
250 Example error output:
251
251
252 .. code-block:: bash
252 .. code-block:: bash
253
253
254 id : <id_given_in_input>
254 id : <id_given_in_input>
255 result : null
255 result : null
256 error : {
256 error : {
257 'Error occurred during session cleanup'
257 'Error occurred during session cleanup'
258 }
258 }
259
259
260 """
260 """
261 if not has_superadmin_permission(apiuser):
261 if not has_superadmin_permission(apiuser):
262 raise JSONRPCForbidden()
262 raise JSONRPCForbidden()
263
263
264 older_then = safe_int(Optional.extract(older_then)) or 60
264 older_then = safe_int(Optional.extract(older_then)) or 60
265 older_than_seconds = 60 * 60 * 24 * older_then
265 older_than_seconds = 60 * 60 * 24 * older_then
266
266
267 config = system_info.rhodecode_config().get_value()['value']['config']
267 config = system_info.rhodecode_config().get_value()['value']['config']
268 session_model = user_sessions.get_session_handler(
268 session_model = user_sessions.get_session_handler(
269 config.get('beaker.session.type', 'memory'))(config)
269 config.get('beaker.session.type', 'memory'))(config)
270
270
271 backend = session_model.SESSION_TYPE
271 backend = session_model.SESSION_TYPE
272 try:
272 try:
273 cleaned = session_model.clean_sessions(
273 cleaned = session_model.clean_sessions(
274 older_than_seconds=older_than_seconds)
274 older_than_seconds=older_than_seconds)
275 return {'sessions_removed': cleaned, 'backend': backend}
275 return {'sessions_removed': cleaned, 'backend': backend}
276 except user_sessions.CleanupCommand as msg:
276 except user_sessions.CleanupCommand as msg:
277 return {'cleanup_command': msg.message, 'backend': backend}
277 return {'cleanup_command': msg.message, 'backend': backend}
278 except Exception as e:
278 except Exception as e:
279 log.exception('Failed session cleanup')
279 log.exception('Failed session cleanup')
280 raise JSONRPCError(
280 raise JSONRPCError(
281 'Error occurred during session cleanup'
281 'Error occurred during session cleanup'
282 )
282 )
283
283
284
284
285 @jsonrpc_method()
285 @jsonrpc_method()
286 def get_method(request, apiuser, pattern=Optional('*')):
286 def get_method(request, apiuser, pattern=Optional('*')):
287 """
287 """
288 Returns list of all available API methods. By default match pattern
288 Returns list of all available API methods. By default match pattern
289 os "*" but any other pattern can be specified. eg *comment* will return
289 os "*" but any other pattern can be specified. eg *comment* will return
290 all methods with comment inside them. If just single method is matched
290 all methods with comment inside them. If just single method is matched
291 returned data will also include method specification
291 returned data will also include method specification
292
292
293 This command can only be run using an |authtoken| with admin rights to
293 This command can only be run using an |authtoken| with admin rights to
294 the specified repository.
294 the specified repository.
295
295
296 This command takes the following options:
296 This command takes the following options:
297
297
298 :param apiuser: This is filled automatically from the |authtoken|.
298 :param apiuser: This is filled automatically from the |authtoken|.
299 :type apiuser: AuthUser
299 :type apiuser: AuthUser
300 :param pattern: pattern to match method names against
300 :param pattern: pattern to match method names against
301 :type pattern: Optional("*")
301 :type pattern: Optional("*")
302
302
303 Example output:
303 Example output:
304
304
305 .. code-block:: bash
305 .. code-block:: bash
306
306
307 id : <id_given_in_input>
307 id : <id_given_in_input>
308 "result": [
308 "result": [
309 "changeset_comment",
309 "changeset_comment",
310 "comment_pull_request",
310 "comment_pull_request",
311 "comment_commit"
311 "comment_commit"
312 ]
312 ]
313 error : null
313 error : null
314
314
315 .. code-block:: bash
315 .. code-block:: bash
316
316
317 id : <id_given_in_input>
317 id : <id_given_in_input>
318 "result": [
318 "result": [
319 "comment_commit",
319 "comment_commit",
320 {
320 {
321 "apiuser": "<RequiredType>",
321 "apiuser": "<RequiredType>",
322 "comment_type": "<Optional:u'note'>",
322 "comment_type": "<Optional:u'note'>",
323 "commit_id": "<RequiredType>",
323 "commit_id": "<RequiredType>",
324 "message": "<RequiredType>",
324 "message": "<RequiredType>",
325 "repoid": "<RequiredType>",
325 "repoid": "<RequiredType>",
326 "request": "<RequiredType>",
326 "request": "<RequiredType>",
327 "resolves_comment_id": "<Optional:None>",
327 "resolves_comment_id": "<Optional:None>",
328 "status": "<Optional:None>",
328 "status": "<Optional:None>",
329 "userid": "<Optional:<OptionalAttr:apiuser>>"
329 "userid": "<Optional:<OptionalAttr:apiuser>>"
330 }
330 }
331 ]
331 ]
332 error : null
332 error : null
333 """
333 """
334 from rhodecode.config.patches import inspect_getargspec
334 from rhodecode.config.patches import inspect_getargspec
335 inspect = inspect_getargspec()
335 inspect = inspect_getargspec()
336
336
337 if not has_superadmin_permission(apiuser):
337 if not has_superadmin_permission(apiuser):
338 raise JSONRPCForbidden()
338 raise JSONRPCForbidden()
339
339
340 pattern = Optional.extract(pattern)
340 pattern = Optional.extract(pattern)
341
341
342 matches = find_methods(request.registry.jsonrpc_methods, pattern)
342 matches = find_methods(request.registry.jsonrpc_methods, pattern)
343
343
344 args_desc = []
344 args_desc = []
345 if len(matches) == 1:
345 if len(matches) == 1:
346 func = matches[matches.keys()[0]]
346 func = matches[matches.keys()[0]]
347
347
348 argspec = inspect.getargspec(func)
348 argspec = inspect.getargspec(func)
349 arglist = argspec[0]
349 arglist = argspec[0]
350 defaults = map(repr, argspec[3] or [])
350 defaults = map(repr, argspec[3] or [])
351
351
352 default_empty = '<RequiredType>'
352 default_empty = '<RequiredType>'
353
353
354 # kw arguments required by this method
354 # kw arguments required by this method
355 func_kwargs = dict(itertools.izip_longest(
355 func_kwargs = dict(itertools.zip_longest(
356 reversed(arglist), reversed(defaults), fillvalue=default_empty))
356 reversed(arglist), reversed(defaults), fillvalue=default_empty))
357 args_desc.append(func_kwargs)
357 args_desc.append(func_kwargs)
358
358
359 return matches.keys() + args_desc
359 return matches.keys() + args_desc
360
360
361
361
362 @jsonrpc_method()
362 @jsonrpc_method()
363 def store_exception(request, apiuser, exc_data_json, prefix=Optional('rhodecode')):
363 def store_exception(request, apiuser, exc_data_json, prefix=Optional('rhodecode')):
364 """
364 """
365 Stores sent exception inside the built-in exception tracker in |RCE| server.
365 Stores sent exception inside the built-in exception tracker in |RCE| server.
366
366
367 This command can only be run using an |authtoken| with admin rights to
367 This command can only be run using an |authtoken| with admin rights to
368 the specified repository.
368 the specified repository.
369
369
370 This command takes the following options:
370 This command takes the following options:
371
371
372 :param apiuser: This is filled automatically from the |authtoken|.
372 :param apiuser: This is filled automatically from the |authtoken|.
373 :type apiuser: AuthUser
373 :type apiuser: AuthUser
374
374
375 :param exc_data_json: JSON data with exception e.g
375 :param exc_data_json: JSON data with exception e.g
376 {"exc_traceback": "Value `1` is not allowed", "exc_type_name": "ValueError"}
376 {"exc_traceback": "Value `1` is not allowed", "exc_type_name": "ValueError"}
377 :type exc_data_json: JSON data
377 :type exc_data_json: JSON data
378
378
379 :param prefix: prefix for error type, e.g 'rhodecode', 'vcsserver', 'rhodecode-tools'
379 :param prefix: prefix for error type, e.g 'rhodecode', 'vcsserver', 'rhodecode-tools'
380 :type prefix: Optional("rhodecode")
380 :type prefix: Optional("rhodecode")
381
381
382 Example output:
382 Example output:
383
383
384 .. code-block:: bash
384 .. code-block:: bash
385
385
386 id : <id_given_in_input>
386 id : <id_given_in_input>
387 "result": {
387 "result": {
388 "exc_id": 139718459226384,
388 "exc_id": 139718459226384,
389 "exc_url": "http://localhost:8080/_admin/settings/exceptions/139718459226384"
389 "exc_url": "http://localhost:8080/_admin/settings/exceptions/139718459226384"
390 }
390 }
391 error : null
391 error : null
392 """
392 """
393 if not has_superadmin_permission(apiuser):
393 if not has_superadmin_permission(apiuser):
394 raise JSONRPCForbidden()
394 raise JSONRPCForbidden()
395
395
396 prefix = Optional.extract(prefix)
396 prefix = Optional.extract(prefix)
397 exc_id = exc_tracking.generate_id()
397 exc_id = exc_tracking.generate_id()
398
398
399 try:
399 try:
400 exc_data = json.loads(exc_data_json)
400 exc_data = json.loads(exc_data_json)
401 except Exception:
401 except Exception:
402 log.error('Failed to parse JSON: %r', exc_data_json)
402 log.error('Failed to parse JSON: %r', exc_data_json)
403 raise JSONRPCError('Failed to parse JSON data from exc_data_json field. '
403 raise JSONRPCError('Failed to parse JSON data from exc_data_json field. '
404 'Please make sure it contains a valid JSON.')
404 'Please make sure it contains a valid JSON.')
405
405
406 try:
406 try:
407 exc_traceback = exc_data['exc_traceback']
407 exc_traceback = exc_data['exc_traceback']
408 exc_type_name = exc_data['exc_type_name']
408 exc_type_name = exc_data['exc_type_name']
409 except KeyError as err:
409 except KeyError as err:
410 raise JSONRPCError('Missing exc_traceback, or exc_type_name '
410 raise JSONRPCError('Missing exc_traceback, or exc_type_name '
411 'in exc_data_json field. Missing: {}'.format(err))
411 'in exc_data_json field. Missing: {}'.format(err))
412
412
413 exc_tracking._store_exception(
413 exc_tracking._store_exception(
414 exc_id=exc_id, exc_traceback=exc_traceback,
414 exc_id=exc_id, exc_traceback=exc_traceback,
415 exc_type_name=exc_type_name, prefix=prefix)
415 exc_type_name=exc_type_name, prefix=prefix)
416
416
417 exc_url = request.route_url(
417 exc_url = request.route_url(
418 'admin_settings_exception_tracker_show', exception_id=exc_id)
418 'admin_settings_exception_tracker_show', exception_id=exc_id)
419 return {'exc_id': exc_id, 'exc_url': exc_url}
419 return {'exc_id': exc_id, 'exc_url': exc_url}
@@ -1,479 +1,479 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2016-2020 RhodeCode GmbH
3 # Copyright (C) 2016-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import re
21 import re
22 import logging
22 import logging
23 import formencode
23 import formencode
24 import formencode.htmlfill
24 import formencode.htmlfill
25 import datetime
25 import datetime
26 from pyramid.interfaces import IRoutesMapper
26 from pyramid.interfaces import IRoutesMapper
27
27
28 from pyramid.httpexceptions import HTTPFound
28 from pyramid.httpexceptions import HTTPFound
29 from pyramid.renderers import render
29 from pyramid.renderers import render
30 from pyramid.response import Response
30 from pyramid.response import Response
31
31
32 from rhodecode.apps._base import BaseAppView, DataGridAppView
32 from rhodecode.apps._base import BaseAppView, DataGridAppView
33 from rhodecode.apps.ssh_support import SshKeyFileChangeEvent
33 from rhodecode.apps.ssh_support import SshKeyFileChangeEvent
34 from rhodecode import events
34 from rhodecode import events
35
35
36 from rhodecode.lib import helpers as h
36 from rhodecode.lib import helpers as h
37 from rhodecode.lib.auth import (
37 from rhodecode.lib.auth import (
38 LoginRequired, HasPermissionAllDecorator, CSRFRequired)
38 LoginRequired, HasPermissionAllDecorator, CSRFRequired)
39 from rhodecode.lib.utils2 import aslist, safe_unicode
39 from rhodecode.lib.utils2 import aslist, safe_unicode
40 from rhodecode.model.db import (
40 from rhodecode.model.db import (
41 or_, coalesce, User, UserIpMap, UserSshKeys)
41 or_, coalesce, User, UserIpMap, UserSshKeys)
42 from rhodecode.model.forms import (
42 from rhodecode.model.forms import (
43 ApplicationPermissionsForm, ObjectPermissionsForm, UserPermissionsForm)
43 ApplicationPermissionsForm, ObjectPermissionsForm, UserPermissionsForm)
44 from rhodecode.model.meta import Session
44 from rhodecode.model.meta import Session
45 from rhodecode.model.permission import PermissionModel
45 from rhodecode.model.permission import PermissionModel
46 from rhodecode.model.settings import SettingsModel
46 from rhodecode.model.settings import SettingsModel
47
47
48
48
49 log = logging.getLogger(__name__)
49 log = logging.getLogger(__name__)
50
50
51
51
52 class AdminPermissionsView(BaseAppView, DataGridAppView):
52 class AdminPermissionsView(BaseAppView, DataGridAppView):
53 def load_default_context(self):
53 def load_default_context(self):
54 c = self._get_local_tmpl_context()
54 c = self._get_local_tmpl_context()
55 PermissionModel().set_global_permission_choices(
55 PermissionModel().set_global_permission_choices(
56 c, gettext_translator=self.request.translate)
56 c, gettext_translator=self.request.translate)
57 return c
57 return c
58
58
59 @LoginRequired()
59 @LoginRequired()
60 @HasPermissionAllDecorator('hg.admin')
60 @HasPermissionAllDecorator('hg.admin')
61 def permissions_application(self):
61 def permissions_application(self):
62 c = self.load_default_context()
62 c = self.load_default_context()
63 c.active = 'application'
63 c.active = 'application'
64
64
65 c.user = User.get_default_user(refresh=True)
65 c.user = User.get_default_user(refresh=True)
66
66
67 app_settings = c.rc_config
67 app_settings = c.rc_config
68
68
69 defaults = {
69 defaults = {
70 'anonymous': c.user.active,
70 'anonymous': c.user.active,
71 'default_register_message': app_settings.get(
71 'default_register_message': app_settings.get(
72 'rhodecode_register_message')
72 'rhodecode_register_message')
73 }
73 }
74 defaults.update(c.user.get_default_perms())
74 defaults.update(c.user.get_default_perms())
75
75
76 data = render('rhodecode:templates/admin/permissions/permissions.mako',
76 data = render('rhodecode:templates/admin/permissions/permissions.mako',
77 self._get_template_context(c), self.request)
77 self._get_template_context(c), self.request)
78 html = formencode.htmlfill.render(
78 html = formencode.htmlfill.render(
79 data,
79 data,
80 defaults=defaults,
80 defaults=defaults,
81 encoding="UTF-8",
81 encoding="UTF-8",
82 force_defaults=False
82 force_defaults=False
83 )
83 )
84 return Response(html)
84 return Response(html)
85
85
86 @LoginRequired()
86 @LoginRequired()
87 @HasPermissionAllDecorator('hg.admin')
87 @HasPermissionAllDecorator('hg.admin')
88 @CSRFRequired()
88 @CSRFRequired()
89 def permissions_application_update(self):
89 def permissions_application_update(self):
90 _ = self.request.translate
90 _ = self.request.translate
91 c = self.load_default_context()
91 c = self.load_default_context()
92 c.active = 'application'
92 c.active = 'application'
93
93
94 _form = ApplicationPermissionsForm(
94 _form = ApplicationPermissionsForm(
95 self.request.translate,
95 self.request.translate,
96 [x[0] for x in c.register_choices],
96 [x[0] for x in c.register_choices],
97 [x[0] for x in c.password_reset_choices],
97 [x[0] for x in c.password_reset_choices],
98 [x[0] for x in c.extern_activate_choices])()
98 [x[0] for x in c.extern_activate_choices])()
99
99
100 try:
100 try:
101 form_result = _form.to_python(dict(self.request.POST))
101 form_result = _form.to_python(dict(self.request.POST))
102 form_result.update({'perm_user_name': User.DEFAULT_USER})
102 form_result.update({'perm_user_name': User.DEFAULT_USER})
103 PermissionModel().update_application_permissions(form_result)
103 PermissionModel().update_application_permissions(form_result)
104
104
105 settings = [
105 settings = [
106 ('register_message', 'default_register_message'),
106 ('register_message', 'default_register_message'),
107 ]
107 ]
108 for setting, form_key in settings:
108 for setting, form_key in settings:
109 sett = SettingsModel().create_or_update_setting(
109 sett = SettingsModel().create_or_update_setting(
110 setting, form_result[form_key])
110 setting, form_result[form_key])
111 Session().add(sett)
111 Session().add(sett)
112
112
113 Session().commit()
113 Session().commit()
114 h.flash(_('Application permissions updated successfully'),
114 h.flash(_('Application permissions updated successfully'),
115 category='success')
115 category='success')
116
116
117 except formencode.Invalid as errors:
117 except formencode.Invalid as errors:
118 defaults = errors.value
118 defaults = errors.value
119
119
120 data = render(
120 data = render(
121 'rhodecode:templates/admin/permissions/permissions.mako',
121 'rhodecode:templates/admin/permissions/permissions.mako',
122 self._get_template_context(c), self.request)
122 self._get_template_context(c), self.request)
123 html = formencode.htmlfill.render(
123 html = formencode.htmlfill.render(
124 data,
124 data,
125 defaults=defaults,
125 defaults=defaults,
126 errors=errors.error_dict or {},
126 errors=errors.error_dict or {},
127 prefix_error=False,
127 prefix_error=False,
128 encoding="UTF-8",
128 encoding="UTF-8",
129 force_defaults=False
129 force_defaults=False
130 )
130 )
131 return Response(html)
131 return Response(html)
132
132
133 except Exception:
133 except Exception:
134 log.exception("Exception during update of permissions")
134 log.exception("Exception during update of permissions")
135 h.flash(_('Error occurred during update of permissions'),
135 h.flash(_('Error occurred during update of permissions'),
136 category='error')
136 category='error')
137
137
138 affected_user_ids = [User.get_default_user_id()]
138 affected_user_ids = [User.get_default_user_id()]
139 PermissionModel().trigger_permission_flush(affected_user_ids)
139 PermissionModel().trigger_permission_flush(affected_user_ids)
140
140
141 raise HTTPFound(h.route_path('admin_permissions_application'))
141 raise HTTPFound(h.route_path('admin_permissions_application'))
142
142
143 @LoginRequired()
143 @LoginRequired()
144 @HasPermissionAllDecorator('hg.admin')
144 @HasPermissionAllDecorator('hg.admin')
145 def permissions_objects(self):
145 def permissions_objects(self):
146 c = self.load_default_context()
146 c = self.load_default_context()
147 c.active = 'objects'
147 c.active = 'objects'
148
148
149 c.user = User.get_default_user(refresh=True)
149 c.user = User.get_default_user(refresh=True)
150 defaults = {}
150 defaults = {}
151 defaults.update(c.user.get_default_perms())
151 defaults.update(c.user.get_default_perms())
152
152
153 data = render(
153 data = render(
154 'rhodecode:templates/admin/permissions/permissions.mako',
154 'rhodecode:templates/admin/permissions/permissions.mako',
155 self._get_template_context(c), self.request)
155 self._get_template_context(c), self.request)
156 html = formencode.htmlfill.render(
156 html = formencode.htmlfill.render(
157 data,
157 data,
158 defaults=defaults,
158 defaults=defaults,
159 encoding="UTF-8",
159 encoding="UTF-8",
160 force_defaults=False
160 force_defaults=False
161 )
161 )
162 return Response(html)
162 return Response(html)
163
163
164 @LoginRequired()
164 @LoginRequired()
165 @HasPermissionAllDecorator('hg.admin')
165 @HasPermissionAllDecorator('hg.admin')
166 @CSRFRequired()
166 @CSRFRequired()
167 def permissions_objects_update(self):
167 def permissions_objects_update(self):
168 _ = self.request.translate
168 _ = self.request.translate
169 c = self.load_default_context()
169 c = self.load_default_context()
170 c.active = 'objects'
170 c.active = 'objects'
171
171
172 _form = ObjectPermissionsForm(
172 _form = ObjectPermissionsForm(
173 self.request.translate,
173 self.request.translate,
174 [x[0] for x in c.repo_perms_choices],
174 [x[0] for x in c.repo_perms_choices],
175 [x[0] for x in c.group_perms_choices],
175 [x[0] for x in c.group_perms_choices],
176 [x[0] for x in c.user_group_perms_choices],
176 [x[0] for x in c.user_group_perms_choices],
177 )()
177 )()
178
178
179 try:
179 try:
180 form_result = _form.to_python(dict(self.request.POST))
180 form_result = _form.to_python(dict(self.request.POST))
181 form_result.update({'perm_user_name': User.DEFAULT_USER})
181 form_result.update({'perm_user_name': User.DEFAULT_USER})
182 PermissionModel().update_object_permissions(form_result)
182 PermissionModel().update_object_permissions(form_result)
183
183
184 Session().commit()
184 Session().commit()
185 h.flash(_('Object permissions updated successfully'),
185 h.flash(_('Object permissions updated successfully'),
186 category='success')
186 category='success')
187
187
188 except formencode.Invalid as errors:
188 except formencode.Invalid as errors:
189 defaults = errors.value
189 defaults = errors.value
190
190
191 data = render(
191 data = render(
192 'rhodecode:templates/admin/permissions/permissions.mako',
192 'rhodecode:templates/admin/permissions/permissions.mako',
193 self._get_template_context(c), self.request)
193 self._get_template_context(c), self.request)
194 html = formencode.htmlfill.render(
194 html = formencode.htmlfill.render(
195 data,
195 data,
196 defaults=defaults,
196 defaults=defaults,
197 errors=errors.error_dict or {},
197 errors=errors.error_dict or {},
198 prefix_error=False,
198 prefix_error=False,
199 encoding="UTF-8",
199 encoding="UTF-8",
200 force_defaults=False
200 force_defaults=False
201 )
201 )
202 return Response(html)
202 return Response(html)
203 except Exception:
203 except Exception:
204 log.exception("Exception during update of permissions")
204 log.exception("Exception during update of permissions")
205 h.flash(_('Error occurred during update of permissions'),
205 h.flash(_('Error occurred during update of permissions'),
206 category='error')
206 category='error')
207
207
208 affected_user_ids = [User.get_default_user_id()]
208 affected_user_ids = [User.get_default_user_id()]
209 PermissionModel().trigger_permission_flush(affected_user_ids)
209 PermissionModel().trigger_permission_flush(affected_user_ids)
210
210
211 raise HTTPFound(h.route_path('admin_permissions_object'))
211 raise HTTPFound(h.route_path('admin_permissions_object'))
212
212
213 @LoginRequired()
213 @LoginRequired()
214 @HasPermissionAllDecorator('hg.admin')
214 @HasPermissionAllDecorator('hg.admin')
215 def permissions_branch(self):
215 def permissions_branch(self):
216 c = self.load_default_context()
216 c = self.load_default_context()
217 c.active = 'branch'
217 c.active = 'branch'
218
218
219 c.user = User.get_default_user(refresh=True)
219 c.user = User.get_default_user(refresh=True)
220 defaults = {}
220 defaults = {}
221 defaults.update(c.user.get_default_perms())
221 defaults.update(c.user.get_default_perms())
222
222
223 data = render(
223 data = render(
224 'rhodecode:templates/admin/permissions/permissions.mako',
224 'rhodecode:templates/admin/permissions/permissions.mako',
225 self._get_template_context(c), self.request)
225 self._get_template_context(c), self.request)
226 html = formencode.htmlfill.render(
226 html = formencode.htmlfill.render(
227 data,
227 data,
228 defaults=defaults,
228 defaults=defaults,
229 encoding="UTF-8",
229 encoding="UTF-8",
230 force_defaults=False
230 force_defaults=False
231 )
231 )
232 return Response(html)
232 return Response(html)
233
233
234 @LoginRequired()
234 @LoginRequired()
235 @HasPermissionAllDecorator('hg.admin')
235 @HasPermissionAllDecorator('hg.admin')
236 def permissions_global(self):
236 def permissions_global(self):
237 c = self.load_default_context()
237 c = self.load_default_context()
238 c.active = 'global'
238 c.active = 'global'
239
239
240 c.user = User.get_default_user(refresh=True)
240 c.user = User.get_default_user(refresh=True)
241 defaults = {}
241 defaults = {}
242 defaults.update(c.user.get_default_perms())
242 defaults.update(c.user.get_default_perms())
243
243
244 data = render(
244 data = render(
245 'rhodecode:templates/admin/permissions/permissions.mako',
245 'rhodecode:templates/admin/permissions/permissions.mako',
246 self._get_template_context(c), self.request)
246 self._get_template_context(c), self.request)
247 html = formencode.htmlfill.render(
247 html = formencode.htmlfill.render(
248 data,
248 data,
249 defaults=defaults,
249 defaults=defaults,
250 encoding="UTF-8",
250 encoding="UTF-8",
251 force_defaults=False
251 force_defaults=False
252 )
252 )
253 return Response(html)
253 return Response(html)
254
254
255 @LoginRequired()
255 @LoginRequired()
256 @HasPermissionAllDecorator('hg.admin')
256 @HasPermissionAllDecorator('hg.admin')
257 @CSRFRequired()
257 @CSRFRequired()
258 def permissions_global_update(self):
258 def permissions_global_update(self):
259 _ = self.request.translate
259 _ = self.request.translate
260 c = self.load_default_context()
260 c = self.load_default_context()
261 c.active = 'global'
261 c.active = 'global'
262
262
263 _form = UserPermissionsForm(
263 _form = UserPermissionsForm(
264 self.request.translate,
264 self.request.translate,
265 [x[0] for x in c.repo_create_choices],
265 [x[0] for x in c.repo_create_choices],
266 [x[0] for x in c.repo_create_on_write_choices],
266 [x[0] for x in c.repo_create_on_write_choices],
267 [x[0] for x in c.repo_group_create_choices],
267 [x[0] for x in c.repo_group_create_choices],
268 [x[0] for x in c.user_group_create_choices],
268 [x[0] for x in c.user_group_create_choices],
269 [x[0] for x in c.fork_choices],
269 [x[0] for x in c.fork_choices],
270 [x[0] for x in c.inherit_default_permission_choices])()
270 [x[0] for x in c.inherit_default_permission_choices])()
271
271
272 try:
272 try:
273 form_result = _form.to_python(dict(self.request.POST))
273 form_result = _form.to_python(dict(self.request.POST))
274 form_result.update({'perm_user_name': User.DEFAULT_USER})
274 form_result.update({'perm_user_name': User.DEFAULT_USER})
275 PermissionModel().update_user_permissions(form_result)
275 PermissionModel().update_user_permissions(form_result)
276
276
277 Session().commit()
277 Session().commit()
278 h.flash(_('Global permissions updated successfully'),
278 h.flash(_('Global permissions updated successfully'),
279 category='success')
279 category='success')
280
280
281 except formencode.Invalid as errors:
281 except formencode.Invalid as errors:
282 defaults = errors.value
282 defaults = errors.value
283
283
284 data = render(
284 data = render(
285 'rhodecode:templates/admin/permissions/permissions.mako',
285 'rhodecode:templates/admin/permissions/permissions.mako',
286 self._get_template_context(c), self.request)
286 self._get_template_context(c), self.request)
287 html = formencode.htmlfill.render(
287 html = formencode.htmlfill.render(
288 data,
288 data,
289 defaults=defaults,
289 defaults=defaults,
290 errors=errors.error_dict or {},
290 errors=errors.error_dict or {},
291 prefix_error=False,
291 prefix_error=False,
292 encoding="UTF-8",
292 encoding="UTF-8",
293 force_defaults=False
293 force_defaults=False
294 )
294 )
295 return Response(html)
295 return Response(html)
296 except Exception:
296 except Exception:
297 log.exception("Exception during update of permissions")
297 log.exception("Exception during update of permissions")
298 h.flash(_('Error occurred during update of permissions'),
298 h.flash(_('Error occurred during update of permissions'),
299 category='error')
299 category='error')
300
300
301 affected_user_ids = [User.get_default_user_id()]
301 affected_user_ids = [User.get_default_user_id()]
302 PermissionModel().trigger_permission_flush(affected_user_ids)
302 PermissionModel().trigger_permission_flush(affected_user_ids)
303
303
304 raise HTTPFound(h.route_path('admin_permissions_global'))
304 raise HTTPFound(h.route_path('admin_permissions_global'))
305
305
306 @LoginRequired()
306 @LoginRequired()
307 @HasPermissionAllDecorator('hg.admin')
307 @HasPermissionAllDecorator('hg.admin')
308 def permissions_ips(self):
308 def permissions_ips(self):
309 c = self.load_default_context()
309 c = self.load_default_context()
310 c.active = 'ips'
310 c.active = 'ips'
311
311
312 c.user = User.get_default_user(refresh=True)
312 c.user = User.get_default_user(refresh=True)
313 c.user_ip_map = (
313 c.user_ip_map = (
314 UserIpMap.query().filter(UserIpMap.user == c.user).all())
314 UserIpMap.query().filter(UserIpMap.user == c.user).all())
315
315
316 return self._get_template_context(c)
316 return self._get_template_context(c)
317
317
318 @LoginRequired()
318 @LoginRequired()
319 @HasPermissionAllDecorator('hg.admin')
319 @HasPermissionAllDecorator('hg.admin')
320 def permissions_overview(self):
320 def permissions_overview(self):
321 c = self.load_default_context()
321 c = self.load_default_context()
322 c.active = 'perms'
322 c.active = 'perms'
323
323
324 c.user = User.get_default_user(refresh=True)
324 c.user = User.get_default_user(refresh=True)
325 c.perm_user = c.user.AuthUser()
325 c.perm_user = c.user.AuthUser()
326 return self._get_template_context(c)
326 return self._get_template_context(c)
327
327
328 @LoginRequired()
328 @LoginRequired()
329 @HasPermissionAllDecorator('hg.admin')
329 @HasPermissionAllDecorator('hg.admin')
330 def auth_token_access(self):
330 def auth_token_access(self):
331 from rhodecode import CONFIG
331 from rhodecode import CONFIG
332
332
333 c = self.load_default_context()
333 c = self.load_default_context()
334 c.active = 'auth_token_access'
334 c.active = 'auth_token_access'
335
335
336 c.user = User.get_default_user(refresh=True)
336 c.user = User.get_default_user(refresh=True)
337 c.perm_user = c.user.AuthUser()
337 c.perm_user = c.user.AuthUser()
338
338
339 mapper = self.request.registry.queryUtility(IRoutesMapper)
339 mapper = self.request.registry.queryUtility(IRoutesMapper)
340 c.view_data = []
340 c.view_data = []
341
341
342 _argument_prog = re.compile('\{(.*?)\}|:\((.*)\)')
342 _argument_prog = re.compile(r'\{(.*?)\}|:\((.*)\)')
343 introspector = self.request.registry.introspector
343 introspector = self.request.registry.introspector
344
344
345 view_intr = {}
345 view_intr = {}
346 for view_data in introspector.get_category('views'):
346 for view_data in introspector.get_category('views'):
347 intr = view_data['introspectable']
347 intr = view_data['introspectable']
348
348
349 if 'route_name' in intr and intr['attr']:
349 if 'route_name' in intr and intr['attr']:
350 view_intr[intr['route_name']] = '{}:{}'.format(
350 view_intr[intr['route_name']] = '{}:{}'.format(
351 str(intr['derived_callable'].__name__), intr['attr']
351 str(intr['derived_callable'].__name__), intr['attr']
352 )
352 )
353
353
354 c.whitelist_key = 'api_access_controllers_whitelist'
354 c.whitelist_key = 'api_access_controllers_whitelist'
355 c.whitelist_file = CONFIG.get('__file__')
355 c.whitelist_file = CONFIG.get('__file__')
356 whitelist_views = aslist(
356 whitelist_views = aslist(
357 CONFIG.get(c.whitelist_key), sep=',')
357 CONFIG.get(c.whitelist_key), sep=',')
358
358
359 for route_info in mapper.get_routes():
359 for route_info in mapper.get_routes():
360 if not route_info.name.startswith('__'):
360 if not route_info.name.startswith('__'):
361 routepath = route_info.pattern
361 routepath = route_info.pattern
362
362
363 def replace(matchobj):
363 def replace(matchobj):
364 if matchobj.group(1):
364 if matchobj.group(1):
365 return "{%s}" % matchobj.group(1).split(':')[0]
365 return "{%s}" % matchobj.group(1).split(':')[0]
366 else:
366 else:
367 return "{%s}" % matchobj.group(2)
367 return "{%s}" % matchobj.group(2)
368
368
369 routepath = _argument_prog.sub(replace, routepath)
369 routepath = _argument_prog.sub(replace, routepath)
370
370
371 if not routepath.startswith('/'):
371 if not routepath.startswith('/'):
372 routepath = '/' + routepath
372 routepath = '/' + routepath
373
373
374 view_fqn = view_intr.get(route_info.name, 'NOT AVAILABLE')
374 view_fqn = view_intr.get(route_info.name, 'NOT AVAILABLE')
375 active = view_fqn in whitelist_views
375 active = view_fqn in whitelist_views
376 c.view_data.append((route_info.name, view_fqn, routepath, active))
376 c.view_data.append((route_info.name, view_fqn, routepath, active))
377
377
378 c.whitelist_views = whitelist_views
378 c.whitelist_views = whitelist_views
379 return self._get_template_context(c)
379 return self._get_template_context(c)
380
380
381 def ssh_enabled(self):
381 def ssh_enabled(self):
382 return self.request.registry.settings.get(
382 return self.request.registry.settings.get(
383 'ssh.generate_authorized_keyfile')
383 'ssh.generate_authorized_keyfile')
384
384
385 @LoginRequired()
385 @LoginRequired()
386 @HasPermissionAllDecorator('hg.admin')
386 @HasPermissionAllDecorator('hg.admin')
387 def ssh_keys(self):
387 def ssh_keys(self):
388 c = self.load_default_context()
388 c = self.load_default_context()
389 c.active = 'ssh_keys'
389 c.active = 'ssh_keys'
390 c.ssh_enabled = self.ssh_enabled()
390 c.ssh_enabled = self.ssh_enabled()
391 return self._get_template_context(c)
391 return self._get_template_context(c)
392
392
393 @LoginRequired()
393 @LoginRequired()
394 @HasPermissionAllDecorator('hg.admin')
394 @HasPermissionAllDecorator('hg.admin')
395 def ssh_keys_data(self):
395 def ssh_keys_data(self):
396 _ = self.request.translate
396 _ = self.request.translate
397 self.load_default_context()
397 self.load_default_context()
398 column_map = {
398 column_map = {
399 'fingerprint': 'ssh_key_fingerprint',
399 'fingerprint': 'ssh_key_fingerprint',
400 'username': User.username
400 'username': User.username
401 }
401 }
402 draw, start, limit = self._extract_chunk(self.request)
402 draw, start, limit = self._extract_chunk(self.request)
403 search_q, order_by, order_dir = self._extract_ordering(
403 search_q, order_by, order_dir = self._extract_ordering(
404 self.request, column_map=column_map)
404 self.request, column_map=column_map)
405
405
406 ssh_keys_data_total_count = UserSshKeys.query()\
406 ssh_keys_data_total_count = UserSshKeys.query()\
407 .count()
407 .count()
408
408
409 # json generate
409 # json generate
410 base_q = UserSshKeys.query().join(UserSshKeys.user)
410 base_q = UserSshKeys.query().join(UserSshKeys.user)
411
411
412 if search_q:
412 if search_q:
413 like_expression = u'%{}%'.format(safe_unicode(search_q))
413 like_expression = u'%{}%'.format(safe_unicode(search_q))
414 base_q = base_q.filter(or_(
414 base_q = base_q.filter(or_(
415 User.username.ilike(like_expression),
415 User.username.ilike(like_expression),
416 UserSshKeys.ssh_key_fingerprint.ilike(like_expression),
416 UserSshKeys.ssh_key_fingerprint.ilike(like_expression),
417 ))
417 ))
418
418
419 users_data_total_filtered_count = base_q.count()
419 users_data_total_filtered_count = base_q.count()
420
420
421 sort_col = self._get_order_col(order_by, UserSshKeys)
421 sort_col = self._get_order_col(order_by, UserSshKeys)
422 if sort_col:
422 if sort_col:
423 if order_dir == 'asc':
423 if order_dir == 'asc':
424 # handle null values properly to order by NULL last
424 # handle null values properly to order by NULL last
425 if order_by in ['created_on']:
425 if order_by in ['created_on']:
426 sort_col = coalesce(sort_col, datetime.date.max)
426 sort_col = coalesce(sort_col, datetime.date.max)
427 sort_col = sort_col.asc()
427 sort_col = sort_col.asc()
428 else:
428 else:
429 # handle null values properly to order by NULL last
429 # handle null values properly to order by NULL last
430 if order_by in ['created_on']:
430 if order_by in ['created_on']:
431 sort_col = coalesce(sort_col, datetime.date.min)
431 sort_col = coalesce(sort_col, datetime.date.min)
432 sort_col = sort_col.desc()
432 sort_col = sort_col.desc()
433
433
434 base_q = base_q.order_by(sort_col)
434 base_q = base_q.order_by(sort_col)
435 base_q = base_q.offset(start).limit(limit)
435 base_q = base_q.offset(start).limit(limit)
436
436
437 ssh_keys = base_q.all()
437 ssh_keys = base_q.all()
438
438
439 ssh_keys_data = []
439 ssh_keys_data = []
440 for ssh_key in ssh_keys:
440 for ssh_key in ssh_keys:
441 ssh_keys_data.append({
441 ssh_keys_data.append({
442 "username": h.gravatar_with_user(self.request, ssh_key.user.username),
442 "username": h.gravatar_with_user(self.request, ssh_key.user.username),
443 "fingerprint": ssh_key.ssh_key_fingerprint,
443 "fingerprint": ssh_key.ssh_key_fingerprint,
444 "description": ssh_key.description,
444 "description": ssh_key.description,
445 "created_on": h.format_date(ssh_key.created_on),
445 "created_on": h.format_date(ssh_key.created_on),
446 "accessed_on": h.format_date(ssh_key.accessed_on),
446 "accessed_on": h.format_date(ssh_key.accessed_on),
447 "action": h.link_to(
447 "action": h.link_to(
448 _('Edit'), h.route_path('edit_user_ssh_keys',
448 _('Edit'), h.route_path('edit_user_ssh_keys',
449 user_id=ssh_key.user.user_id))
449 user_id=ssh_key.user.user_id))
450 })
450 })
451
451
452 data = ({
452 data = ({
453 'draw': draw,
453 'draw': draw,
454 'data': ssh_keys_data,
454 'data': ssh_keys_data,
455 'recordsTotal': ssh_keys_data_total_count,
455 'recordsTotal': ssh_keys_data_total_count,
456 'recordsFiltered': users_data_total_filtered_count,
456 'recordsFiltered': users_data_total_filtered_count,
457 })
457 })
458
458
459 return data
459 return data
460
460
461 @LoginRequired()
461 @LoginRequired()
462 @HasPermissionAllDecorator('hg.admin')
462 @HasPermissionAllDecorator('hg.admin')
463 @CSRFRequired()
463 @CSRFRequired()
464 def ssh_keys_update(self):
464 def ssh_keys_update(self):
465 _ = self.request.translate
465 _ = self.request.translate
466 self.load_default_context()
466 self.load_default_context()
467
467
468 ssh_enabled = self.ssh_enabled()
468 ssh_enabled = self.ssh_enabled()
469 key_file = self.request.registry.settings.get(
469 key_file = self.request.registry.settings.get(
470 'ssh.authorized_keys_file_path')
470 'ssh.authorized_keys_file_path')
471 if ssh_enabled:
471 if ssh_enabled:
472 events.trigger(SshKeyFileChangeEvent(), self.request.registry)
472 events.trigger(SshKeyFileChangeEvent(), self.request.registry)
473 h.flash(_('Updated SSH keys file: {}').format(key_file),
473 h.flash(_('Updated SSH keys file: {}').format(key_file),
474 category='success')
474 category='success')
475 else:
475 else:
476 h.flash(_('SSH key support is disabled in .ini file'),
476 h.flash(_('SSH key support is disabled in .ini file'),
477 category='warning')
477 category='warning')
478
478
479 raise HTTPFound(h.route_path('admin_permissions_ssh_keys'))
479 raise HTTPFound(h.route_path('admin_permissions_ssh_keys'))
@@ -1,58 +1,57 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2016-2020 RhodeCode GmbH
3 # Copyright (C) 2016-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21
21 import io
22 import uuid
22 import uuid
23 from io import StringIO
24 import pathlib2
23 import pathlib2
25
24
26
25
27 def get_file_storage(settings):
26 def get_file_storage(settings):
28 from rhodecode.apps.file_store.backends.local_store import LocalFileStorage
27 from rhodecode.apps.file_store.backends.local_store import LocalFileStorage
29 from rhodecode.apps.file_store import config_keys
28 from rhodecode.apps.file_store import config_keys
30 store_path = settings.get(config_keys.store_path)
29 store_path = settings.get(config_keys.store_path)
31 return LocalFileStorage(base_path=store_path)
30 return LocalFileStorage(base_path=store_path)
32
31
33
32
34 def splitext(filename):
33 def splitext(filename):
35 ext = ''.join(pathlib2.Path(filename).suffixes)
34 ext = ''.join(pathlib2.Path(filename).suffixes)
36 return filename, ext
35 return filename, ext
37
36
38
37
39 def uid_filename(filename, randomized=True):
38 def uid_filename(filename, randomized=True):
40 """
39 """
41 Generates a randomized or stable (uuid) filename,
40 Generates a randomized or stable (uuid) filename,
42 preserving the original extension.
41 preserving the original extension.
43
42
44 :param filename: the original filename
43 :param filename: the original filename
45 :param randomized: define if filename should be stable (sha1 based) or randomized
44 :param randomized: define if filename should be stable (sha1 based) or randomized
46 """
45 """
47
46
48 _, ext = splitext(filename)
47 _, ext = splitext(filename)
49 if randomized:
48 if randomized:
50 uid = uuid.uuid4()
49 uid = uuid.uuid4()
51 else:
50 else:
52 hash_key = '{}.{}'.format(filename, 'store')
51 hash_key = '{}.{}'.format(filename, 'store')
53 uid = uuid.uuid5(uuid.NAMESPACE_URL, hash_key)
52 uid = uuid.uuid5(uuid.NAMESPACE_URL, hash_key)
54 return str(uid) + ext.lower()
53 return str(uid) + ext.lower()
55
54
56
55
57 def bytes_to_file_obj(bytes_data):
56 def bytes_to_file_obj(bytes_data):
58 return StringIO.StringIO(bytes_data)
57 return io.StringIO(bytes_data)
@@ -1,580 +1,580 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import urllib.parse
21 import urllib.parse
22
22
23 import mock
23 import mock
24 import pytest
24 import pytest
25
25
26 from rhodecode.tests import (
26 from rhodecode.tests import (
27 assert_session_flash, HG_REPO, TEST_USER_ADMIN_LOGIN,
27 assert_session_flash, HG_REPO, TEST_USER_ADMIN_LOGIN,
28 no_newline_id_generator)
28 no_newline_id_generator)
29 from rhodecode.tests.fixture import Fixture
29 from rhodecode.tests.fixture import Fixture
30 from rhodecode.lib.auth import check_password
30 from rhodecode.lib.auth import check_password
31 from rhodecode.lib import helpers as h
31 from rhodecode.lib import helpers as h
32 from rhodecode.model.auth_token import AuthTokenModel
32 from rhodecode.model.auth_token import AuthTokenModel
33 from rhodecode.model.db import User, Notification, UserApiKeys
33 from rhodecode.model.db import User, Notification, UserApiKeys
34 from rhodecode.model.meta import Session
34 from rhodecode.model.meta import Session
35
35
36 fixture = Fixture()
36 fixture = Fixture()
37
37
38 whitelist_view = ['RepoCommitsView:repo_commit_raw']
38 whitelist_view = ['RepoCommitsView:repo_commit_raw']
39
39
40
40
41 def route_path(name, params=None, **kwargs):
41 def route_path(name, params=None, **kwargs):
42 import urllib.request, urllib.parse, urllib.error
42 import urllib.request, urllib.parse, urllib.error
43 from rhodecode.apps._base import ADMIN_PREFIX
43 from rhodecode.apps._base import ADMIN_PREFIX
44
44
45 base_url = {
45 base_url = {
46 'login': ADMIN_PREFIX + '/login',
46 'login': ADMIN_PREFIX + '/login',
47 'logout': ADMIN_PREFIX + '/logout',
47 'logout': ADMIN_PREFIX + '/logout',
48 'register': ADMIN_PREFIX + '/register',
48 'register': ADMIN_PREFIX + '/register',
49 'reset_password':
49 'reset_password':
50 ADMIN_PREFIX + '/password_reset',
50 ADMIN_PREFIX + '/password_reset',
51 'reset_password_confirmation':
51 'reset_password_confirmation':
52 ADMIN_PREFIX + '/password_reset_confirmation',
52 ADMIN_PREFIX + '/password_reset_confirmation',
53
53
54 'admin_permissions_application':
54 'admin_permissions_application':
55 ADMIN_PREFIX + '/permissions/application',
55 ADMIN_PREFIX + '/permissions/application',
56 'admin_permissions_application_update':
56 'admin_permissions_application_update':
57 ADMIN_PREFIX + '/permissions/application/update',
57 ADMIN_PREFIX + '/permissions/application/update',
58
58
59 'repo_commit_raw': '/{repo_name}/raw-changeset/{commit_id}'
59 'repo_commit_raw': '/{repo_name}/raw-changeset/{commit_id}'
60
60
61 }[name].format(**kwargs)
61 }[name].format(**kwargs)
62
62
63 if params:
63 if params:
64 base_url = '{}?{}'.format(base_url, urllib.parse.urlencode(params))
64 base_url = '{}?{}'.format(base_url, urllib.parse.urlencode(params))
65 return base_url
65 return base_url
66
66
67
67
68 @pytest.mark.usefixtures('app')
68 @pytest.mark.usefixtures('app')
69 class TestLoginController(object):
69 class TestLoginController(object):
70 destroy_users = set()
70 destroy_users = set()
71
71
72 @classmethod
72 @classmethod
73 def teardown_class(cls):
73 def teardown_class(cls):
74 fixture.destroy_users(cls.destroy_users)
74 fixture.destroy_users(cls.destroy_users)
75
75
76 def teardown_method(self, method):
76 def teardown_method(self, method):
77 for n in Notification.query().all():
77 for n in Notification.query().all():
78 Session().delete(n)
78 Session().delete(n)
79
79
80 Session().commit()
80 Session().commit()
81 assert Notification.query().all() == []
81 assert Notification.query().all() == []
82
82
83 def test_index(self):
83 def test_index(self):
84 response = self.app.get(route_path('login'))
84 response = self.app.get(route_path('login'))
85 assert response.status == '200 OK'
85 assert response.status == '200 OK'
86 # Test response...
86 # Test response...
87
87
88 def test_login_admin_ok(self):
88 def test_login_admin_ok(self):
89 response = self.app.post(route_path('login'),
89 response = self.app.post(route_path('login'),
90 {'username': 'test_admin',
90 {'username': 'test_admin',
91 'password': 'test12'}, status=302)
91 'password': 'test12'}, status=302)
92 response = response.follow()
92 response = response.follow()
93 session = response.get_session_from_response()
93 session = response.get_session_from_response()
94 username = session['rhodecode_user'].get('username')
94 username = session['rhodecode_user'].get('username')
95 assert username == 'test_admin'
95 assert username == 'test_admin'
96 response.mustcontain('logout')
96 response.mustcontain('logout')
97
97
98 def test_login_regular_ok(self):
98 def test_login_regular_ok(self):
99 response = self.app.post(route_path('login'),
99 response = self.app.post(route_path('login'),
100 {'username': 'test_regular',
100 {'username': 'test_regular',
101 'password': 'test12'}, status=302)
101 'password': 'test12'}, status=302)
102
102
103 response = response.follow()
103 response = response.follow()
104 session = response.get_session_from_response()
104 session = response.get_session_from_response()
105 username = session['rhodecode_user'].get('username')
105 username = session['rhodecode_user'].get('username')
106 assert username == 'test_regular'
106 assert username == 'test_regular'
107 response.mustcontain('logout')
107 response.mustcontain('logout')
108
108
109 def test_login_regular_forbidden_when_super_admin_restriction(self):
109 def test_login_regular_forbidden_when_super_admin_restriction(self):
110 from rhodecode.authentication.plugins.auth_rhodecode import RhodeCodeAuthPlugin
110 from rhodecode.authentication.plugins.auth_rhodecode import RhodeCodeAuthPlugin
111 with fixture.auth_restriction(self.app._pyramid_registry,
111 with fixture.auth_restriction(self.app._pyramid_registry,
112 RhodeCodeAuthPlugin.AUTH_RESTRICTION_SUPER_ADMIN):
112 RhodeCodeAuthPlugin.AUTH_RESTRICTION_SUPER_ADMIN):
113 response = self.app.post(route_path('login'),
113 response = self.app.post(route_path('login'),
114 {'username': 'test_regular',
114 {'username': 'test_regular',
115 'password': 'test12'})
115 'password': 'test12'})
116
116
117 response.mustcontain('invalid user name')
117 response.mustcontain('invalid user name')
118 response.mustcontain('invalid password')
118 response.mustcontain('invalid password')
119
119
120 def test_login_regular_forbidden_when_scope_restriction(self):
120 def test_login_regular_forbidden_when_scope_restriction(self):
121 from rhodecode.authentication.plugins.auth_rhodecode import RhodeCodeAuthPlugin
121 from rhodecode.authentication.plugins.auth_rhodecode import RhodeCodeAuthPlugin
122 with fixture.scope_restriction(self.app._pyramid_registry,
122 with fixture.scope_restriction(self.app._pyramid_registry,
123 RhodeCodeAuthPlugin.AUTH_RESTRICTION_SCOPE_VCS):
123 RhodeCodeAuthPlugin.AUTH_RESTRICTION_SCOPE_VCS):
124 response = self.app.post(route_path('login'),
124 response = self.app.post(route_path('login'),
125 {'username': 'test_regular',
125 {'username': 'test_regular',
126 'password': 'test12'})
126 'password': 'test12'})
127
127
128 response.mustcontain('invalid user name')
128 response.mustcontain('invalid user name')
129 response.mustcontain('invalid password')
129 response.mustcontain('invalid password')
130
130
131 def test_login_ok_came_from(self):
131 def test_login_ok_came_from(self):
132 test_came_from = '/_admin/users?branch=stable'
132 test_came_from = '/_admin/users?branch=stable'
133 _url = '{}?came_from={}'.format(route_path('login'), test_came_from)
133 _url = '{}?came_from={}'.format(route_path('login'), test_came_from)
134 response = self.app.post(
134 response = self.app.post(
135 _url, {'username': 'test_admin', 'password': 'test12'}, status=302)
135 _url, {'username': 'test_admin', 'password': 'test12'}, status=302)
136
136
137 assert 'branch=stable' in response.location
137 assert 'branch=stable' in response.location
138 response = response.follow()
138 response = response.follow()
139
139
140 assert response.status == '200 OK'
140 assert response.status == '200 OK'
141 response.mustcontain('Users administration')
141 response.mustcontain('Users administration')
142
142
143 def test_redirect_to_login_with_get_args(self):
143 def test_redirect_to_login_with_get_args(self):
144 with fixture.anon_access(False):
144 with fixture.anon_access(False):
145 kwargs = {'branch': 'stable'}
145 kwargs = {'branch': 'stable'}
146 response = self.app.get(
146 response = self.app.get(
147 h.route_path('repo_summary', repo_name=HG_REPO, _query=kwargs),
147 h.route_path('repo_summary', repo_name=HG_REPO, _query=kwargs),
148 status=302)
148 status=302)
149
149
150 response_query = urllib.parse.urlparse.parse_qsl(response.location)
150 response_query = urllib.parse.parse_qsl(response.location)
151 assert 'branch=stable' in response_query[0][1]
151 assert 'branch=stable' in response_query[0][1]
152
152
153 def test_login_form_with_get_args(self):
153 def test_login_form_with_get_args(self):
154 _url = '{}?came_from=/_admin/users,branch=stable'.format(route_path('login'))
154 _url = '{}?came_from=/_admin/users,branch=stable'.format(route_path('login'))
155 response = self.app.get(_url)
155 response = self.app.get(_url)
156 assert 'branch%3Dstable' in response.form.action
156 assert 'branch%3Dstable' in response.form.action
157
157
158 @pytest.mark.parametrize("url_came_from", [
158 @pytest.mark.parametrize("url_came_from", [
159 'data:text/html,<script>window.alert("xss")</script>',
159 'data:text/html,<script>window.alert("xss")</script>',
160 'mailto:test@rhodecode.org',
160 'mailto:test@rhodecode.org',
161 'file:///etc/passwd',
161 'file:///etc/passwd',
162 'ftp://some.ftp.server',
162 'ftp://some.ftp.server',
163 'http://other.domain',
163 'http://other.domain',
164 '/\r\nX-Forwarded-Host: http://example.org',
164 '/\r\nX-Forwarded-Host: http://example.org',
165 ], ids=no_newline_id_generator)
165 ], ids=no_newline_id_generator)
166 def test_login_bad_came_froms(self, url_came_from):
166 def test_login_bad_came_froms(self, url_came_from):
167 _url = '{}?came_from={}'.format(route_path('login'), url_came_from)
167 _url = '{}?came_from={}'.format(route_path('login'), url_came_from)
168 response = self.app.post(
168 response = self.app.post(
169 _url,
169 _url,
170 {'username': 'test_admin', 'password': 'test12'})
170 {'username': 'test_admin', 'password': 'test12'})
171 assert response.status == '302 Found'
171 assert response.status == '302 Found'
172 response = response.follow()
172 response = response.follow()
173 assert response.status == '200 OK'
173 assert response.status == '200 OK'
174 assert response.request.path == '/'
174 assert response.request.path == '/'
175
175
176 def test_login_short_password(self):
176 def test_login_short_password(self):
177 response = self.app.post(route_path('login'),
177 response = self.app.post(route_path('login'),
178 {'username': 'test_admin',
178 {'username': 'test_admin',
179 'password': 'as'})
179 'password': 'as'})
180 assert response.status == '200 OK'
180 assert response.status == '200 OK'
181
181
182 response.mustcontain('Enter 3 characters or more')
182 response.mustcontain('Enter 3 characters or more')
183
183
184 def test_login_wrong_non_ascii_password(self, user_regular):
184 def test_login_wrong_non_ascii_password(self, user_regular):
185 response = self.app.post(
185 response = self.app.post(
186 route_path('login'),
186 route_path('login'),
187 {'username': user_regular.username,
187 {'username': user_regular.username,
188 'password': u'invalid-non-asci\xe4'.encode('utf8')})
188 'password': u'invalid-non-asci\xe4'.encode('utf8')})
189
189
190 response.mustcontain('invalid user name')
190 response.mustcontain('invalid user name')
191 response.mustcontain('invalid password')
191 response.mustcontain('invalid password')
192
192
193 def test_login_with_non_ascii_password(self, user_util):
193 def test_login_with_non_ascii_password(self, user_util):
194 password = u'valid-non-ascii\xe4'
194 password = u'valid-non-ascii\xe4'
195 user = user_util.create_user(password=password)
195 user = user_util.create_user(password=password)
196 response = self.app.post(
196 response = self.app.post(
197 route_path('login'),
197 route_path('login'),
198 {'username': user.username,
198 {'username': user.username,
199 'password': password})
199 'password': password})
200 assert response.status_code == 302
200 assert response.status_code == 302
201
201
202 def test_login_wrong_username_password(self):
202 def test_login_wrong_username_password(self):
203 response = self.app.post(route_path('login'),
203 response = self.app.post(route_path('login'),
204 {'username': 'error',
204 {'username': 'error',
205 'password': 'test12'})
205 'password': 'test12'})
206
206
207 response.mustcontain('invalid user name')
207 response.mustcontain('invalid user name')
208 response.mustcontain('invalid password')
208 response.mustcontain('invalid password')
209
209
210 def test_login_admin_ok_password_migration(self, real_crypto_backend):
210 def test_login_admin_ok_password_migration(self, real_crypto_backend):
211 from rhodecode.lib import auth
211 from rhodecode.lib import auth
212
212
213 # create new user, with sha256 password
213 # create new user, with sha256 password
214 temp_user = 'test_admin_sha256'
214 temp_user = 'test_admin_sha256'
215 user = fixture.create_user(temp_user)
215 user = fixture.create_user(temp_user)
216 user.password = auth._RhodeCodeCryptoSha256().hash_create(
216 user.password = auth._RhodeCodeCryptoSha256().hash_create(
217 b'test123')
217 b'test123')
218 Session().add(user)
218 Session().add(user)
219 Session().commit()
219 Session().commit()
220 self.destroy_users.add(temp_user)
220 self.destroy_users.add(temp_user)
221 response = self.app.post(route_path('login'),
221 response = self.app.post(route_path('login'),
222 {'username': temp_user,
222 {'username': temp_user,
223 'password': 'test123'}, status=302)
223 'password': 'test123'}, status=302)
224
224
225 response = response.follow()
225 response = response.follow()
226 session = response.get_session_from_response()
226 session = response.get_session_from_response()
227 username = session['rhodecode_user'].get('username')
227 username = session['rhodecode_user'].get('username')
228 assert username == temp_user
228 assert username == temp_user
229 response.mustcontain('logout')
229 response.mustcontain('logout')
230
230
231 # new password should be bcrypted, after log-in and transfer
231 # new password should be bcrypted, after log-in and transfer
232 user = User.get_by_username(temp_user)
232 user = User.get_by_username(temp_user)
233 assert user.password.startswith('$')
233 assert user.password.startswith('$')
234
234
235 # REGISTRATIONS
235 # REGISTRATIONS
236 def test_register(self):
236 def test_register(self):
237 response = self.app.get(route_path('register'))
237 response = self.app.get(route_path('register'))
238 response.mustcontain('Create an Account')
238 response.mustcontain('Create an Account')
239
239
240 def test_register_err_same_username(self):
240 def test_register_err_same_username(self):
241 uname = 'test_admin'
241 uname = 'test_admin'
242 response = self.app.post(
242 response = self.app.post(
243 route_path('register'),
243 route_path('register'),
244 {
244 {
245 'username': uname,
245 'username': uname,
246 'password': 'test12',
246 'password': 'test12',
247 'password_confirmation': 'test12',
247 'password_confirmation': 'test12',
248 'email': 'goodmail@domain.com',
248 'email': 'goodmail@domain.com',
249 'firstname': 'test',
249 'firstname': 'test',
250 'lastname': 'test'
250 'lastname': 'test'
251 }
251 }
252 )
252 )
253
253
254 assertr = response.assert_response()
254 assertr = response.assert_response()
255 msg = 'Username "%(username)s" already exists'
255 msg = 'Username "%(username)s" already exists'
256 msg = msg % {'username': uname}
256 msg = msg % {'username': uname}
257 assertr.element_contains('#username+.error-message', msg)
257 assertr.element_contains('#username+.error-message', msg)
258
258
259 def test_register_err_same_email(self):
259 def test_register_err_same_email(self):
260 response = self.app.post(
260 response = self.app.post(
261 route_path('register'),
261 route_path('register'),
262 {
262 {
263 'username': 'test_admin_0',
263 'username': 'test_admin_0',
264 'password': 'test12',
264 'password': 'test12',
265 'password_confirmation': 'test12',
265 'password_confirmation': 'test12',
266 'email': 'test_admin@mail.com',
266 'email': 'test_admin@mail.com',
267 'firstname': 'test',
267 'firstname': 'test',
268 'lastname': 'test'
268 'lastname': 'test'
269 }
269 }
270 )
270 )
271
271
272 assertr = response.assert_response()
272 assertr = response.assert_response()
273 msg = u'This e-mail address is already taken'
273 msg = u'This e-mail address is already taken'
274 assertr.element_contains('#email+.error-message', msg)
274 assertr.element_contains('#email+.error-message', msg)
275
275
276 def test_register_err_same_email_case_sensitive(self):
276 def test_register_err_same_email_case_sensitive(self):
277 response = self.app.post(
277 response = self.app.post(
278 route_path('register'),
278 route_path('register'),
279 {
279 {
280 'username': 'test_admin_1',
280 'username': 'test_admin_1',
281 'password': 'test12',
281 'password': 'test12',
282 'password_confirmation': 'test12',
282 'password_confirmation': 'test12',
283 'email': 'TesT_Admin@mail.COM',
283 'email': 'TesT_Admin@mail.COM',
284 'firstname': 'test',
284 'firstname': 'test',
285 'lastname': 'test'
285 'lastname': 'test'
286 }
286 }
287 )
287 )
288 assertr = response.assert_response()
288 assertr = response.assert_response()
289 msg = u'This e-mail address is already taken'
289 msg = u'This e-mail address is already taken'
290 assertr.element_contains('#email+.error-message', msg)
290 assertr.element_contains('#email+.error-message', msg)
291
291
292 def test_register_err_wrong_data(self):
292 def test_register_err_wrong_data(self):
293 response = self.app.post(
293 response = self.app.post(
294 route_path('register'),
294 route_path('register'),
295 {
295 {
296 'username': 'xs',
296 'username': 'xs',
297 'password': 'test',
297 'password': 'test',
298 'password_confirmation': 'test',
298 'password_confirmation': 'test',
299 'email': 'goodmailm',
299 'email': 'goodmailm',
300 'firstname': 'test',
300 'firstname': 'test',
301 'lastname': 'test'
301 'lastname': 'test'
302 }
302 }
303 )
303 )
304 assert response.status == '200 OK'
304 assert response.status == '200 OK'
305 response.mustcontain('An email address must contain a single @')
305 response.mustcontain('An email address must contain a single @')
306 response.mustcontain('Enter a value 6 characters long or more')
306 response.mustcontain('Enter a value 6 characters long or more')
307
307
308 def test_register_err_username(self):
308 def test_register_err_username(self):
309 response = self.app.post(
309 response = self.app.post(
310 route_path('register'),
310 route_path('register'),
311 {
311 {
312 'username': 'error user',
312 'username': 'error user',
313 'password': 'test12',
313 'password': 'test12',
314 'password_confirmation': 'test12',
314 'password_confirmation': 'test12',
315 'email': 'goodmailm',
315 'email': 'goodmailm',
316 'firstname': 'test',
316 'firstname': 'test',
317 'lastname': 'test'
317 'lastname': 'test'
318 }
318 }
319 )
319 )
320
320
321 response.mustcontain('An email address must contain a single @')
321 response.mustcontain('An email address must contain a single @')
322 response.mustcontain(
322 response.mustcontain(
323 'Username may only contain '
323 'Username may only contain '
324 'alphanumeric characters underscores, '
324 'alphanumeric characters underscores, '
325 'periods or dashes and must begin with '
325 'periods or dashes and must begin with '
326 'alphanumeric character')
326 'alphanumeric character')
327
327
328 def test_register_err_case_sensitive(self):
328 def test_register_err_case_sensitive(self):
329 usr = 'Test_Admin'
329 usr = 'Test_Admin'
330 response = self.app.post(
330 response = self.app.post(
331 route_path('register'),
331 route_path('register'),
332 {
332 {
333 'username': usr,
333 'username': usr,
334 'password': 'test12',
334 'password': 'test12',
335 'password_confirmation': 'test12',
335 'password_confirmation': 'test12',
336 'email': 'goodmailm',
336 'email': 'goodmailm',
337 'firstname': 'test',
337 'firstname': 'test',
338 'lastname': 'test'
338 'lastname': 'test'
339 }
339 }
340 )
340 )
341
341
342 assertr = response.assert_response()
342 assertr = response.assert_response()
343 msg = u'Username "%(username)s" already exists'
343 msg = u'Username "%(username)s" already exists'
344 msg = msg % {'username': usr}
344 msg = msg % {'username': usr}
345 assertr.element_contains('#username+.error-message', msg)
345 assertr.element_contains('#username+.error-message', msg)
346
346
347 def test_register_special_chars(self):
347 def test_register_special_chars(self):
348 response = self.app.post(
348 response = self.app.post(
349 route_path('register'),
349 route_path('register'),
350 {
350 {
351 'username': 'xxxaxn',
351 'username': 'xxxaxn',
352 'password': 'ąćźżąśśśś',
352 'password': 'ąćźżąśśśś',
353 'password_confirmation': 'ąćźżąśśśś',
353 'password_confirmation': 'ąćźżąśśśś',
354 'email': 'goodmailm@test.plx',
354 'email': 'goodmailm@test.plx',
355 'firstname': 'test',
355 'firstname': 'test',
356 'lastname': 'test'
356 'lastname': 'test'
357 }
357 }
358 )
358 )
359
359
360 msg = u'Invalid characters (non-ascii) in password'
360 msg = u'Invalid characters (non-ascii) in password'
361 response.mustcontain(msg)
361 response.mustcontain(msg)
362
362
363 def test_register_password_mismatch(self):
363 def test_register_password_mismatch(self):
364 response = self.app.post(
364 response = self.app.post(
365 route_path('register'),
365 route_path('register'),
366 {
366 {
367 'username': 'xs',
367 'username': 'xs',
368 'password': '123qwe',
368 'password': '123qwe',
369 'password_confirmation': 'qwe123',
369 'password_confirmation': 'qwe123',
370 'email': 'goodmailm@test.plxa',
370 'email': 'goodmailm@test.plxa',
371 'firstname': 'test',
371 'firstname': 'test',
372 'lastname': 'test'
372 'lastname': 'test'
373 }
373 }
374 )
374 )
375 msg = u'Passwords do not match'
375 msg = u'Passwords do not match'
376 response.mustcontain(msg)
376 response.mustcontain(msg)
377
377
378 def test_register_ok(self):
378 def test_register_ok(self):
379 username = 'test_regular4'
379 username = 'test_regular4'
380 password = 'qweqwe'
380 password = 'qweqwe'
381 email = 'marcin@test.com'
381 email = 'marcin@test.com'
382 name = 'testname'
382 name = 'testname'
383 lastname = 'testlastname'
383 lastname = 'testlastname'
384
384
385 # this initializes a session
385 # this initializes a session
386 response = self.app.get(route_path('register'))
386 response = self.app.get(route_path('register'))
387 response.mustcontain('Create an Account')
387 response.mustcontain('Create an Account')
388
388
389
389
390 response = self.app.post(
390 response = self.app.post(
391 route_path('register'),
391 route_path('register'),
392 {
392 {
393 'username': username,
393 'username': username,
394 'password': password,
394 'password': password,
395 'password_confirmation': password,
395 'password_confirmation': password,
396 'email': email,
396 'email': email,
397 'firstname': name,
397 'firstname': name,
398 'lastname': lastname,
398 'lastname': lastname,
399 'admin': True
399 'admin': True
400 },
400 },
401 status=302
401 status=302
402 ) # This should be overridden
402 ) # This should be overridden
403
403
404 assert_session_flash(
404 assert_session_flash(
405 response, 'You have successfully registered with RhodeCode. You can log-in now.')
405 response, 'You have successfully registered with RhodeCode. You can log-in now.')
406
406
407 ret = Session().query(User).filter(
407 ret = Session().query(User).filter(
408 User.username == 'test_regular4').one()
408 User.username == 'test_regular4').one()
409 assert ret.username == username
409 assert ret.username == username
410 assert check_password(password, ret.password)
410 assert check_password(password, ret.password)
411 assert ret.email == email
411 assert ret.email == email
412 assert ret.name == name
412 assert ret.name == name
413 assert ret.lastname == lastname
413 assert ret.lastname == lastname
414 assert ret.auth_tokens is not None
414 assert ret.auth_tokens is not None
415 assert not ret.admin
415 assert not ret.admin
416
416
417 def test_forgot_password_wrong_mail(self):
417 def test_forgot_password_wrong_mail(self):
418 bad_email = 'marcin@wrongmail.org'
418 bad_email = 'marcin@wrongmail.org'
419 # this initializes a session
419 # this initializes a session
420 self.app.get(route_path('reset_password'))
420 self.app.get(route_path('reset_password'))
421
421
422 response = self.app.post(
422 response = self.app.post(
423 route_path('reset_password'), {'email': bad_email, }
423 route_path('reset_password'), {'email': bad_email, }
424 )
424 )
425 assert_session_flash(response,
425 assert_session_flash(response,
426 'If such email exists, a password reset link was sent to it.')
426 'If such email exists, a password reset link was sent to it.')
427
427
428 def test_forgot_password(self, user_util):
428 def test_forgot_password(self, user_util):
429 # this initializes a session
429 # this initializes a session
430 self.app.get(route_path('reset_password'))
430 self.app.get(route_path('reset_password'))
431
431
432 user = user_util.create_user()
432 user = user_util.create_user()
433 user_id = user.user_id
433 user_id = user.user_id
434 email = user.email
434 email = user.email
435
435
436 response = self.app.post(route_path('reset_password'), {'email': email, })
436 response = self.app.post(route_path('reset_password'), {'email': email, })
437
437
438 assert_session_flash(response,
438 assert_session_flash(response,
439 'If such email exists, a password reset link was sent to it.')
439 'If such email exists, a password reset link was sent to it.')
440
440
441 # BAD KEY
441 # BAD KEY
442 confirm_url = '{}?key={}'.format(route_path('reset_password_confirmation'), 'badkey')
442 confirm_url = '{}?key={}'.format(route_path('reset_password_confirmation'), 'badkey')
443 response = self.app.get(confirm_url, status=302)
443 response = self.app.get(confirm_url, status=302)
444 assert response.location.endswith(route_path('reset_password'))
444 assert response.location.endswith(route_path('reset_password'))
445 assert_session_flash(response, 'Given reset token is invalid')
445 assert_session_flash(response, 'Given reset token is invalid')
446
446
447 response.follow() # cleanup flash
447 response.follow() # cleanup flash
448
448
449 # GOOD KEY
449 # GOOD KEY
450 key = UserApiKeys.query()\
450 key = UserApiKeys.query()\
451 .filter(UserApiKeys.user_id == user_id)\
451 .filter(UserApiKeys.user_id == user_id)\
452 .filter(UserApiKeys.role == UserApiKeys.ROLE_PASSWORD_RESET)\
452 .filter(UserApiKeys.role == UserApiKeys.ROLE_PASSWORD_RESET)\
453 .first()
453 .first()
454
454
455 assert key
455 assert key
456
456
457 confirm_url = '{}?key={}'.format(route_path('reset_password_confirmation'), key.api_key)
457 confirm_url = '{}?key={}'.format(route_path('reset_password_confirmation'), key.api_key)
458 response = self.app.get(confirm_url)
458 response = self.app.get(confirm_url)
459 assert response.status == '302 Found'
459 assert response.status == '302 Found'
460 assert response.location.endswith(route_path('login'))
460 assert response.location.endswith(route_path('login'))
461
461
462 assert_session_flash(
462 assert_session_flash(
463 response,
463 response,
464 'Your password reset was successful, '
464 'Your password reset was successful, '
465 'a new password has been sent to your email')
465 'a new password has been sent to your email')
466
466
467 response.follow()
467 response.follow()
468
468
469 def _get_api_whitelist(self, values=None):
469 def _get_api_whitelist(self, values=None):
470 config = {'api_access_controllers_whitelist': values or []}
470 config = {'api_access_controllers_whitelist': values or []}
471 return config
471 return config
472
472
473 @pytest.mark.parametrize("test_name, auth_token", [
473 @pytest.mark.parametrize("test_name, auth_token", [
474 ('none', None),
474 ('none', None),
475 ('empty_string', ''),
475 ('empty_string', ''),
476 ('fake_number', '123456'),
476 ('fake_number', '123456'),
477 ('proper_auth_token', None)
477 ('proper_auth_token', None)
478 ])
478 ])
479 def test_access_not_whitelisted_page_via_auth_token(
479 def test_access_not_whitelisted_page_via_auth_token(
480 self, test_name, auth_token, user_admin):
480 self, test_name, auth_token, user_admin):
481
481
482 whitelist = self._get_api_whitelist([])
482 whitelist = self._get_api_whitelist([])
483 with mock.patch.dict('rhodecode.CONFIG', whitelist):
483 with mock.patch.dict('rhodecode.CONFIG', whitelist):
484 assert [] == whitelist['api_access_controllers_whitelist']
484 assert [] == whitelist['api_access_controllers_whitelist']
485 if test_name == 'proper_auth_token':
485 if test_name == 'proper_auth_token':
486 # use builtin if api_key is None
486 # use builtin if api_key is None
487 auth_token = user_admin.api_key
487 auth_token = user_admin.api_key
488
488
489 with fixture.anon_access(False):
489 with fixture.anon_access(False):
490 self.app.get(
490 self.app.get(
491 route_path('repo_commit_raw',
491 route_path('repo_commit_raw',
492 repo_name=HG_REPO, commit_id='tip',
492 repo_name=HG_REPO, commit_id='tip',
493 params=dict(api_key=auth_token)),
493 params=dict(api_key=auth_token)),
494 status=302)
494 status=302)
495
495
496 @pytest.mark.parametrize("test_name, auth_token, code", [
496 @pytest.mark.parametrize("test_name, auth_token, code", [
497 ('none', None, 302),
497 ('none', None, 302),
498 ('empty_string', '', 302),
498 ('empty_string', '', 302),
499 ('fake_number', '123456', 302),
499 ('fake_number', '123456', 302),
500 ('proper_auth_token', None, 200)
500 ('proper_auth_token', None, 200)
501 ])
501 ])
502 def test_access_whitelisted_page_via_auth_token(
502 def test_access_whitelisted_page_via_auth_token(
503 self, test_name, auth_token, code, user_admin):
503 self, test_name, auth_token, code, user_admin):
504
504
505 whitelist = self._get_api_whitelist(whitelist_view)
505 whitelist = self._get_api_whitelist(whitelist_view)
506
506
507 with mock.patch.dict('rhodecode.CONFIG', whitelist):
507 with mock.patch.dict('rhodecode.CONFIG', whitelist):
508 assert whitelist_view == whitelist['api_access_controllers_whitelist']
508 assert whitelist_view == whitelist['api_access_controllers_whitelist']
509
509
510 if test_name == 'proper_auth_token':
510 if test_name == 'proper_auth_token':
511 auth_token = user_admin.api_key
511 auth_token = user_admin.api_key
512 assert auth_token
512 assert auth_token
513
513
514 with fixture.anon_access(False):
514 with fixture.anon_access(False):
515 self.app.get(
515 self.app.get(
516 route_path('repo_commit_raw',
516 route_path('repo_commit_raw',
517 repo_name=HG_REPO, commit_id='tip',
517 repo_name=HG_REPO, commit_id='tip',
518 params=dict(api_key=auth_token)),
518 params=dict(api_key=auth_token)),
519 status=code)
519 status=code)
520
520
521 @pytest.mark.parametrize("test_name, auth_token, code", [
521 @pytest.mark.parametrize("test_name, auth_token, code", [
522 ('proper_auth_token', None, 200),
522 ('proper_auth_token', None, 200),
523 ('wrong_auth_token', '123456', 302),
523 ('wrong_auth_token', '123456', 302),
524 ])
524 ])
525 def test_access_whitelisted_page_via_auth_token_bound_to_token(
525 def test_access_whitelisted_page_via_auth_token_bound_to_token(
526 self, test_name, auth_token, code, user_admin):
526 self, test_name, auth_token, code, user_admin):
527
527
528 expected_token = auth_token
528 expected_token = auth_token
529 if test_name == 'proper_auth_token':
529 if test_name == 'proper_auth_token':
530 auth_token = user_admin.api_key
530 auth_token = user_admin.api_key
531 expected_token = auth_token
531 expected_token = auth_token
532 assert auth_token
532 assert auth_token
533
533
534 whitelist = self._get_api_whitelist([
534 whitelist = self._get_api_whitelist([
535 'RepoCommitsView:repo_commit_raw@{}'.format(expected_token)])
535 'RepoCommitsView:repo_commit_raw@{}'.format(expected_token)])
536
536
537 with mock.patch.dict('rhodecode.CONFIG', whitelist):
537 with mock.patch.dict('rhodecode.CONFIG', whitelist):
538
538
539 with fixture.anon_access(False):
539 with fixture.anon_access(False):
540 self.app.get(
540 self.app.get(
541 route_path('repo_commit_raw',
541 route_path('repo_commit_raw',
542 repo_name=HG_REPO, commit_id='tip',
542 repo_name=HG_REPO, commit_id='tip',
543 params=dict(api_key=auth_token)),
543 params=dict(api_key=auth_token)),
544 status=code)
544 status=code)
545
545
546 def test_access_page_via_extra_auth_token(self):
546 def test_access_page_via_extra_auth_token(self):
547 whitelist = self._get_api_whitelist(whitelist_view)
547 whitelist = self._get_api_whitelist(whitelist_view)
548 with mock.patch.dict('rhodecode.CONFIG', whitelist):
548 with mock.patch.dict('rhodecode.CONFIG', whitelist):
549 assert whitelist_view == \
549 assert whitelist_view == \
550 whitelist['api_access_controllers_whitelist']
550 whitelist['api_access_controllers_whitelist']
551
551
552 new_auth_token = AuthTokenModel().create(
552 new_auth_token = AuthTokenModel().create(
553 TEST_USER_ADMIN_LOGIN, 'test')
553 TEST_USER_ADMIN_LOGIN, 'test')
554 Session().commit()
554 Session().commit()
555 with fixture.anon_access(False):
555 with fixture.anon_access(False):
556 self.app.get(
556 self.app.get(
557 route_path('repo_commit_raw',
557 route_path('repo_commit_raw',
558 repo_name=HG_REPO, commit_id='tip',
558 repo_name=HG_REPO, commit_id='tip',
559 params=dict(api_key=new_auth_token.api_key)),
559 params=dict(api_key=new_auth_token.api_key)),
560 status=200)
560 status=200)
561
561
562 def test_access_page_via_expired_auth_token(self):
562 def test_access_page_via_expired_auth_token(self):
563 whitelist = self._get_api_whitelist(whitelist_view)
563 whitelist = self._get_api_whitelist(whitelist_view)
564 with mock.patch.dict('rhodecode.CONFIG', whitelist):
564 with mock.patch.dict('rhodecode.CONFIG', whitelist):
565 assert whitelist_view == \
565 assert whitelist_view == \
566 whitelist['api_access_controllers_whitelist']
566 whitelist['api_access_controllers_whitelist']
567
567
568 new_auth_token = AuthTokenModel().create(
568 new_auth_token = AuthTokenModel().create(
569 TEST_USER_ADMIN_LOGIN, 'test')
569 TEST_USER_ADMIN_LOGIN, 'test')
570 Session().commit()
570 Session().commit()
571 # patch the api key and make it expired
571 # patch the api key and make it expired
572 new_auth_token.expires = 0
572 new_auth_token.expires = 0
573 Session().add(new_auth_token)
573 Session().add(new_auth_token)
574 Session().commit()
574 Session().commit()
575 with fixture.anon_access(False):
575 with fixture.anon_access(False):
576 self.app.get(
576 self.app.get(
577 route_path('repo_commit_raw',
577 route_path('repo_commit_raw',
578 repo_name=HG_REPO, commit_id='tip',
578 repo_name=HG_REPO, commit_id='tip',
579 params=dict(api_key=new_auth_token.api_key)),
579 params=dict(api_key=new_auth_token.api_key)),
580 status=302)
580 status=302)
@@ -1,1877 +1,1877 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2011-2020 RhodeCode GmbH
3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import logging
21 import logging
22 import collections
22 import collections
23
23
24 import formencode
24 import formencode
25 import formencode.htmlfill
25 import formencode.htmlfill
26 import peppercorn
26 import peppercorn
27 from pyramid.httpexceptions import (
27 from pyramid.httpexceptions import (
28 HTTPFound, HTTPNotFound, HTTPForbidden, HTTPBadRequest, HTTPConflict)
28 HTTPFound, HTTPNotFound, HTTPForbidden, HTTPBadRequest, HTTPConflict)
29
29
30 from pyramid.renderers import render
30 from pyramid.renderers import render
31
31
32 from rhodecode.apps._base import RepoAppView, DataGridAppView
32 from rhodecode.apps._base import RepoAppView, DataGridAppView
33
33
34 from rhodecode.lib import helpers as h, diffs, codeblocks, channelstream
34 from rhodecode.lib import helpers as h, diffs, codeblocks, channelstream
35 from rhodecode.lib.base import vcs_operation_context
35 from rhodecode.lib.base import vcs_operation_context
36 from rhodecode.lib.diffs import load_cached_diff, cache_diff, diff_cache_exist
36 from rhodecode.lib.diffs import load_cached_diff, cache_diff, diff_cache_exist
37 from rhodecode.lib.exceptions import CommentVersionMismatch
37 from rhodecode.lib.exceptions import CommentVersionMismatch
38 from rhodecode.lib.ext_json import json
38 from rhodecode.lib.ext_json import json
39 from rhodecode.lib.auth import (
39 from rhodecode.lib.auth import (
40 LoginRequired, HasRepoPermissionAny, HasRepoPermissionAnyDecorator,
40 LoginRequired, HasRepoPermissionAny, HasRepoPermissionAnyDecorator,
41 NotAnonymous, CSRFRequired)
41 NotAnonymous, CSRFRequired)
42 from rhodecode.lib.utils2 import str2bool, safe_str, safe_unicode, safe_int, aslist, retry
42 from rhodecode.lib.utils2 import str2bool, safe_str, safe_unicode, safe_int, aslist, retry
43 from rhodecode.lib.vcs.backends.base import (
43 from rhodecode.lib.vcs.backends.base import (
44 EmptyCommit, UpdateFailureReason, unicode_to_reference)
44 EmptyCommit, UpdateFailureReason, unicode_to_reference)
45 from rhodecode.lib.vcs.exceptions import (
45 from rhodecode.lib.vcs.exceptions import (
46 CommitDoesNotExistError, RepositoryRequirementError, EmptyRepositoryError)
46 CommitDoesNotExistError, RepositoryRequirementError, EmptyRepositoryError)
47 from rhodecode.model.changeset_status import ChangesetStatusModel
47 from rhodecode.model.changeset_status import ChangesetStatusModel
48 from rhodecode.model.comment import CommentsModel
48 from rhodecode.model.comment import CommentsModel
49 from rhodecode.model.db import (
49 from rhodecode.model.db import (
50 func, false, or_, PullRequest, ChangesetComment, ChangesetStatus, Repository,
50 func, false, or_, PullRequest, ChangesetComment, ChangesetStatus, Repository,
51 PullRequestReviewers)
51 PullRequestReviewers)
52 from rhodecode.model.forms import PullRequestForm
52 from rhodecode.model.forms import PullRequestForm
53 from rhodecode.model.meta import Session
53 from rhodecode.model.meta import Session
54 from rhodecode.model.pull_request import PullRequestModel, MergeCheck
54 from rhodecode.model.pull_request import PullRequestModel, MergeCheck
55 from rhodecode.model.scm import ScmModel
55 from rhodecode.model.scm import ScmModel
56
56
57 log = logging.getLogger(__name__)
57 log = logging.getLogger(__name__)
58
58
59
59
60 class RepoPullRequestsView(RepoAppView, DataGridAppView):
60 class RepoPullRequestsView(RepoAppView, DataGridAppView):
61
61
62 def load_default_context(self):
62 def load_default_context(self):
63 c = self._get_local_tmpl_context(include_app_defaults=True)
63 c = self._get_local_tmpl_context(include_app_defaults=True)
64 c.REVIEW_STATUS_APPROVED = ChangesetStatus.STATUS_APPROVED
64 c.REVIEW_STATUS_APPROVED = ChangesetStatus.STATUS_APPROVED
65 c.REVIEW_STATUS_REJECTED = ChangesetStatus.STATUS_REJECTED
65 c.REVIEW_STATUS_REJECTED = ChangesetStatus.STATUS_REJECTED
66 # backward compat., we use for OLD PRs a plain renderer
66 # backward compat., we use for OLD PRs a plain renderer
67 c.renderer = 'plain'
67 c.renderer = 'plain'
68 return c
68 return c
69
69
70 def _get_pull_requests_list(
70 def _get_pull_requests_list(
71 self, repo_name, source, filter_type, opened_by, statuses):
71 self, repo_name, source, filter_type, opened_by, statuses):
72
72
73 draw, start, limit = self._extract_chunk(self.request)
73 draw, start, limit = self._extract_chunk(self.request)
74 search_q, order_by, order_dir = self._extract_ordering(self.request)
74 search_q, order_by, order_dir = self._extract_ordering(self.request)
75 _render = self.request.get_partial_renderer(
75 _render = self.request.get_partial_renderer(
76 'rhodecode:templates/data_table/_dt_elements.mako')
76 'rhodecode:templates/data_table/_dt_elements.mako')
77
77
78 # pagination
78 # pagination
79
79
80 if filter_type == 'awaiting_review':
80 if filter_type == 'awaiting_review':
81 pull_requests = PullRequestModel().get_awaiting_review(
81 pull_requests = PullRequestModel().get_awaiting_review(
82 repo_name,
82 repo_name,
83 search_q=search_q, statuses=statuses,
83 search_q=search_q, statuses=statuses,
84 offset=start, length=limit, order_by=order_by, order_dir=order_dir)
84 offset=start, length=limit, order_by=order_by, order_dir=order_dir)
85 pull_requests_total_count = PullRequestModel().count_awaiting_review(
85 pull_requests_total_count = PullRequestModel().count_awaiting_review(
86 repo_name,
86 repo_name,
87 search_q=search_q, statuses=statuses)
87 search_q=search_q, statuses=statuses)
88 elif filter_type == 'awaiting_my_review':
88 elif filter_type == 'awaiting_my_review':
89 pull_requests = PullRequestModel().get_awaiting_my_review(
89 pull_requests = PullRequestModel().get_awaiting_my_review(
90 repo_name, self._rhodecode_user.user_id,
90 repo_name, self._rhodecode_user.user_id,
91 search_q=search_q, statuses=statuses,
91 search_q=search_q, statuses=statuses,
92 offset=start, length=limit, order_by=order_by, order_dir=order_dir)
92 offset=start, length=limit, order_by=order_by, order_dir=order_dir)
93 pull_requests_total_count = PullRequestModel().count_awaiting_my_review(
93 pull_requests_total_count = PullRequestModel().count_awaiting_my_review(
94 repo_name, self._rhodecode_user.user_id,
94 repo_name, self._rhodecode_user.user_id,
95 search_q=search_q, statuses=statuses)
95 search_q=search_q, statuses=statuses)
96 else:
96 else:
97 pull_requests = PullRequestModel().get_all(
97 pull_requests = PullRequestModel().get_all(
98 repo_name, search_q=search_q, source=source, opened_by=opened_by,
98 repo_name, search_q=search_q, source=source, opened_by=opened_by,
99 statuses=statuses, offset=start, length=limit,
99 statuses=statuses, offset=start, length=limit,
100 order_by=order_by, order_dir=order_dir)
100 order_by=order_by, order_dir=order_dir)
101 pull_requests_total_count = PullRequestModel().count_all(
101 pull_requests_total_count = PullRequestModel().count_all(
102 repo_name, search_q=search_q, source=source, statuses=statuses,
102 repo_name, search_q=search_q, source=source, statuses=statuses,
103 opened_by=opened_by)
103 opened_by=opened_by)
104
104
105 data = []
105 data = []
106 comments_model = CommentsModel()
106 comments_model = CommentsModel()
107 for pr in pull_requests:
107 for pr in pull_requests:
108 comments_count = comments_model.get_all_comments(
108 comments_count = comments_model.get_all_comments(
109 self.db_repo.repo_id, pull_request=pr,
109 self.db_repo.repo_id, pull_request=pr,
110 include_drafts=False, count_only=True)
110 include_drafts=False, count_only=True)
111
111
112 review_statuses = pr.reviewers_statuses(user=self._rhodecode_db_user)
112 review_statuses = pr.reviewers_statuses(user=self._rhodecode_db_user)
113 my_review_status = ChangesetStatus.STATUS_NOT_REVIEWED
113 my_review_status = ChangesetStatus.STATUS_NOT_REVIEWED
114 if review_statuses and review_statuses[4]:
114 if review_statuses and review_statuses[4]:
115 _review_obj, _user, _reasons, _mandatory, statuses = review_statuses
115 _review_obj, _user, _reasons, _mandatory, statuses = review_statuses
116 my_review_status = statuses[0][1].status
116 my_review_status = statuses[0][1].status
117
117
118 data.append({
118 data.append({
119 'name': _render('pullrequest_name',
119 'name': _render('pullrequest_name',
120 pr.pull_request_id, pr.pull_request_state,
120 pr.pull_request_id, pr.pull_request_state,
121 pr.work_in_progress, pr.target_repo.repo_name,
121 pr.work_in_progress, pr.target_repo.repo_name,
122 short=True),
122 short=True),
123 'name_raw': pr.pull_request_id,
123 'name_raw': pr.pull_request_id,
124 'status': _render('pullrequest_status',
124 'status': _render('pullrequest_status',
125 pr.calculated_review_status()),
125 pr.calculated_review_status()),
126 'my_status': _render('pullrequest_status',
126 'my_status': _render('pullrequest_status',
127 my_review_status),
127 my_review_status),
128 'title': _render('pullrequest_title', pr.title, pr.description),
128 'title': _render('pullrequest_title', pr.title, pr.description),
129 'description': h.escape(pr.description),
129 'description': h.escape(pr.description),
130 'updated_on': _render('pullrequest_updated_on',
130 'updated_on': _render('pullrequest_updated_on',
131 h.datetime_to_time(pr.updated_on),
131 h.datetime_to_time(pr.updated_on),
132 pr.versions_count),
132 pr.versions_count),
133 'updated_on_raw': h.datetime_to_time(pr.updated_on),
133 'updated_on_raw': h.datetime_to_time(pr.updated_on),
134 'created_on': _render('pullrequest_updated_on',
134 'created_on': _render('pullrequest_updated_on',
135 h.datetime_to_time(pr.created_on)),
135 h.datetime_to_time(pr.created_on)),
136 'created_on_raw': h.datetime_to_time(pr.created_on),
136 'created_on_raw': h.datetime_to_time(pr.created_on),
137 'state': pr.pull_request_state,
137 'state': pr.pull_request_state,
138 'author': _render('pullrequest_author',
138 'author': _render('pullrequest_author',
139 pr.author.full_contact, ),
139 pr.author.full_contact, ),
140 'author_raw': pr.author.full_name,
140 'author_raw': pr.author.full_name,
141 'comments': _render('pullrequest_comments', comments_count),
141 'comments': _render('pullrequest_comments', comments_count),
142 'comments_raw': comments_count,
142 'comments_raw': comments_count,
143 'closed': pr.is_closed(),
143 'closed': pr.is_closed(),
144 })
144 })
145
145
146 data = ({
146 data = ({
147 'draw': draw,
147 'draw': draw,
148 'data': data,
148 'data': data,
149 'recordsTotal': pull_requests_total_count,
149 'recordsTotal': pull_requests_total_count,
150 'recordsFiltered': pull_requests_total_count,
150 'recordsFiltered': pull_requests_total_count,
151 })
151 })
152 return data
152 return data
153
153
154 @LoginRequired()
154 @LoginRequired()
155 @HasRepoPermissionAnyDecorator(
155 @HasRepoPermissionAnyDecorator(
156 'repository.read', 'repository.write', 'repository.admin')
156 'repository.read', 'repository.write', 'repository.admin')
157 def pull_request_list(self):
157 def pull_request_list(self):
158 c = self.load_default_context()
158 c = self.load_default_context()
159
159
160 req_get = self.request.GET
160 req_get = self.request.GET
161 c.source = str2bool(req_get.get('source'))
161 c.source = str2bool(req_get.get('source'))
162 c.closed = str2bool(req_get.get('closed'))
162 c.closed = str2bool(req_get.get('closed'))
163 c.my = str2bool(req_get.get('my'))
163 c.my = str2bool(req_get.get('my'))
164 c.awaiting_review = str2bool(req_get.get('awaiting_review'))
164 c.awaiting_review = str2bool(req_get.get('awaiting_review'))
165 c.awaiting_my_review = str2bool(req_get.get('awaiting_my_review'))
165 c.awaiting_my_review = str2bool(req_get.get('awaiting_my_review'))
166
166
167 c.active = 'open'
167 c.active = 'open'
168 if c.my:
168 if c.my:
169 c.active = 'my'
169 c.active = 'my'
170 if c.closed:
170 if c.closed:
171 c.active = 'closed'
171 c.active = 'closed'
172 if c.awaiting_review and not c.source:
172 if c.awaiting_review and not c.source:
173 c.active = 'awaiting'
173 c.active = 'awaiting'
174 if c.source and not c.awaiting_review:
174 if c.source and not c.awaiting_review:
175 c.active = 'source'
175 c.active = 'source'
176 if c.awaiting_my_review:
176 if c.awaiting_my_review:
177 c.active = 'awaiting_my'
177 c.active = 'awaiting_my'
178
178
179 return self._get_template_context(c)
179 return self._get_template_context(c)
180
180
181 @LoginRequired()
181 @LoginRequired()
182 @HasRepoPermissionAnyDecorator(
182 @HasRepoPermissionAnyDecorator(
183 'repository.read', 'repository.write', 'repository.admin')
183 'repository.read', 'repository.write', 'repository.admin')
184 def pull_request_list_data(self):
184 def pull_request_list_data(self):
185 self.load_default_context()
185 self.load_default_context()
186
186
187 # additional filters
187 # additional filters
188 req_get = self.request.GET
188 req_get = self.request.GET
189 source = str2bool(req_get.get('source'))
189 source = str2bool(req_get.get('source'))
190 closed = str2bool(req_get.get('closed'))
190 closed = str2bool(req_get.get('closed'))
191 my = str2bool(req_get.get('my'))
191 my = str2bool(req_get.get('my'))
192 awaiting_review = str2bool(req_get.get('awaiting_review'))
192 awaiting_review = str2bool(req_get.get('awaiting_review'))
193 awaiting_my_review = str2bool(req_get.get('awaiting_my_review'))
193 awaiting_my_review = str2bool(req_get.get('awaiting_my_review'))
194
194
195 filter_type = 'awaiting_review' if awaiting_review \
195 filter_type = 'awaiting_review' if awaiting_review \
196 else 'awaiting_my_review' if awaiting_my_review \
196 else 'awaiting_my_review' if awaiting_my_review \
197 else None
197 else None
198
198
199 opened_by = None
199 opened_by = None
200 if my:
200 if my:
201 opened_by = [self._rhodecode_user.user_id]
201 opened_by = [self._rhodecode_user.user_id]
202
202
203 statuses = [PullRequest.STATUS_NEW, PullRequest.STATUS_OPEN]
203 statuses = [PullRequest.STATUS_NEW, PullRequest.STATUS_OPEN]
204 if closed:
204 if closed:
205 statuses = [PullRequest.STATUS_CLOSED]
205 statuses = [PullRequest.STATUS_CLOSED]
206
206
207 data = self._get_pull_requests_list(
207 data = self._get_pull_requests_list(
208 repo_name=self.db_repo_name, source=source,
208 repo_name=self.db_repo_name, source=source,
209 filter_type=filter_type, opened_by=opened_by, statuses=statuses)
209 filter_type=filter_type, opened_by=opened_by, statuses=statuses)
210
210
211 return data
211 return data
212
212
213 def _is_diff_cache_enabled(self, target_repo):
213 def _is_diff_cache_enabled(self, target_repo):
214 caching_enabled = self._get_general_setting(
214 caching_enabled = self._get_general_setting(
215 target_repo, 'rhodecode_diff_cache')
215 target_repo, 'rhodecode_diff_cache')
216 log.debug('Diff caching enabled: %s', caching_enabled)
216 log.debug('Diff caching enabled: %s', caching_enabled)
217 return caching_enabled
217 return caching_enabled
218
218
219 def _get_diffset(self, source_repo_name, source_repo,
219 def _get_diffset(self, source_repo_name, source_repo,
220 ancestor_commit,
220 ancestor_commit,
221 source_ref_id, target_ref_id,
221 source_ref_id, target_ref_id,
222 target_commit, source_commit, diff_limit, file_limit,
222 target_commit, source_commit, diff_limit, file_limit,
223 fulldiff, hide_whitespace_changes, diff_context, use_ancestor=True):
223 fulldiff, hide_whitespace_changes, diff_context, use_ancestor=True):
224
224
225 target_commit_final = target_commit
225 target_commit_final = target_commit
226 source_commit_final = source_commit
226 source_commit_final = source_commit
227
227
228 if use_ancestor:
228 if use_ancestor:
229 # we might want to not use it for versions
229 # we might want to not use it for versions
230 target_ref_id = ancestor_commit.raw_id
230 target_ref_id = ancestor_commit.raw_id
231 target_commit_final = ancestor_commit
231 target_commit_final = ancestor_commit
232
232
233 vcs_diff = PullRequestModel().get_diff(
233 vcs_diff = PullRequestModel().get_diff(
234 source_repo, source_ref_id, target_ref_id,
234 source_repo, source_ref_id, target_ref_id,
235 hide_whitespace_changes, diff_context)
235 hide_whitespace_changes, diff_context)
236
236
237 diff_processor = diffs.DiffProcessor(
237 diff_processor = diffs.DiffProcessor(
238 vcs_diff, format='newdiff', diff_limit=diff_limit,
238 vcs_diff, format='newdiff', diff_limit=diff_limit,
239 file_limit=file_limit, show_full_diff=fulldiff)
239 file_limit=file_limit, show_full_diff=fulldiff)
240
240
241 _parsed = diff_processor.prepare()
241 _parsed = diff_processor.prepare()
242
242
243 diffset = codeblocks.DiffSet(
243 diffset = codeblocks.DiffSet(
244 repo_name=self.db_repo_name,
244 repo_name=self.db_repo_name,
245 source_repo_name=source_repo_name,
245 source_repo_name=source_repo_name,
246 source_node_getter=codeblocks.diffset_node_getter(target_commit_final),
246 source_node_getter=codeblocks.diffset_node_getter(target_commit_final),
247 target_node_getter=codeblocks.diffset_node_getter(source_commit_final),
247 target_node_getter=codeblocks.diffset_node_getter(source_commit_final),
248 )
248 )
249 diffset = self.path_filter.render_patchset_filtered(
249 diffset = self.path_filter.render_patchset_filtered(
250 diffset, _parsed, target_ref_id, source_ref_id)
250 diffset, _parsed, target_ref_id, source_ref_id)
251
251
252 return diffset
252 return diffset
253
253
254 def _get_range_diffset(self, source_scm, source_repo,
254 def _get_range_diffset(self, source_scm, source_repo,
255 commit1, commit2, diff_limit, file_limit,
255 commit1, commit2, diff_limit, file_limit,
256 fulldiff, hide_whitespace_changes, diff_context):
256 fulldiff, hide_whitespace_changes, diff_context):
257 vcs_diff = source_scm.get_diff(
257 vcs_diff = source_scm.get_diff(
258 commit1, commit2,
258 commit1, commit2,
259 ignore_whitespace=hide_whitespace_changes,
259 ignore_whitespace=hide_whitespace_changes,
260 context=diff_context)
260 context=diff_context)
261
261
262 diff_processor = diffs.DiffProcessor(
262 diff_processor = diffs.DiffProcessor(
263 vcs_diff, format='newdiff', diff_limit=diff_limit,
263 vcs_diff, format='newdiff', diff_limit=diff_limit,
264 file_limit=file_limit, show_full_diff=fulldiff)
264 file_limit=file_limit, show_full_diff=fulldiff)
265
265
266 _parsed = diff_processor.prepare()
266 _parsed = diff_processor.prepare()
267
267
268 diffset = codeblocks.DiffSet(
268 diffset = codeblocks.DiffSet(
269 repo_name=source_repo.repo_name,
269 repo_name=source_repo.repo_name,
270 source_node_getter=codeblocks.diffset_node_getter(commit1),
270 source_node_getter=codeblocks.diffset_node_getter(commit1),
271 target_node_getter=codeblocks.diffset_node_getter(commit2))
271 target_node_getter=codeblocks.diffset_node_getter(commit2))
272
272
273 diffset = self.path_filter.render_patchset_filtered(
273 diffset = self.path_filter.render_patchset_filtered(
274 diffset, _parsed, commit1.raw_id, commit2.raw_id)
274 diffset, _parsed, commit1.raw_id, commit2.raw_id)
275
275
276 return diffset
276 return diffset
277
277
278 def register_comments_vars(self, c, pull_request, versions, include_drafts=True):
278 def register_comments_vars(self, c, pull_request, versions, include_drafts=True):
279 comments_model = CommentsModel()
279 comments_model = CommentsModel()
280
280
281 # GENERAL COMMENTS with versions #
281 # GENERAL COMMENTS with versions #
282 q = comments_model._all_general_comments_of_pull_request(pull_request)
282 q = comments_model._all_general_comments_of_pull_request(pull_request)
283 q = q.order_by(ChangesetComment.comment_id.asc())
283 q = q.order_by(ChangesetComment.comment_id.asc())
284 if not include_drafts:
284 if not include_drafts:
285 q = q.filter(ChangesetComment.draft == false())
285 q = q.filter(ChangesetComment.draft == false())
286 general_comments = q
286 general_comments = q
287
287
288 # pick comments we want to render at current version
288 # pick comments we want to render at current version
289 c.comment_versions = comments_model.aggregate_comments(
289 c.comment_versions = comments_model.aggregate_comments(
290 general_comments, versions, c.at_version_num)
290 general_comments, versions, c.at_version_num)
291
291
292 # INLINE COMMENTS with versions #
292 # INLINE COMMENTS with versions #
293 q = comments_model._all_inline_comments_of_pull_request(pull_request)
293 q = comments_model._all_inline_comments_of_pull_request(pull_request)
294 q = q.order_by(ChangesetComment.comment_id.asc())
294 q = q.order_by(ChangesetComment.comment_id.asc())
295 if not include_drafts:
295 if not include_drafts:
296 q = q.filter(ChangesetComment.draft == false())
296 q = q.filter(ChangesetComment.draft == false())
297 inline_comments = q
297 inline_comments = q
298
298
299 c.inline_versions = comments_model.aggregate_comments(
299 c.inline_versions = comments_model.aggregate_comments(
300 inline_comments, versions, c.at_version_num, inline=True)
300 inline_comments, versions, c.at_version_num, inline=True)
301
301
302 # Comments inline+general
302 # Comments inline+general
303 if c.at_version:
303 if c.at_version:
304 c.inline_comments_flat = c.inline_versions[c.at_version_num]['display']
304 c.inline_comments_flat = c.inline_versions[c.at_version_num]['display']
305 c.comments = c.comment_versions[c.at_version_num]['display']
305 c.comments = c.comment_versions[c.at_version_num]['display']
306 else:
306 else:
307 c.inline_comments_flat = c.inline_versions[c.at_version_num]['until']
307 c.inline_comments_flat = c.inline_versions[c.at_version_num]['until']
308 c.comments = c.comment_versions[c.at_version_num]['until']
308 c.comments = c.comment_versions[c.at_version_num]['until']
309
309
310 return general_comments, inline_comments
310 return general_comments, inline_comments
311
311
312 @LoginRequired()
312 @LoginRequired()
313 @HasRepoPermissionAnyDecorator(
313 @HasRepoPermissionAnyDecorator(
314 'repository.read', 'repository.write', 'repository.admin')
314 'repository.read', 'repository.write', 'repository.admin')
315 def pull_request_show(self):
315 def pull_request_show(self):
316 _ = self.request.translate
316 _ = self.request.translate
317 c = self.load_default_context()
317 c = self.load_default_context()
318
318
319 pull_request = PullRequest.get_or_404(
319 pull_request = PullRequest.get_or_404(
320 self.request.matchdict['pull_request_id'])
320 self.request.matchdict['pull_request_id'])
321 pull_request_id = pull_request.pull_request_id
321 pull_request_id = pull_request.pull_request_id
322
322
323 c.state_progressing = pull_request.is_state_changing()
323 c.state_progressing = pull_request.is_state_changing()
324 c.pr_broadcast_channel = channelstream.pr_channel(pull_request)
324 c.pr_broadcast_channel = channelstream.pr_channel(pull_request)
325
325
326 _new_state = {
326 _new_state = {
327 'created': PullRequest.STATE_CREATED,
327 'created': PullRequest.STATE_CREATED,
328 }.get(self.request.GET.get('force_state'))
328 }.get(self.request.GET.get('force_state'))
329 can_force_state = c.is_super_admin or HasRepoPermissionAny('repository.admin')(c.repo_name)
329 can_force_state = c.is_super_admin or HasRepoPermissionAny('repository.admin')(c.repo_name)
330
330
331 if can_force_state and _new_state:
331 if can_force_state and _new_state:
332 with pull_request.set_state(PullRequest.STATE_UPDATING, final_state=_new_state):
332 with pull_request.set_state(PullRequest.STATE_UPDATING, final_state=_new_state):
333 h.flash(
333 h.flash(
334 _('Pull Request state was force changed to `{}`').format(_new_state),
334 _('Pull Request state was force changed to `{}`').format(_new_state),
335 category='success')
335 category='success')
336 Session().commit()
336 Session().commit()
337
337
338 raise HTTPFound(h.route_path(
338 raise HTTPFound(h.route_path(
339 'pullrequest_show', repo_name=self.db_repo_name,
339 'pullrequest_show', repo_name=self.db_repo_name,
340 pull_request_id=pull_request_id))
340 pull_request_id=pull_request_id))
341
341
342 version = self.request.GET.get('version')
342 version = self.request.GET.get('version')
343 from_version = self.request.GET.get('from_version') or version
343 from_version = self.request.GET.get('from_version') or version
344 merge_checks = self.request.GET.get('merge_checks')
344 merge_checks = self.request.GET.get('merge_checks')
345 c.fulldiff = str2bool(self.request.GET.get('fulldiff'))
345 c.fulldiff = str2bool(self.request.GET.get('fulldiff'))
346 force_refresh = str2bool(self.request.GET.get('force_refresh'))
346 force_refresh = str2bool(self.request.GET.get('force_refresh'))
347 c.range_diff_on = self.request.GET.get('range-diff') == "1"
347 c.range_diff_on = self.request.GET.get('range-diff') == "1"
348
348
349 # fetch global flags of ignore ws or context lines
349 # fetch global flags of ignore ws or context lines
350 diff_context = diffs.get_diff_context(self.request)
350 diff_context = diffs.get_diff_context(self.request)
351 hide_whitespace_changes = diffs.get_diff_whitespace_flag(self.request)
351 hide_whitespace_changes = diffs.get_diff_whitespace_flag(self.request)
352
352
353 (pull_request_latest,
353 (pull_request_latest,
354 pull_request_at_ver,
354 pull_request_at_ver,
355 pull_request_display_obj,
355 pull_request_display_obj,
356 at_version) = PullRequestModel().get_pr_version(
356 at_version) = PullRequestModel().get_pr_version(
357 pull_request_id, version=version)
357 pull_request_id, version=version)
358
358
359 pr_closed = pull_request_latest.is_closed()
359 pr_closed = pull_request_latest.is_closed()
360
360
361 if pr_closed and (version or from_version):
361 if pr_closed and (version or from_version):
362 # not allow to browse versions for closed PR
362 # not allow to browse versions for closed PR
363 raise HTTPFound(h.route_path(
363 raise HTTPFound(h.route_path(
364 'pullrequest_show', repo_name=self.db_repo_name,
364 'pullrequest_show', repo_name=self.db_repo_name,
365 pull_request_id=pull_request_id))
365 pull_request_id=pull_request_id))
366
366
367 versions = pull_request_display_obj.versions()
367 versions = pull_request_display_obj.versions()
368
368
369 c.commit_versions = PullRequestModel().pr_commits_versions(versions)
369 c.commit_versions = PullRequestModel().pr_commits_versions(versions)
370
370
371 # used to store per-commit range diffs
371 # used to store per-commit range diffs
372 c.changes = collections.OrderedDict()
372 c.changes = collections.OrderedDict()
373
373
374 c.at_version = at_version
374 c.at_version = at_version
375 c.at_version_num = (at_version
375 c.at_version_num = (at_version
376 if at_version and at_version != PullRequest.LATEST_VER
376 if at_version and at_version != PullRequest.LATEST_VER
377 else None)
377 else None)
378
378
379 c.at_version_index = ChangesetComment.get_index_from_version(
379 c.at_version_index = ChangesetComment.get_index_from_version(
380 c.at_version_num, versions)
380 c.at_version_num, versions)
381
381
382 (prev_pull_request_latest,
382 (prev_pull_request_latest,
383 prev_pull_request_at_ver,
383 prev_pull_request_at_ver,
384 prev_pull_request_display_obj,
384 prev_pull_request_display_obj,
385 prev_at_version) = PullRequestModel().get_pr_version(
385 prev_at_version) = PullRequestModel().get_pr_version(
386 pull_request_id, version=from_version)
386 pull_request_id, version=from_version)
387
387
388 c.from_version = prev_at_version
388 c.from_version = prev_at_version
389 c.from_version_num = (prev_at_version
389 c.from_version_num = (prev_at_version
390 if prev_at_version and prev_at_version != PullRequest.LATEST_VER
390 if prev_at_version and prev_at_version != PullRequest.LATEST_VER
391 else None)
391 else None)
392 c.from_version_index = ChangesetComment.get_index_from_version(
392 c.from_version_index = ChangesetComment.get_index_from_version(
393 c.from_version_num, versions)
393 c.from_version_num, versions)
394
394
395 # define if we're in COMPARE mode or VIEW at version mode
395 # define if we're in COMPARE mode or VIEW at version mode
396 compare = at_version != prev_at_version
396 compare = at_version != prev_at_version
397
397
398 # pull_requests repo_name we opened it against
398 # pull_requests repo_name we opened it against
399 # ie. target_repo must match
399 # ie. target_repo must match
400 if self.db_repo_name != pull_request_at_ver.target_repo.repo_name:
400 if self.db_repo_name != pull_request_at_ver.target_repo.repo_name:
401 log.warning('Mismatch between the current repo: %s, and target %s',
401 log.warning('Mismatch between the current repo: %s, and target %s',
402 self.db_repo_name, pull_request_at_ver.target_repo.repo_name)
402 self.db_repo_name, pull_request_at_ver.target_repo.repo_name)
403 raise HTTPNotFound()
403 raise HTTPNotFound()
404
404
405 c.shadow_clone_url = PullRequestModel().get_shadow_clone_url(pull_request_at_ver)
405 c.shadow_clone_url = PullRequestModel().get_shadow_clone_url(pull_request_at_ver)
406
406
407 c.pull_request = pull_request_display_obj
407 c.pull_request = pull_request_display_obj
408 c.renderer = pull_request_at_ver.description_renderer or c.renderer
408 c.renderer = pull_request_at_ver.description_renderer or c.renderer
409 c.pull_request_latest = pull_request_latest
409 c.pull_request_latest = pull_request_latest
410
410
411 # inject latest version
411 # inject latest version
412 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
412 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
413 c.versions = versions + [latest_ver]
413 c.versions = versions + [latest_ver]
414
414
415 if compare or (at_version and not at_version == PullRequest.LATEST_VER):
415 if compare or (at_version and not at_version == PullRequest.LATEST_VER):
416 c.allowed_to_change_status = False
416 c.allowed_to_change_status = False
417 c.allowed_to_update = False
417 c.allowed_to_update = False
418 c.allowed_to_merge = False
418 c.allowed_to_merge = False
419 c.allowed_to_delete = False
419 c.allowed_to_delete = False
420 c.allowed_to_comment = False
420 c.allowed_to_comment = False
421 c.allowed_to_close = False
421 c.allowed_to_close = False
422 else:
422 else:
423 can_change_status = PullRequestModel().check_user_change_status(
423 can_change_status = PullRequestModel().check_user_change_status(
424 pull_request_at_ver, self._rhodecode_user)
424 pull_request_at_ver, self._rhodecode_user)
425 c.allowed_to_change_status = can_change_status and not pr_closed
425 c.allowed_to_change_status = can_change_status and not pr_closed
426
426
427 c.allowed_to_update = PullRequestModel().check_user_update(
427 c.allowed_to_update = PullRequestModel().check_user_update(
428 pull_request_latest, self._rhodecode_user) and not pr_closed
428 pull_request_latest, self._rhodecode_user) and not pr_closed
429 c.allowed_to_merge = PullRequestModel().check_user_merge(
429 c.allowed_to_merge = PullRequestModel().check_user_merge(
430 pull_request_latest, self._rhodecode_user) and not pr_closed
430 pull_request_latest, self._rhodecode_user) and not pr_closed
431 c.allowed_to_delete = PullRequestModel().check_user_delete(
431 c.allowed_to_delete = PullRequestModel().check_user_delete(
432 pull_request_latest, self._rhodecode_user) and not pr_closed
432 pull_request_latest, self._rhodecode_user) and not pr_closed
433 c.allowed_to_comment = not pr_closed
433 c.allowed_to_comment = not pr_closed
434 c.allowed_to_close = c.allowed_to_merge and not pr_closed
434 c.allowed_to_close = c.allowed_to_merge and not pr_closed
435
435
436 c.forbid_adding_reviewers = False
436 c.forbid_adding_reviewers = False
437
437
438 if pull_request_latest.reviewer_data and \
438 if pull_request_latest.reviewer_data and \
439 'rules' in pull_request_latest.reviewer_data:
439 'rules' in pull_request_latest.reviewer_data:
440 rules = pull_request_latest.reviewer_data['rules'] or {}
440 rules = pull_request_latest.reviewer_data['rules'] or {}
441 try:
441 try:
442 c.forbid_adding_reviewers = rules.get('forbid_adding_reviewers')
442 c.forbid_adding_reviewers = rules.get('forbid_adding_reviewers')
443 except Exception:
443 except Exception:
444 pass
444 pass
445
445
446 # check merge capabilities
446 # check merge capabilities
447 _merge_check = MergeCheck.validate(
447 _merge_check = MergeCheck.validate(
448 pull_request_latest, auth_user=self._rhodecode_user,
448 pull_request_latest, auth_user=self._rhodecode_user,
449 translator=self.request.translate,
449 translator=self.request.translate,
450 force_shadow_repo_refresh=force_refresh)
450 force_shadow_repo_refresh=force_refresh)
451
451
452 c.pr_merge_errors = _merge_check.error_details
452 c.pr_merge_errors = _merge_check.error_details
453 c.pr_merge_possible = not _merge_check.failed
453 c.pr_merge_possible = not _merge_check.failed
454 c.pr_merge_message = _merge_check.merge_msg
454 c.pr_merge_message = _merge_check.merge_msg
455 c.pr_merge_source_commit = _merge_check.source_commit
455 c.pr_merge_source_commit = _merge_check.source_commit
456 c.pr_merge_target_commit = _merge_check.target_commit
456 c.pr_merge_target_commit = _merge_check.target_commit
457
457
458 c.pr_merge_info = MergeCheck.get_merge_conditions(
458 c.pr_merge_info = MergeCheck.get_merge_conditions(
459 pull_request_latest, translator=self.request.translate)
459 pull_request_latest, translator=self.request.translate)
460
460
461 c.pull_request_review_status = _merge_check.review_status
461 c.pull_request_review_status = _merge_check.review_status
462 if merge_checks:
462 if merge_checks:
463 self.request.override_renderer = \
463 self.request.override_renderer = \
464 'rhodecode:templates/pullrequests/pullrequest_merge_checks.mako'
464 'rhodecode:templates/pullrequests/pullrequest_merge_checks.mako'
465 return self._get_template_context(c)
465 return self._get_template_context(c)
466
466
467 c.reviewers_count = pull_request.reviewers_count
467 c.reviewers_count = pull_request.reviewers_count
468 c.observers_count = pull_request.observers_count
468 c.observers_count = pull_request.observers_count
469
469
470 # reviewers and statuses
470 # reviewers and statuses
471 c.pull_request_default_reviewers_data_json = json.dumps(pull_request.reviewer_data)
471 c.pull_request_default_reviewers_data_json = json.dumps(pull_request.reviewer_data)
472 c.pull_request_set_reviewers_data_json = collections.OrderedDict({'reviewers': []})
472 c.pull_request_set_reviewers_data_json = collections.OrderedDict({'reviewers': []})
473 c.pull_request_set_observers_data_json = collections.OrderedDict({'observers': []})
473 c.pull_request_set_observers_data_json = collections.OrderedDict({'observers': []})
474
474
475 for review_obj, member, reasons, mandatory, status in pull_request_at_ver.reviewers_statuses():
475 for review_obj, member, reasons, mandatory, status in pull_request_at_ver.reviewers_statuses():
476 member_reviewer = h.reviewer_as_json(
476 member_reviewer = h.reviewer_as_json(
477 member, reasons=reasons, mandatory=mandatory,
477 member, reasons=reasons, mandatory=mandatory,
478 role=review_obj.role,
478 role=review_obj.role,
479 user_group=review_obj.rule_user_group_data()
479 user_group=review_obj.rule_user_group_data()
480 )
480 )
481
481
482 current_review_status = status[0][1].status if status else ChangesetStatus.STATUS_NOT_REVIEWED
482 current_review_status = status[0][1].status if status else ChangesetStatus.STATUS_NOT_REVIEWED
483 member_reviewer['review_status'] = current_review_status
483 member_reviewer['review_status'] = current_review_status
484 member_reviewer['review_status_label'] = h.commit_status_lbl(current_review_status)
484 member_reviewer['review_status_label'] = h.commit_status_lbl(current_review_status)
485 member_reviewer['allowed_to_update'] = c.allowed_to_update
485 member_reviewer['allowed_to_update'] = c.allowed_to_update
486 c.pull_request_set_reviewers_data_json['reviewers'].append(member_reviewer)
486 c.pull_request_set_reviewers_data_json['reviewers'].append(member_reviewer)
487
487
488 c.pull_request_set_reviewers_data_json = json.dumps(c.pull_request_set_reviewers_data_json)
488 c.pull_request_set_reviewers_data_json = json.dumps(c.pull_request_set_reviewers_data_json)
489
489
490 for observer_obj, member in pull_request_at_ver.observers():
490 for observer_obj, member in pull_request_at_ver.observers():
491 member_observer = h.reviewer_as_json(
491 member_observer = h.reviewer_as_json(
492 member, reasons=[], mandatory=False,
492 member, reasons=[], mandatory=False,
493 role=observer_obj.role,
493 role=observer_obj.role,
494 user_group=observer_obj.rule_user_group_data()
494 user_group=observer_obj.rule_user_group_data()
495 )
495 )
496 member_observer['allowed_to_update'] = c.allowed_to_update
496 member_observer['allowed_to_update'] = c.allowed_to_update
497 c.pull_request_set_observers_data_json['observers'].append(member_observer)
497 c.pull_request_set_observers_data_json['observers'].append(member_observer)
498
498
499 c.pull_request_set_observers_data_json = json.dumps(c.pull_request_set_observers_data_json)
499 c.pull_request_set_observers_data_json = json.dumps(c.pull_request_set_observers_data_json)
500
500
501 general_comments, inline_comments = \
501 general_comments, inline_comments = \
502 self.register_comments_vars(c, pull_request_latest, versions)
502 self.register_comments_vars(c, pull_request_latest, versions)
503
503
504 # TODOs
504 # TODOs
505 c.unresolved_comments = CommentsModel() \
505 c.unresolved_comments = CommentsModel() \
506 .get_pull_request_unresolved_todos(pull_request_latest)
506 .get_pull_request_unresolved_todos(pull_request_latest)
507 c.resolved_comments = CommentsModel() \
507 c.resolved_comments = CommentsModel() \
508 .get_pull_request_resolved_todos(pull_request_latest)
508 .get_pull_request_resolved_todos(pull_request_latest)
509
509
510 # Drafts
510 # Drafts
511 c.draft_comments = CommentsModel().get_pull_request_drafts(
511 c.draft_comments = CommentsModel().get_pull_request_drafts(
512 self._rhodecode_db_user.user_id,
512 self._rhodecode_db_user.user_id,
513 pull_request_latest)
513 pull_request_latest)
514
514
515 # if we use version, then do not show later comments
515 # if we use version, then do not show later comments
516 # than current version
516 # than current version
517 display_inline_comments = collections.defaultdict(
517 display_inline_comments = collections.defaultdict(
518 lambda: collections.defaultdict(list))
518 lambda: collections.defaultdict(list))
519 for co in inline_comments:
519 for co in inline_comments:
520 if c.at_version_num:
520 if c.at_version_num:
521 # pick comments that are at least UPTO given version, so we
521 # pick comments that are at least UPTO given version, so we
522 # don't render comments for higher version
522 # don't render comments for higher version
523 should_render = co.pull_request_version_id and \
523 should_render = co.pull_request_version_id and \
524 co.pull_request_version_id <= c.at_version_num
524 co.pull_request_version_id <= c.at_version_num
525 else:
525 else:
526 # showing all, for 'latest'
526 # showing all, for 'latest'
527 should_render = True
527 should_render = True
528
528
529 if should_render:
529 if should_render:
530 display_inline_comments[co.f_path][co.line_no].append(co)
530 display_inline_comments[co.f_path][co.line_no].append(co)
531
531
532 # load diff data into template context, if we use compare mode then
532 # load diff data into template context, if we use compare mode then
533 # diff is calculated based on changes between versions of PR
533 # diff is calculated based on changes between versions of PR
534
534
535 source_repo = pull_request_at_ver.source_repo
535 source_repo = pull_request_at_ver.source_repo
536 source_ref_id = pull_request_at_ver.source_ref_parts.commit_id
536 source_ref_id = pull_request_at_ver.source_ref_parts.commit_id
537
537
538 target_repo = pull_request_at_ver.target_repo
538 target_repo = pull_request_at_ver.target_repo
539 target_ref_id = pull_request_at_ver.target_ref_parts.commit_id
539 target_ref_id = pull_request_at_ver.target_ref_parts.commit_id
540
540
541 if compare:
541 if compare:
542 # in compare switch the diff base to latest commit from prev version
542 # in compare switch the diff base to latest commit from prev version
543 target_ref_id = prev_pull_request_display_obj.revisions[0]
543 target_ref_id = prev_pull_request_display_obj.revisions[0]
544
544
545 # despite opening commits for bookmarks/branches/tags, we always
545 # despite opening commits for bookmarks/branches/tags, we always
546 # convert this to rev to prevent changes after bookmark or branch change
546 # convert this to rev to prevent changes after bookmark or branch change
547 c.source_ref_type = 'rev'
547 c.source_ref_type = 'rev'
548 c.source_ref = source_ref_id
548 c.source_ref = source_ref_id
549
549
550 c.target_ref_type = 'rev'
550 c.target_ref_type = 'rev'
551 c.target_ref = target_ref_id
551 c.target_ref = target_ref_id
552
552
553 c.source_repo = source_repo
553 c.source_repo = source_repo
554 c.target_repo = target_repo
554 c.target_repo = target_repo
555
555
556 c.commit_ranges = []
556 c.commit_ranges = []
557 source_commit = EmptyCommit()
557 source_commit = EmptyCommit()
558 target_commit = EmptyCommit()
558 target_commit = EmptyCommit()
559 c.missing_requirements = False
559 c.missing_requirements = False
560
560
561 source_scm = source_repo.scm_instance()
561 source_scm = source_repo.scm_instance()
562 target_scm = target_repo.scm_instance()
562 target_scm = target_repo.scm_instance()
563
563
564 shadow_scm = None
564 shadow_scm = None
565 try:
565 try:
566 shadow_scm = pull_request_latest.get_shadow_repo()
566 shadow_scm = pull_request_latest.get_shadow_repo()
567 except Exception:
567 except Exception:
568 log.debug('Failed to get shadow repo', exc_info=True)
568 log.debug('Failed to get shadow repo', exc_info=True)
569 # try first the existing source_repo, and then shadow
569 # try first the existing source_repo, and then shadow
570 # repo if we can obtain one
570 # repo if we can obtain one
571 commits_source_repo = source_scm
571 commits_source_repo = source_scm
572 if shadow_scm:
572 if shadow_scm:
573 commits_source_repo = shadow_scm
573 commits_source_repo = shadow_scm
574
574
575 c.commits_source_repo = commits_source_repo
575 c.commits_source_repo = commits_source_repo
576 c.ancestor = None # set it to None, to hide it from PR view
576 c.ancestor = None # set it to None, to hide it from PR view
577
577
578 # empty version means latest, so we keep this to prevent
578 # empty version means latest, so we keep this to prevent
579 # double caching
579 # double caching
580 version_normalized = version or PullRequest.LATEST_VER
580 version_normalized = version or PullRequest.LATEST_VER
581 from_version_normalized = from_version or PullRequest.LATEST_VER
581 from_version_normalized = from_version or PullRequest.LATEST_VER
582
582
583 cache_path = self.rhodecode_vcs_repo.get_create_shadow_cache_pr_path(target_repo)
583 cache_path = self.rhodecode_vcs_repo.get_create_shadow_cache_pr_path(target_repo)
584 cache_file_path = diff_cache_exist(
584 cache_file_path = diff_cache_exist(
585 cache_path, 'pull_request', pull_request_id, version_normalized,
585 cache_path, 'pull_request', pull_request_id, version_normalized,
586 from_version_normalized, source_ref_id, target_ref_id,
586 from_version_normalized, source_ref_id, target_ref_id,
587 hide_whitespace_changes, diff_context, c.fulldiff)
587 hide_whitespace_changes, diff_context, c.fulldiff)
588
588
589 caching_enabled = self._is_diff_cache_enabled(c.target_repo)
589 caching_enabled = self._is_diff_cache_enabled(c.target_repo)
590 force_recache = self.get_recache_flag()
590 force_recache = self.get_recache_flag()
591
591
592 cached_diff = None
592 cached_diff = None
593 if caching_enabled:
593 if caching_enabled:
594 cached_diff = load_cached_diff(cache_file_path)
594 cached_diff = load_cached_diff(cache_file_path)
595
595
596 has_proper_commit_cache = (
596 has_proper_commit_cache = (
597 cached_diff and cached_diff.get('commits')
597 cached_diff and cached_diff.get('commits')
598 and len(cached_diff.get('commits', [])) == 5
598 and len(cached_diff.get('commits', [])) == 5
599 and cached_diff.get('commits')[0]
599 and cached_diff.get('commits')[0]
600 and cached_diff.get('commits')[3])
600 and cached_diff.get('commits')[3])
601
601
602 if not force_recache and not c.range_diff_on and has_proper_commit_cache:
602 if not force_recache and not c.range_diff_on and has_proper_commit_cache:
603 diff_commit_cache = \
603 diff_commit_cache = \
604 (ancestor_commit, commit_cache, missing_requirements,
604 (ancestor_commit, commit_cache, missing_requirements,
605 source_commit, target_commit) = cached_diff['commits']
605 source_commit, target_commit) = cached_diff['commits']
606 else:
606 else:
607 # NOTE(marcink): we reach potentially unreachable errors when a PR has
607 # NOTE(marcink): we reach potentially unreachable errors when a PR has
608 # merge errors resulting in potentially hidden commits in the shadow repo.
608 # merge errors resulting in potentially hidden commits in the shadow repo.
609 maybe_unreachable = _merge_check.MERGE_CHECK in _merge_check.error_details \
609 maybe_unreachable = _merge_check.MERGE_CHECK in _merge_check.error_details \
610 and _merge_check.merge_response
610 and _merge_check.merge_response
611 maybe_unreachable = maybe_unreachable \
611 maybe_unreachable = maybe_unreachable \
612 and _merge_check.merge_response.metadata.get('unresolved_files')
612 and _merge_check.merge_response.metadata.get('unresolved_files')
613 log.debug("Using unreachable commits due to MERGE_CHECK in merge simulation")
613 log.debug("Using unreachable commits due to MERGE_CHECK in merge simulation")
614 diff_commit_cache = \
614 diff_commit_cache = \
615 (ancestor_commit, commit_cache, missing_requirements,
615 (ancestor_commit, commit_cache, missing_requirements,
616 source_commit, target_commit) = self.get_commits(
616 source_commit, target_commit) = self.get_commits(
617 commits_source_repo,
617 commits_source_repo,
618 pull_request_at_ver,
618 pull_request_at_ver,
619 source_commit,
619 source_commit,
620 source_ref_id,
620 source_ref_id,
621 source_scm,
621 source_scm,
622 target_commit,
622 target_commit,
623 target_ref_id,
623 target_ref_id,
624 target_scm,
624 target_scm,
625 maybe_unreachable=maybe_unreachable)
625 maybe_unreachable=maybe_unreachable)
626
626
627 # register our commit range
627 # register our commit range
628 for comm in commit_cache.values():
628 for comm in commit_cache.values():
629 c.commit_ranges.append(comm)
629 c.commit_ranges.append(comm)
630
630
631 c.missing_requirements = missing_requirements
631 c.missing_requirements = missing_requirements
632 c.ancestor_commit = ancestor_commit
632 c.ancestor_commit = ancestor_commit
633 c.statuses = source_repo.statuses(
633 c.statuses = source_repo.statuses(
634 [x.raw_id for x in c.commit_ranges])
634 [x.raw_id for x in c.commit_ranges])
635
635
636 # auto collapse if we have more than limit
636 # auto collapse if we have more than limit
637 collapse_limit = diffs.DiffProcessor._collapse_commits_over
637 collapse_limit = diffs.DiffProcessor._collapse_commits_over
638 c.collapse_all_commits = len(c.commit_ranges) > collapse_limit
638 c.collapse_all_commits = len(c.commit_ranges) > collapse_limit
639 c.compare_mode = compare
639 c.compare_mode = compare
640
640
641 # diff_limit is the old behavior, will cut off the whole diff
641 # diff_limit is the old behavior, will cut off the whole diff
642 # if the limit is applied otherwise will just hide the
642 # if the limit is applied otherwise will just hide the
643 # big files from the front-end
643 # big files from the front-end
644 diff_limit = c.visual.cut_off_limit_diff
644 diff_limit = c.visual.cut_off_limit_diff
645 file_limit = c.visual.cut_off_limit_file
645 file_limit = c.visual.cut_off_limit_file
646
646
647 c.missing_commits = False
647 c.missing_commits = False
648 if (c.missing_requirements
648 if (c.missing_requirements
649 or isinstance(source_commit, EmptyCommit)
649 or isinstance(source_commit, EmptyCommit)
650 or source_commit == target_commit):
650 or source_commit == target_commit):
651
651
652 c.missing_commits = True
652 c.missing_commits = True
653 else:
653 else:
654 c.inline_comments = display_inline_comments
654 c.inline_comments = display_inline_comments
655
655
656 use_ancestor = True
656 use_ancestor = True
657 if from_version_normalized != version_normalized:
657 if from_version_normalized != version_normalized:
658 use_ancestor = False
658 use_ancestor = False
659
659
660 has_proper_diff_cache = cached_diff and cached_diff.get('commits')
660 has_proper_diff_cache = cached_diff and cached_diff.get('commits')
661 if not force_recache and has_proper_diff_cache:
661 if not force_recache and has_proper_diff_cache:
662 c.diffset = cached_diff['diff']
662 c.diffset = cached_diff['diff']
663 else:
663 else:
664 try:
664 try:
665 c.diffset = self._get_diffset(
665 c.diffset = self._get_diffset(
666 c.source_repo.repo_name, commits_source_repo,
666 c.source_repo.repo_name, commits_source_repo,
667 c.ancestor_commit,
667 c.ancestor_commit,
668 source_ref_id, target_ref_id,
668 source_ref_id, target_ref_id,
669 target_commit, source_commit,
669 target_commit, source_commit,
670 diff_limit, file_limit, c.fulldiff,
670 diff_limit, file_limit, c.fulldiff,
671 hide_whitespace_changes, diff_context,
671 hide_whitespace_changes, diff_context,
672 use_ancestor=use_ancestor
672 use_ancestor=use_ancestor
673 )
673 )
674
674
675 # save cached diff
675 # save cached diff
676 if caching_enabled:
676 if caching_enabled:
677 cache_diff(cache_file_path, c.diffset, diff_commit_cache)
677 cache_diff(cache_file_path, c.diffset, diff_commit_cache)
678 except CommitDoesNotExistError:
678 except CommitDoesNotExistError:
679 log.exception('Failed to generate diffset')
679 log.exception('Failed to generate diffset')
680 c.missing_commits = True
680 c.missing_commits = True
681
681
682 if not c.missing_commits:
682 if not c.missing_commits:
683
683
684 c.limited_diff = c.diffset.limited_diff
684 c.limited_diff = c.diffset.limited_diff
685
685
686 # calculate removed files that are bound to comments
686 # calculate removed files that are bound to comments
687 comment_deleted_files = [
687 comment_deleted_files = [
688 fname for fname in display_inline_comments
688 fname for fname in display_inline_comments
689 if fname not in c.diffset.file_stats]
689 if fname not in c.diffset.file_stats]
690
690
691 c.deleted_files_comments = collections.defaultdict(dict)
691 c.deleted_files_comments = collections.defaultdict(dict)
692 for fname, per_line_comments in display_inline_comments.items():
692 for fname, per_line_comments in display_inline_comments.items():
693 if fname in comment_deleted_files:
693 if fname in comment_deleted_files:
694 c.deleted_files_comments[fname]['stats'] = 0
694 c.deleted_files_comments[fname]['stats'] = 0
695 c.deleted_files_comments[fname]['comments'] = list()
695 c.deleted_files_comments[fname]['comments'] = list()
696 for lno, comments in per_line_comments.items():
696 for lno, comments in per_line_comments.items():
697 c.deleted_files_comments[fname]['comments'].extend(comments)
697 c.deleted_files_comments[fname]['comments'].extend(comments)
698
698
699 # maybe calculate the range diff
699 # maybe calculate the range diff
700 if c.range_diff_on:
700 if c.range_diff_on:
701 # TODO(marcink): set whitespace/context
701 # TODO(marcink): set whitespace/context
702 context_lcl = 3
702 context_lcl = 3
703 ign_whitespace_lcl = False
703 ign_whitespace_lcl = False
704
704
705 for commit in c.commit_ranges:
705 for commit in c.commit_ranges:
706 commit2 = commit
706 commit2 = commit
707 commit1 = commit.first_parent
707 commit1 = commit.first_parent
708
708
709 range_diff_cache_file_path = diff_cache_exist(
709 range_diff_cache_file_path = diff_cache_exist(
710 cache_path, 'diff', commit.raw_id,
710 cache_path, 'diff', commit.raw_id,
711 ign_whitespace_lcl, context_lcl, c.fulldiff)
711 ign_whitespace_lcl, context_lcl, c.fulldiff)
712
712
713 cached_diff = None
713 cached_diff = None
714 if caching_enabled:
714 if caching_enabled:
715 cached_diff = load_cached_diff(range_diff_cache_file_path)
715 cached_diff = load_cached_diff(range_diff_cache_file_path)
716
716
717 has_proper_diff_cache = cached_diff and cached_diff.get('diff')
717 has_proper_diff_cache = cached_diff and cached_diff.get('diff')
718 if not force_recache and has_proper_diff_cache:
718 if not force_recache and has_proper_diff_cache:
719 diffset = cached_diff['diff']
719 diffset = cached_diff['diff']
720 else:
720 else:
721 diffset = self._get_range_diffset(
721 diffset = self._get_range_diffset(
722 commits_source_repo, source_repo,
722 commits_source_repo, source_repo,
723 commit1, commit2, diff_limit, file_limit,
723 commit1, commit2, diff_limit, file_limit,
724 c.fulldiff, ign_whitespace_lcl, context_lcl
724 c.fulldiff, ign_whitespace_lcl, context_lcl
725 )
725 )
726
726
727 # save cached diff
727 # save cached diff
728 if caching_enabled:
728 if caching_enabled:
729 cache_diff(range_diff_cache_file_path, diffset, None)
729 cache_diff(range_diff_cache_file_path, diffset, None)
730
730
731 c.changes[commit.raw_id] = diffset
731 c.changes[commit.raw_id] = diffset
732
732
733 # this is a hack to properly display links, when creating PR, the
733 # this is a hack to properly display links, when creating PR, the
734 # compare view and others uses different notation, and
734 # compare view and others uses different notation, and
735 # compare_commits.mako renders links based on the target_repo.
735 # compare_commits.mako renders links based on the target_repo.
736 # We need to swap that here to generate it properly on the html side
736 # We need to swap that here to generate it properly on the html side
737 c.target_repo = c.source_repo
737 c.target_repo = c.source_repo
738
738
739 c.commit_statuses = ChangesetStatus.STATUSES
739 c.commit_statuses = ChangesetStatus.STATUSES
740
740
741 c.show_version_changes = not pr_closed
741 c.show_version_changes = not pr_closed
742 if c.show_version_changes:
742 if c.show_version_changes:
743 cur_obj = pull_request_at_ver
743 cur_obj = pull_request_at_ver
744 prev_obj = prev_pull_request_at_ver
744 prev_obj = prev_pull_request_at_ver
745
745
746 old_commit_ids = prev_obj.revisions
746 old_commit_ids = prev_obj.revisions
747 new_commit_ids = cur_obj.revisions
747 new_commit_ids = cur_obj.revisions
748 commit_changes = PullRequestModel()._calculate_commit_id_changes(
748 commit_changes = PullRequestModel()._calculate_commit_id_changes(
749 old_commit_ids, new_commit_ids)
749 old_commit_ids, new_commit_ids)
750 c.commit_changes_summary = commit_changes
750 c.commit_changes_summary = commit_changes
751
751
752 # calculate the diff for commits between versions
752 # calculate the diff for commits between versions
753 c.commit_changes = []
753 c.commit_changes = []
754
754
755 def mark(cs, fw):
755 def mark(cs, fw):
756 return list(h.itertools.izip_longest([], cs, fillvalue=fw))
756 return list(h.itertools.zip_longest([], cs, fillvalue=fw))
757
757
758 for c_type, raw_id in mark(commit_changes.added, 'a') \
758 for c_type, raw_id in mark(commit_changes.added, 'a') \
759 + mark(commit_changes.removed, 'r') \
759 + mark(commit_changes.removed, 'r') \
760 + mark(commit_changes.common, 'c'):
760 + mark(commit_changes.common, 'c'):
761
761
762 if raw_id in commit_cache:
762 if raw_id in commit_cache:
763 commit = commit_cache[raw_id]
763 commit = commit_cache[raw_id]
764 else:
764 else:
765 try:
765 try:
766 commit = commits_source_repo.get_commit(raw_id)
766 commit = commits_source_repo.get_commit(raw_id)
767 except CommitDoesNotExistError:
767 except CommitDoesNotExistError:
768 # in case we fail extracting still use "dummy" commit
768 # in case we fail extracting still use "dummy" commit
769 # for display in commit diff
769 # for display in commit diff
770 commit = h.AttributeDict(
770 commit = h.AttributeDict(
771 {'raw_id': raw_id,
771 {'raw_id': raw_id,
772 'message': 'EMPTY or MISSING COMMIT'})
772 'message': 'EMPTY or MISSING COMMIT'})
773 c.commit_changes.append([c_type, commit])
773 c.commit_changes.append([c_type, commit])
774
774
775 # current user review statuses for each version
775 # current user review statuses for each version
776 c.review_versions = {}
776 c.review_versions = {}
777 is_reviewer = PullRequestModel().is_user_reviewer(
777 is_reviewer = PullRequestModel().is_user_reviewer(
778 pull_request, self._rhodecode_user)
778 pull_request, self._rhodecode_user)
779 if is_reviewer:
779 if is_reviewer:
780 for co in general_comments:
780 for co in general_comments:
781 if co.author.user_id == self._rhodecode_user.user_id:
781 if co.author.user_id == self._rhodecode_user.user_id:
782 status = co.status_change
782 status = co.status_change
783 if status:
783 if status:
784 _ver_pr = status[0].comment.pull_request_version_id
784 _ver_pr = status[0].comment.pull_request_version_id
785 c.review_versions[_ver_pr] = status[0]
785 c.review_versions[_ver_pr] = status[0]
786
786
787 return self._get_template_context(c)
787 return self._get_template_context(c)
788
788
789 def get_commits(
789 def get_commits(
790 self, commits_source_repo, pull_request_at_ver, source_commit,
790 self, commits_source_repo, pull_request_at_ver, source_commit,
791 source_ref_id, source_scm, target_commit, target_ref_id, target_scm,
791 source_ref_id, source_scm, target_commit, target_ref_id, target_scm,
792 maybe_unreachable=False):
792 maybe_unreachable=False):
793
793
794 commit_cache = collections.OrderedDict()
794 commit_cache = collections.OrderedDict()
795 missing_requirements = False
795 missing_requirements = False
796
796
797 try:
797 try:
798 pre_load = ["author", "date", "message", "branch", "parents"]
798 pre_load = ["author", "date", "message", "branch", "parents"]
799
799
800 pull_request_commits = pull_request_at_ver.revisions
800 pull_request_commits = pull_request_at_ver.revisions
801 log.debug('Loading %s commits from %s',
801 log.debug('Loading %s commits from %s',
802 len(pull_request_commits), commits_source_repo)
802 len(pull_request_commits), commits_source_repo)
803
803
804 for rev in pull_request_commits:
804 for rev in pull_request_commits:
805 comm = commits_source_repo.get_commit(commit_id=rev, pre_load=pre_load,
805 comm = commits_source_repo.get_commit(commit_id=rev, pre_load=pre_load,
806 maybe_unreachable=maybe_unreachable)
806 maybe_unreachable=maybe_unreachable)
807 commit_cache[comm.raw_id] = comm
807 commit_cache[comm.raw_id] = comm
808
808
809 # Order here matters, we first need to get target, and then
809 # Order here matters, we first need to get target, and then
810 # the source
810 # the source
811 target_commit = commits_source_repo.get_commit(
811 target_commit = commits_source_repo.get_commit(
812 commit_id=safe_str(target_ref_id))
812 commit_id=safe_str(target_ref_id))
813
813
814 source_commit = commits_source_repo.get_commit(
814 source_commit = commits_source_repo.get_commit(
815 commit_id=safe_str(source_ref_id), maybe_unreachable=True)
815 commit_id=safe_str(source_ref_id), maybe_unreachable=True)
816 except CommitDoesNotExistError:
816 except CommitDoesNotExistError:
817 log.warning('Failed to get commit from `{}` repo'.format(
817 log.warning('Failed to get commit from `{}` repo'.format(
818 commits_source_repo), exc_info=True)
818 commits_source_repo), exc_info=True)
819 except RepositoryRequirementError:
819 except RepositoryRequirementError:
820 log.warning('Failed to get all required data from repo', exc_info=True)
820 log.warning('Failed to get all required data from repo', exc_info=True)
821 missing_requirements = True
821 missing_requirements = True
822
822
823 pr_ancestor_id = pull_request_at_ver.common_ancestor_id
823 pr_ancestor_id = pull_request_at_ver.common_ancestor_id
824
824
825 try:
825 try:
826 ancestor_commit = source_scm.get_commit(pr_ancestor_id)
826 ancestor_commit = source_scm.get_commit(pr_ancestor_id)
827 except Exception:
827 except Exception:
828 ancestor_commit = None
828 ancestor_commit = None
829
829
830 return ancestor_commit, commit_cache, missing_requirements, source_commit, target_commit
830 return ancestor_commit, commit_cache, missing_requirements, source_commit, target_commit
831
831
832 def assure_not_empty_repo(self):
832 def assure_not_empty_repo(self):
833 _ = self.request.translate
833 _ = self.request.translate
834
834
835 try:
835 try:
836 self.db_repo.scm_instance().get_commit()
836 self.db_repo.scm_instance().get_commit()
837 except EmptyRepositoryError:
837 except EmptyRepositoryError:
838 h.flash(h.literal(_('There are no commits yet')),
838 h.flash(h.literal(_('There are no commits yet')),
839 category='warning')
839 category='warning')
840 raise HTTPFound(
840 raise HTTPFound(
841 h.route_path('repo_summary', repo_name=self.db_repo.repo_name))
841 h.route_path('repo_summary', repo_name=self.db_repo.repo_name))
842
842
843 @LoginRequired()
843 @LoginRequired()
844 @NotAnonymous()
844 @NotAnonymous()
845 @HasRepoPermissionAnyDecorator(
845 @HasRepoPermissionAnyDecorator(
846 'repository.read', 'repository.write', 'repository.admin')
846 'repository.read', 'repository.write', 'repository.admin')
847 def pull_request_new(self):
847 def pull_request_new(self):
848 _ = self.request.translate
848 _ = self.request.translate
849 c = self.load_default_context()
849 c = self.load_default_context()
850
850
851 self.assure_not_empty_repo()
851 self.assure_not_empty_repo()
852 source_repo = self.db_repo
852 source_repo = self.db_repo
853
853
854 commit_id = self.request.GET.get('commit')
854 commit_id = self.request.GET.get('commit')
855 branch_ref = self.request.GET.get('branch')
855 branch_ref = self.request.GET.get('branch')
856 bookmark_ref = self.request.GET.get('bookmark')
856 bookmark_ref = self.request.GET.get('bookmark')
857
857
858 try:
858 try:
859 source_repo_data = PullRequestModel().generate_repo_data(
859 source_repo_data = PullRequestModel().generate_repo_data(
860 source_repo, commit_id=commit_id,
860 source_repo, commit_id=commit_id,
861 branch=branch_ref, bookmark=bookmark_ref,
861 branch=branch_ref, bookmark=bookmark_ref,
862 translator=self.request.translate)
862 translator=self.request.translate)
863 except CommitDoesNotExistError as e:
863 except CommitDoesNotExistError as e:
864 log.exception(e)
864 log.exception(e)
865 h.flash(_('Commit does not exist'), 'error')
865 h.flash(_('Commit does not exist'), 'error')
866 raise HTTPFound(
866 raise HTTPFound(
867 h.route_path('pullrequest_new', repo_name=source_repo.repo_name))
867 h.route_path('pullrequest_new', repo_name=source_repo.repo_name))
868
868
869 default_target_repo = source_repo
869 default_target_repo = source_repo
870
870
871 if source_repo.parent and c.has_origin_repo_read_perm:
871 if source_repo.parent and c.has_origin_repo_read_perm:
872 parent_vcs_obj = source_repo.parent.scm_instance()
872 parent_vcs_obj = source_repo.parent.scm_instance()
873 if parent_vcs_obj and not parent_vcs_obj.is_empty():
873 if parent_vcs_obj and not parent_vcs_obj.is_empty():
874 # change default if we have a parent repo
874 # change default if we have a parent repo
875 default_target_repo = source_repo.parent
875 default_target_repo = source_repo.parent
876
876
877 target_repo_data = PullRequestModel().generate_repo_data(
877 target_repo_data = PullRequestModel().generate_repo_data(
878 default_target_repo, translator=self.request.translate)
878 default_target_repo, translator=self.request.translate)
879
879
880 selected_source_ref = source_repo_data['refs']['selected_ref']
880 selected_source_ref = source_repo_data['refs']['selected_ref']
881 title_source_ref = ''
881 title_source_ref = ''
882 if selected_source_ref:
882 if selected_source_ref:
883 title_source_ref = selected_source_ref.split(':', 2)[1]
883 title_source_ref = selected_source_ref.split(':', 2)[1]
884 c.default_title = PullRequestModel().generate_pullrequest_title(
884 c.default_title = PullRequestModel().generate_pullrequest_title(
885 source=source_repo.repo_name,
885 source=source_repo.repo_name,
886 source_ref=title_source_ref,
886 source_ref=title_source_ref,
887 target=default_target_repo.repo_name
887 target=default_target_repo.repo_name
888 )
888 )
889
889
890 c.default_repo_data = {
890 c.default_repo_data = {
891 'source_repo_name': source_repo.repo_name,
891 'source_repo_name': source_repo.repo_name,
892 'source_refs_json': json.dumps(source_repo_data),
892 'source_refs_json': json.dumps(source_repo_data),
893 'target_repo_name': default_target_repo.repo_name,
893 'target_repo_name': default_target_repo.repo_name,
894 'target_refs_json': json.dumps(target_repo_data),
894 'target_refs_json': json.dumps(target_repo_data),
895 }
895 }
896 c.default_source_ref = selected_source_ref
896 c.default_source_ref = selected_source_ref
897
897
898 return self._get_template_context(c)
898 return self._get_template_context(c)
899
899
900 @LoginRequired()
900 @LoginRequired()
901 @NotAnonymous()
901 @NotAnonymous()
902 @HasRepoPermissionAnyDecorator(
902 @HasRepoPermissionAnyDecorator(
903 'repository.read', 'repository.write', 'repository.admin')
903 'repository.read', 'repository.write', 'repository.admin')
904 def pull_request_repo_refs(self):
904 def pull_request_repo_refs(self):
905 self.load_default_context()
905 self.load_default_context()
906 target_repo_name = self.request.matchdict['target_repo_name']
906 target_repo_name = self.request.matchdict['target_repo_name']
907 repo = Repository.get_by_repo_name(target_repo_name)
907 repo = Repository.get_by_repo_name(target_repo_name)
908 if not repo:
908 if not repo:
909 raise HTTPNotFound()
909 raise HTTPNotFound()
910
910
911 target_perm = HasRepoPermissionAny(
911 target_perm = HasRepoPermissionAny(
912 'repository.read', 'repository.write', 'repository.admin')(
912 'repository.read', 'repository.write', 'repository.admin')(
913 target_repo_name)
913 target_repo_name)
914 if not target_perm:
914 if not target_perm:
915 raise HTTPNotFound()
915 raise HTTPNotFound()
916
916
917 return PullRequestModel().generate_repo_data(
917 return PullRequestModel().generate_repo_data(
918 repo, translator=self.request.translate)
918 repo, translator=self.request.translate)
919
919
920 @LoginRequired()
920 @LoginRequired()
921 @NotAnonymous()
921 @NotAnonymous()
922 @HasRepoPermissionAnyDecorator(
922 @HasRepoPermissionAnyDecorator(
923 'repository.read', 'repository.write', 'repository.admin')
923 'repository.read', 'repository.write', 'repository.admin')
924 def pullrequest_repo_targets(self):
924 def pullrequest_repo_targets(self):
925 _ = self.request.translate
925 _ = self.request.translate
926 filter_query = self.request.GET.get('query')
926 filter_query = self.request.GET.get('query')
927
927
928 # get the parents
928 # get the parents
929 parent_target_repos = []
929 parent_target_repos = []
930 if self.db_repo.parent:
930 if self.db_repo.parent:
931 parents_query = Repository.query() \
931 parents_query = Repository.query() \
932 .order_by(func.length(Repository.repo_name)) \
932 .order_by(func.length(Repository.repo_name)) \
933 .filter(Repository.fork_id == self.db_repo.parent.repo_id)
933 .filter(Repository.fork_id == self.db_repo.parent.repo_id)
934
934
935 if filter_query:
935 if filter_query:
936 ilike_expression = u'%{}%'.format(safe_unicode(filter_query))
936 ilike_expression = u'%{}%'.format(safe_unicode(filter_query))
937 parents_query = parents_query.filter(
937 parents_query = parents_query.filter(
938 Repository.repo_name.ilike(ilike_expression))
938 Repository.repo_name.ilike(ilike_expression))
939 parents = parents_query.limit(20).all()
939 parents = parents_query.limit(20).all()
940
940
941 for parent in parents:
941 for parent in parents:
942 parent_vcs_obj = parent.scm_instance()
942 parent_vcs_obj = parent.scm_instance()
943 if parent_vcs_obj and not parent_vcs_obj.is_empty():
943 if parent_vcs_obj and not parent_vcs_obj.is_empty():
944 parent_target_repos.append(parent)
944 parent_target_repos.append(parent)
945
945
946 # get other forks, and repo itself
946 # get other forks, and repo itself
947 query = Repository.query() \
947 query = Repository.query() \
948 .order_by(func.length(Repository.repo_name)) \
948 .order_by(func.length(Repository.repo_name)) \
949 .filter(
949 .filter(
950 or_(Repository.repo_id == self.db_repo.repo_id, # repo itself
950 or_(Repository.repo_id == self.db_repo.repo_id, # repo itself
951 Repository.fork_id == self.db_repo.repo_id) # forks of this repo
951 Repository.fork_id == self.db_repo.repo_id) # forks of this repo
952 ) \
952 ) \
953 .filter(~Repository.repo_id.in_([x.repo_id for x in parent_target_repos]))
953 .filter(~Repository.repo_id.in_([x.repo_id for x in parent_target_repos]))
954
954
955 if filter_query:
955 if filter_query:
956 ilike_expression = u'%{}%'.format(safe_unicode(filter_query))
956 ilike_expression = u'%{}%'.format(safe_unicode(filter_query))
957 query = query.filter(Repository.repo_name.ilike(ilike_expression))
957 query = query.filter(Repository.repo_name.ilike(ilike_expression))
958
958
959 limit = max(20 - len(parent_target_repos), 5) # not less then 5
959 limit = max(20 - len(parent_target_repos), 5) # not less then 5
960 target_repos = query.limit(limit).all()
960 target_repos = query.limit(limit).all()
961
961
962 all_target_repos = target_repos + parent_target_repos
962 all_target_repos = target_repos + parent_target_repos
963
963
964 repos = []
964 repos = []
965 # This checks permissions to the repositories
965 # This checks permissions to the repositories
966 for obj in ScmModel().get_repos(all_target_repos):
966 for obj in ScmModel().get_repos(all_target_repos):
967 repos.append({
967 repos.append({
968 'id': obj['name'],
968 'id': obj['name'],
969 'text': obj['name'],
969 'text': obj['name'],
970 'type': 'repo',
970 'type': 'repo',
971 'repo_id': obj['dbrepo']['repo_id'],
971 'repo_id': obj['dbrepo']['repo_id'],
972 'repo_type': obj['dbrepo']['repo_type'],
972 'repo_type': obj['dbrepo']['repo_type'],
973 'private': obj['dbrepo']['private'],
973 'private': obj['dbrepo']['private'],
974
974
975 })
975 })
976
976
977 data = {
977 data = {
978 'more': False,
978 'more': False,
979 'results': [{
979 'results': [{
980 'text': _('Repositories'),
980 'text': _('Repositories'),
981 'children': repos
981 'children': repos
982 }] if repos else []
982 }] if repos else []
983 }
983 }
984 return data
984 return data
985
985
986 @classmethod
986 @classmethod
987 def get_comment_ids(cls, post_data):
987 def get_comment_ids(cls, post_data):
988 return filter(lambda e: e > 0, map(safe_int, aslist(post_data.get('comments'), ',')))
988 return filter(lambda e: e > 0, map(safe_int, aslist(post_data.get('comments'), ',')))
989
989
990 @LoginRequired()
990 @LoginRequired()
991 @NotAnonymous()
991 @NotAnonymous()
992 @HasRepoPermissionAnyDecorator(
992 @HasRepoPermissionAnyDecorator(
993 'repository.read', 'repository.write', 'repository.admin')
993 'repository.read', 'repository.write', 'repository.admin')
994 def pullrequest_comments(self):
994 def pullrequest_comments(self):
995 self.load_default_context()
995 self.load_default_context()
996
996
997 pull_request = PullRequest.get_or_404(
997 pull_request = PullRequest.get_or_404(
998 self.request.matchdict['pull_request_id'])
998 self.request.matchdict['pull_request_id'])
999 pull_request_id = pull_request.pull_request_id
999 pull_request_id = pull_request.pull_request_id
1000 version = self.request.GET.get('version')
1000 version = self.request.GET.get('version')
1001
1001
1002 _render = self.request.get_partial_renderer(
1002 _render = self.request.get_partial_renderer(
1003 'rhodecode:templates/base/sidebar.mako')
1003 'rhodecode:templates/base/sidebar.mako')
1004 c = _render.get_call_context()
1004 c = _render.get_call_context()
1005
1005
1006 (pull_request_latest,
1006 (pull_request_latest,
1007 pull_request_at_ver,
1007 pull_request_at_ver,
1008 pull_request_display_obj,
1008 pull_request_display_obj,
1009 at_version) = PullRequestModel().get_pr_version(
1009 at_version) = PullRequestModel().get_pr_version(
1010 pull_request_id, version=version)
1010 pull_request_id, version=version)
1011 versions = pull_request_display_obj.versions()
1011 versions = pull_request_display_obj.versions()
1012 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
1012 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
1013 c.versions = versions + [latest_ver]
1013 c.versions = versions + [latest_ver]
1014
1014
1015 c.at_version = at_version
1015 c.at_version = at_version
1016 c.at_version_num = (at_version
1016 c.at_version_num = (at_version
1017 if at_version and at_version != PullRequest.LATEST_VER
1017 if at_version and at_version != PullRequest.LATEST_VER
1018 else None)
1018 else None)
1019
1019
1020 self.register_comments_vars(c, pull_request_latest, versions, include_drafts=False)
1020 self.register_comments_vars(c, pull_request_latest, versions, include_drafts=False)
1021 all_comments = c.inline_comments_flat + c.comments
1021 all_comments = c.inline_comments_flat + c.comments
1022
1022
1023 existing_ids = self.get_comment_ids(self.request.POST)
1023 existing_ids = self.get_comment_ids(self.request.POST)
1024 return _render('comments_table', all_comments, len(all_comments),
1024 return _render('comments_table', all_comments, len(all_comments),
1025 existing_ids=existing_ids)
1025 existing_ids=existing_ids)
1026
1026
1027 @LoginRequired()
1027 @LoginRequired()
1028 @NotAnonymous()
1028 @NotAnonymous()
1029 @HasRepoPermissionAnyDecorator(
1029 @HasRepoPermissionAnyDecorator(
1030 'repository.read', 'repository.write', 'repository.admin')
1030 'repository.read', 'repository.write', 'repository.admin')
1031 def pullrequest_todos(self):
1031 def pullrequest_todos(self):
1032 self.load_default_context()
1032 self.load_default_context()
1033
1033
1034 pull_request = PullRequest.get_or_404(
1034 pull_request = PullRequest.get_or_404(
1035 self.request.matchdict['pull_request_id'])
1035 self.request.matchdict['pull_request_id'])
1036 pull_request_id = pull_request.pull_request_id
1036 pull_request_id = pull_request.pull_request_id
1037 version = self.request.GET.get('version')
1037 version = self.request.GET.get('version')
1038
1038
1039 _render = self.request.get_partial_renderer(
1039 _render = self.request.get_partial_renderer(
1040 'rhodecode:templates/base/sidebar.mako')
1040 'rhodecode:templates/base/sidebar.mako')
1041 c = _render.get_call_context()
1041 c = _render.get_call_context()
1042 (pull_request_latest,
1042 (pull_request_latest,
1043 pull_request_at_ver,
1043 pull_request_at_ver,
1044 pull_request_display_obj,
1044 pull_request_display_obj,
1045 at_version) = PullRequestModel().get_pr_version(
1045 at_version) = PullRequestModel().get_pr_version(
1046 pull_request_id, version=version)
1046 pull_request_id, version=version)
1047 versions = pull_request_display_obj.versions()
1047 versions = pull_request_display_obj.versions()
1048 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
1048 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
1049 c.versions = versions + [latest_ver]
1049 c.versions = versions + [latest_ver]
1050
1050
1051 c.at_version = at_version
1051 c.at_version = at_version
1052 c.at_version_num = (at_version
1052 c.at_version_num = (at_version
1053 if at_version and at_version != PullRequest.LATEST_VER
1053 if at_version and at_version != PullRequest.LATEST_VER
1054 else None)
1054 else None)
1055
1055
1056 c.unresolved_comments = CommentsModel() \
1056 c.unresolved_comments = CommentsModel() \
1057 .get_pull_request_unresolved_todos(pull_request, include_drafts=False)
1057 .get_pull_request_unresolved_todos(pull_request, include_drafts=False)
1058 c.resolved_comments = CommentsModel() \
1058 c.resolved_comments = CommentsModel() \
1059 .get_pull_request_resolved_todos(pull_request, include_drafts=False)
1059 .get_pull_request_resolved_todos(pull_request, include_drafts=False)
1060
1060
1061 all_comments = c.unresolved_comments + c.resolved_comments
1061 all_comments = c.unresolved_comments + c.resolved_comments
1062 existing_ids = self.get_comment_ids(self.request.POST)
1062 existing_ids = self.get_comment_ids(self.request.POST)
1063 return _render('comments_table', all_comments, len(c.unresolved_comments),
1063 return _render('comments_table', all_comments, len(c.unresolved_comments),
1064 todo_comments=True, existing_ids=existing_ids)
1064 todo_comments=True, existing_ids=existing_ids)
1065
1065
1066 @LoginRequired()
1066 @LoginRequired()
1067 @NotAnonymous()
1067 @NotAnonymous()
1068 @HasRepoPermissionAnyDecorator(
1068 @HasRepoPermissionAnyDecorator(
1069 'repository.read', 'repository.write', 'repository.admin')
1069 'repository.read', 'repository.write', 'repository.admin')
1070 def pullrequest_drafts(self):
1070 def pullrequest_drafts(self):
1071 self.load_default_context()
1071 self.load_default_context()
1072
1072
1073 pull_request = PullRequest.get_or_404(
1073 pull_request = PullRequest.get_or_404(
1074 self.request.matchdict['pull_request_id'])
1074 self.request.matchdict['pull_request_id'])
1075 pull_request_id = pull_request.pull_request_id
1075 pull_request_id = pull_request.pull_request_id
1076 version = self.request.GET.get('version')
1076 version = self.request.GET.get('version')
1077
1077
1078 _render = self.request.get_partial_renderer(
1078 _render = self.request.get_partial_renderer(
1079 'rhodecode:templates/base/sidebar.mako')
1079 'rhodecode:templates/base/sidebar.mako')
1080 c = _render.get_call_context()
1080 c = _render.get_call_context()
1081
1081
1082 (pull_request_latest,
1082 (pull_request_latest,
1083 pull_request_at_ver,
1083 pull_request_at_ver,
1084 pull_request_display_obj,
1084 pull_request_display_obj,
1085 at_version) = PullRequestModel().get_pr_version(
1085 at_version) = PullRequestModel().get_pr_version(
1086 pull_request_id, version=version)
1086 pull_request_id, version=version)
1087 versions = pull_request_display_obj.versions()
1087 versions = pull_request_display_obj.versions()
1088 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
1088 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
1089 c.versions = versions + [latest_ver]
1089 c.versions = versions + [latest_ver]
1090
1090
1091 c.at_version = at_version
1091 c.at_version = at_version
1092 c.at_version_num = (at_version
1092 c.at_version_num = (at_version
1093 if at_version and at_version != PullRequest.LATEST_VER
1093 if at_version and at_version != PullRequest.LATEST_VER
1094 else None)
1094 else None)
1095
1095
1096 c.draft_comments = CommentsModel() \
1096 c.draft_comments = CommentsModel() \
1097 .get_pull_request_drafts(self._rhodecode_db_user.user_id, pull_request)
1097 .get_pull_request_drafts(self._rhodecode_db_user.user_id, pull_request)
1098
1098
1099 all_comments = c.draft_comments
1099 all_comments = c.draft_comments
1100
1100
1101 existing_ids = self.get_comment_ids(self.request.POST)
1101 existing_ids = self.get_comment_ids(self.request.POST)
1102 return _render('comments_table', all_comments, len(all_comments),
1102 return _render('comments_table', all_comments, len(all_comments),
1103 existing_ids=existing_ids, draft_comments=True)
1103 existing_ids=existing_ids, draft_comments=True)
1104
1104
1105 @LoginRequired()
1105 @LoginRequired()
1106 @NotAnonymous()
1106 @NotAnonymous()
1107 @HasRepoPermissionAnyDecorator(
1107 @HasRepoPermissionAnyDecorator(
1108 'repository.read', 'repository.write', 'repository.admin')
1108 'repository.read', 'repository.write', 'repository.admin')
1109 @CSRFRequired()
1109 @CSRFRequired()
1110 def pull_request_create(self):
1110 def pull_request_create(self):
1111 _ = self.request.translate
1111 _ = self.request.translate
1112 self.assure_not_empty_repo()
1112 self.assure_not_empty_repo()
1113 self.load_default_context()
1113 self.load_default_context()
1114
1114
1115 controls = peppercorn.parse(self.request.POST.items())
1115 controls = peppercorn.parse(self.request.POST.items())
1116
1116
1117 try:
1117 try:
1118 form = PullRequestForm(
1118 form = PullRequestForm(
1119 self.request.translate, self.db_repo.repo_id)()
1119 self.request.translate, self.db_repo.repo_id)()
1120 _form = form.to_python(controls)
1120 _form = form.to_python(controls)
1121 except formencode.Invalid as errors:
1121 except formencode.Invalid as errors:
1122 if errors.error_dict.get('revisions'):
1122 if errors.error_dict.get('revisions'):
1123 msg = 'Revisions: %s' % errors.error_dict['revisions']
1123 msg = 'Revisions: %s' % errors.error_dict['revisions']
1124 elif errors.error_dict.get('pullrequest_title'):
1124 elif errors.error_dict.get('pullrequest_title'):
1125 msg = errors.error_dict.get('pullrequest_title')
1125 msg = errors.error_dict.get('pullrequest_title')
1126 else:
1126 else:
1127 msg = _('Error creating pull request: {}').format(errors)
1127 msg = _('Error creating pull request: {}').format(errors)
1128 log.exception(msg)
1128 log.exception(msg)
1129 h.flash(msg, 'error')
1129 h.flash(msg, 'error')
1130
1130
1131 # would rather just go back to form ...
1131 # would rather just go back to form ...
1132 raise HTTPFound(
1132 raise HTTPFound(
1133 h.route_path('pullrequest_new', repo_name=self.db_repo_name))
1133 h.route_path('pullrequest_new', repo_name=self.db_repo_name))
1134
1134
1135 source_repo = _form['source_repo']
1135 source_repo = _form['source_repo']
1136 source_ref = _form['source_ref']
1136 source_ref = _form['source_ref']
1137 target_repo = _form['target_repo']
1137 target_repo = _form['target_repo']
1138 target_ref = _form['target_ref']
1138 target_ref = _form['target_ref']
1139 commit_ids = _form['revisions'][::-1]
1139 commit_ids = _form['revisions'][::-1]
1140 common_ancestor_id = _form['common_ancestor']
1140 common_ancestor_id = _form['common_ancestor']
1141
1141
1142 # find the ancestor for this pr
1142 # find the ancestor for this pr
1143 source_db_repo = Repository.get_by_repo_name(_form['source_repo'])
1143 source_db_repo = Repository.get_by_repo_name(_form['source_repo'])
1144 target_db_repo = Repository.get_by_repo_name(_form['target_repo'])
1144 target_db_repo = Repository.get_by_repo_name(_form['target_repo'])
1145
1145
1146 if not (source_db_repo or target_db_repo):
1146 if not (source_db_repo or target_db_repo):
1147 h.flash(_('source_repo or target repo not found'), category='error')
1147 h.flash(_('source_repo or target repo not found'), category='error')
1148 raise HTTPFound(
1148 raise HTTPFound(
1149 h.route_path('pullrequest_new', repo_name=self.db_repo_name))
1149 h.route_path('pullrequest_new', repo_name=self.db_repo_name))
1150
1150
1151 # re-check permissions again here
1151 # re-check permissions again here
1152 # source_repo we must have read permissions
1152 # source_repo we must have read permissions
1153
1153
1154 source_perm = HasRepoPermissionAny(
1154 source_perm = HasRepoPermissionAny(
1155 'repository.read', 'repository.write', 'repository.admin')(
1155 'repository.read', 'repository.write', 'repository.admin')(
1156 source_db_repo.repo_name)
1156 source_db_repo.repo_name)
1157 if not source_perm:
1157 if not source_perm:
1158 msg = _('Not Enough permissions to source repo `{}`.'.format(
1158 msg = _('Not Enough permissions to source repo `{}`.'.format(
1159 source_db_repo.repo_name))
1159 source_db_repo.repo_name))
1160 h.flash(msg, category='error')
1160 h.flash(msg, category='error')
1161 # copy the args back to redirect
1161 # copy the args back to redirect
1162 org_query = self.request.GET.mixed()
1162 org_query = self.request.GET.mixed()
1163 raise HTTPFound(
1163 raise HTTPFound(
1164 h.route_path('pullrequest_new', repo_name=self.db_repo_name,
1164 h.route_path('pullrequest_new', repo_name=self.db_repo_name,
1165 _query=org_query))
1165 _query=org_query))
1166
1166
1167 # target repo we must have read permissions, and also later on
1167 # target repo we must have read permissions, and also later on
1168 # we want to check branch permissions here
1168 # we want to check branch permissions here
1169 target_perm = HasRepoPermissionAny(
1169 target_perm = HasRepoPermissionAny(
1170 'repository.read', 'repository.write', 'repository.admin')(
1170 'repository.read', 'repository.write', 'repository.admin')(
1171 target_db_repo.repo_name)
1171 target_db_repo.repo_name)
1172 if not target_perm:
1172 if not target_perm:
1173 msg = _('Not Enough permissions to target repo `{}`.'.format(
1173 msg = _('Not Enough permissions to target repo `{}`.'.format(
1174 target_db_repo.repo_name))
1174 target_db_repo.repo_name))
1175 h.flash(msg, category='error')
1175 h.flash(msg, category='error')
1176 # copy the args back to redirect
1176 # copy the args back to redirect
1177 org_query = self.request.GET.mixed()
1177 org_query = self.request.GET.mixed()
1178 raise HTTPFound(
1178 raise HTTPFound(
1179 h.route_path('pullrequest_new', repo_name=self.db_repo_name,
1179 h.route_path('pullrequest_new', repo_name=self.db_repo_name,
1180 _query=org_query))
1180 _query=org_query))
1181
1181
1182 source_scm = source_db_repo.scm_instance()
1182 source_scm = source_db_repo.scm_instance()
1183 target_scm = target_db_repo.scm_instance()
1183 target_scm = target_db_repo.scm_instance()
1184
1184
1185 source_ref_obj = unicode_to_reference(source_ref)
1185 source_ref_obj = unicode_to_reference(source_ref)
1186 target_ref_obj = unicode_to_reference(target_ref)
1186 target_ref_obj = unicode_to_reference(target_ref)
1187
1187
1188 source_commit = source_scm.get_commit(source_ref_obj.commit_id)
1188 source_commit = source_scm.get_commit(source_ref_obj.commit_id)
1189 target_commit = target_scm.get_commit(target_ref_obj.commit_id)
1189 target_commit = target_scm.get_commit(target_ref_obj.commit_id)
1190
1190
1191 ancestor = source_scm.get_common_ancestor(
1191 ancestor = source_scm.get_common_ancestor(
1192 source_commit.raw_id, target_commit.raw_id, target_scm)
1192 source_commit.raw_id, target_commit.raw_id, target_scm)
1193
1193
1194 # recalculate target ref based on ancestor
1194 # recalculate target ref based on ancestor
1195 target_ref = ':'.join((target_ref_obj.type, target_ref_obj.name, ancestor))
1195 target_ref = ':'.join((target_ref_obj.type, target_ref_obj.name, ancestor))
1196
1196
1197 get_default_reviewers_data, validate_default_reviewers, validate_observers = \
1197 get_default_reviewers_data, validate_default_reviewers, validate_observers = \
1198 PullRequestModel().get_reviewer_functions()
1198 PullRequestModel().get_reviewer_functions()
1199
1199
1200 # recalculate reviewers logic, to make sure we can validate this
1200 # recalculate reviewers logic, to make sure we can validate this
1201 reviewer_rules = get_default_reviewers_data(
1201 reviewer_rules = get_default_reviewers_data(
1202 self._rhodecode_db_user,
1202 self._rhodecode_db_user,
1203 source_db_repo,
1203 source_db_repo,
1204 source_ref_obj,
1204 source_ref_obj,
1205 target_db_repo,
1205 target_db_repo,
1206 target_ref_obj,
1206 target_ref_obj,
1207 include_diff_info=False)
1207 include_diff_info=False)
1208
1208
1209 reviewers = validate_default_reviewers(_form['review_members'], reviewer_rules)
1209 reviewers = validate_default_reviewers(_form['review_members'], reviewer_rules)
1210 observers = validate_observers(_form['observer_members'], reviewer_rules)
1210 observers = validate_observers(_form['observer_members'], reviewer_rules)
1211
1211
1212 pullrequest_title = _form['pullrequest_title']
1212 pullrequest_title = _form['pullrequest_title']
1213 title_source_ref = source_ref_obj.name
1213 title_source_ref = source_ref_obj.name
1214 if not pullrequest_title:
1214 if not pullrequest_title:
1215 pullrequest_title = PullRequestModel().generate_pullrequest_title(
1215 pullrequest_title = PullRequestModel().generate_pullrequest_title(
1216 source=source_repo,
1216 source=source_repo,
1217 source_ref=title_source_ref,
1217 source_ref=title_source_ref,
1218 target=target_repo
1218 target=target_repo
1219 )
1219 )
1220
1220
1221 description = _form['pullrequest_desc']
1221 description = _form['pullrequest_desc']
1222 description_renderer = _form['description_renderer']
1222 description_renderer = _form['description_renderer']
1223
1223
1224 try:
1224 try:
1225 pull_request = PullRequestModel().create(
1225 pull_request = PullRequestModel().create(
1226 created_by=self._rhodecode_user.user_id,
1226 created_by=self._rhodecode_user.user_id,
1227 source_repo=source_repo,
1227 source_repo=source_repo,
1228 source_ref=source_ref,
1228 source_ref=source_ref,
1229 target_repo=target_repo,
1229 target_repo=target_repo,
1230 target_ref=target_ref,
1230 target_ref=target_ref,
1231 revisions=commit_ids,
1231 revisions=commit_ids,
1232 common_ancestor_id=common_ancestor_id,
1232 common_ancestor_id=common_ancestor_id,
1233 reviewers=reviewers,
1233 reviewers=reviewers,
1234 observers=observers,
1234 observers=observers,
1235 title=pullrequest_title,
1235 title=pullrequest_title,
1236 description=description,
1236 description=description,
1237 description_renderer=description_renderer,
1237 description_renderer=description_renderer,
1238 reviewer_data=reviewer_rules,
1238 reviewer_data=reviewer_rules,
1239 auth_user=self._rhodecode_user
1239 auth_user=self._rhodecode_user
1240 )
1240 )
1241 Session().commit()
1241 Session().commit()
1242
1242
1243 h.flash(_('Successfully opened new pull request'),
1243 h.flash(_('Successfully opened new pull request'),
1244 category='success')
1244 category='success')
1245 except Exception:
1245 except Exception:
1246 msg = _('Error occurred during creation of this pull request.')
1246 msg = _('Error occurred during creation of this pull request.')
1247 log.exception(msg)
1247 log.exception(msg)
1248 h.flash(msg, category='error')
1248 h.flash(msg, category='error')
1249
1249
1250 # copy the args back to redirect
1250 # copy the args back to redirect
1251 org_query = self.request.GET.mixed()
1251 org_query = self.request.GET.mixed()
1252 raise HTTPFound(
1252 raise HTTPFound(
1253 h.route_path('pullrequest_new', repo_name=self.db_repo_name,
1253 h.route_path('pullrequest_new', repo_name=self.db_repo_name,
1254 _query=org_query))
1254 _query=org_query))
1255
1255
1256 raise HTTPFound(
1256 raise HTTPFound(
1257 h.route_path('pullrequest_show', repo_name=target_repo,
1257 h.route_path('pullrequest_show', repo_name=target_repo,
1258 pull_request_id=pull_request.pull_request_id))
1258 pull_request_id=pull_request.pull_request_id))
1259
1259
1260 @LoginRequired()
1260 @LoginRequired()
1261 @NotAnonymous()
1261 @NotAnonymous()
1262 @HasRepoPermissionAnyDecorator(
1262 @HasRepoPermissionAnyDecorator(
1263 'repository.read', 'repository.write', 'repository.admin')
1263 'repository.read', 'repository.write', 'repository.admin')
1264 @CSRFRequired()
1264 @CSRFRequired()
1265 def pull_request_update(self):
1265 def pull_request_update(self):
1266 pull_request = PullRequest.get_or_404(
1266 pull_request = PullRequest.get_or_404(
1267 self.request.matchdict['pull_request_id'])
1267 self.request.matchdict['pull_request_id'])
1268 _ = self.request.translate
1268 _ = self.request.translate
1269
1269
1270 c = self.load_default_context()
1270 c = self.load_default_context()
1271 redirect_url = None
1271 redirect_url = None
1272 # we do this check as first, because we want to know ASAP in the flow that
1272 # we do this check as first, because we want to know ASAP in the flow that
1273 # pr is updating currently
1273 # pr is updating currently
1274 is_state_changing = pull_request.is_state_changing()
1274 is_state_changing = pull_request.is_state_changing()
1275
1275
1276 if pull_request.is_closed():
1276 if pull_request.is_closed():
1277 log.debug('update: forbidden because pull request is closed')
1277 log.debug('update: forbidden because pull request is closed')
1278 msg = _(u'Cannot update closed pull requests.')
1278 msg = _(u'Cannot update closed pull requests.')
1279 h.flash(msg, category='error')
1279 h.flash(msg, category='error')
1280 return {'response': True,
1280 return {'response': True,
1281 'redirect_url': redirect_url}
1281 'redirect_url': redirect_url}
1282
1282
1283 c.pr_broadcast_channel = channelstream.pr_channel(pull_request)
1283 c.pr_broadcast_channel = channelstream.pr_channel(pull_request)
1284
1284
1285 # only owner or admin can update it
1285 # only owner or admin can update it
1286 allowed_to_update = PullRequestModel().check_user_update(
1286 allowed_to_update = PullRequestModel().check_user_update(
1287 pull_request, self._rhodecode_user)
1287 pull_request, self._rhodecode_user)
1288
1288
1289 if allowed_to_update:
1289 if allowed_to_update:
1290 controls = peppercorn.parse(self.request.POST.items())
1290 controls = peppercorn.parse(self.request.POST.items())
1291 force_refresh = str2bool(self.request.POST.get('force_refresh', 'false'))
1291 force_refresh = str2bool(self.request.POST.get('force_refresh', 'false'))
1292 do_update_commits = str2bool(self.request.POST.get('update_commits', 'false'))
1292 do_update_commits = str2bool(self.request.POST.get('update_commits', 'false'))
1293
1293
1294 if 'review_members' in controls:
1294 if 'review_members' in controls:
1295 self._update_reviewers(
1295 self._update_reviewers(
1296 c,
1296 c,
1297 pull_request, controls['review_members'],
1297 pull_request, controls['review_members'],
1298 pull_request.reviewer_data,
1298 pull_request.reviewer_data,
1299 PullRequestReviewers.ROLE_REVIEWER)
1299 PullRequestReviewers.ROLE_REVIEWER)
1300 elif 'observer_members' in controls:
1300 elif 'observer_members' in controls:
1301 self._update_reviewers(
1301 self._update_reviewers(
1302 c,
1302 c,
1303 pull_request, controls['observer_members'],
1303 pull_request, controls['observer_members'],
1304 pull_request.reviewer_data,
1304 pull_request.reviewer_data,
1305 PullRequestReviewers.ROLE_OBSERVER)
1305 PullRequestReviewers.ROLE_OBSERVER)
1306 elif do_update_commits:
1306 elif do_update_commits:
1307 if is_state_changing:
1307 if is_state_changing:
1308 log.debug('commits update: forbidden because pull request is in state %s',
1308 log.debug('commits update: forbidden because pull request is in state %s',
1309 pull_request.pull_request_state)
1309 pull_request.pull_request_state)
1310 msg = _(u'Cannot update pull requests commits in state other than `{}`. '
1310 msg = _(u'Cannot update pull requests commits in state other than `{}`. '
1311 u'Current state is: `{}`').format(
1311 u'Current state is: `{}`').format(
1312 PullRequest.STATE_CREATED, pull_request.pull_request_state)
1312 PullRequest.STATE_CREATED, pull_request.pull_request_state)
1313 h.flash(msg, category='error')
1313 h.flash(msg, category='error')
1314 return {'response': True,
1314 return {'response': True,
1315 'redirect_url': redirect_url}
1315 'redirect_url': redirect_url}
1316
1316
1317 self._update_commits(c, pull_request)
1317 self._update_commits(c, pull_request)
1318 if force_refresh:
1318 if force_refresh:
1319 redirect_url = h.route_path(
1319 redirect_url = h.route_path(
1320 'pullrequest_show', repo_name=self.db_repo_name,
1320 'pullrequest_show', repo_name=self.db_repo_name,
1321 pull_request_id=pull_request.pull_request_id,
1321 pull_request_id=pull_request.pull_request_id,
1322 _query={"force_refresh": 1})
1322 _query={"force_refresh": 1})
1323 elif str2bool(self.request.POST.get('edit_pull_request', 'false')):
1323 elif str2bool(self.request.POST.get('edit_pull_request', 'false')):
1324 self._edit_pull_request(pull_request)
1324 self._edit_pull_request(pull_request)
1325 else:
1325 else:
1326 log.error('Unhandled update data.')
1326 log.error('Unhandled update data.')
1327 raise HTTPBadRequest()
1327 raise HTTPBadRequest()
1328
1328
1329 return {'response': True,
1329 return {'response': True,
1330 'redirect_url': redirect_url}
1330 'redirect_url': redirect_url}
1331 raise HTTPForbidden()
1331 raise HTTPForbidden()
1332
1332
1333 def _edit_pull_request(self, pull_request):
1333 def _edit_pull_request(self, pull_request):
1334 """
1334 """
1335 Edit title and description
1335 Edit title and description
1336 """
1336 """
1337 _ = self.request.translate
1337 _ = self.request.translate
1338
1338
1339 try:
1339 try:
1340 PullRequestModel().edit(
1340 PullRequestModel().edit(
1341 pull_request,
1341 pull_request,
1342 self.request.POST.get('title'),
1342 self.request.POST.get('title'),
1343 self.request.POST.get('description'),
1343 self.request.POST.get('description'),
1344 self.request.POST.get('description_renderer'),
1344 self.request.POST.get('description_renderer'),
1345 self._rhodecode_user)
1345 self._rhodecode_user)
1346 except ValueError:
1346 except ValueError:
1347 msg = _(u'Cannot update closed pull requests.')
1347 msg = _(u'Cannot update closed pull requests.')
1348 h.flash(msg, category='error')
1348 h.flash(msg, category='error')
1349 return
1349 return
1350 else:
1350 else:
1351 Session().commit()
1351 Session().commit()
1352
1352
1353 msg = _(u'Pull request title & description updated.')
1353 msg = _(u'Pull request title & description updated.')
1354 h.flash(msg, category='success')
1354 h.flash(msg, category='success')
1355 return
1355 return
1356
1356
1357 def _update_commits(self, c, pull_request):
1357 def _update_commits(self, c, pull_request):
1358 _ = self.request.translate
1358 _ = self.request.translate
1359 log.debug('pull-request: running update commits actions')
1359 log.debug('pull-request: running update commits actions')
1360
1360
1361 @retry(exception=Exception, n_tries=3, delay=2)
1361 @retry(exception=Exception, n_tries=3, delay=2)
1362 def commits_update():
1362 def commits_update():
1363 return PullRequestModel().update_commits(
1363 return PullRequestModel().update_commits(
1364 pull_request, self._rhodecode_db_user)
1364 pull_request, self._rhodecode_db_user)
1365
1365
1366 with pull_request.set_state(PullRequest.STATE_UPDATING):
1366 with pull_request.set_state(PullRequest.STATE_UPDATING):
1367 resp = commits_update() # retry x3
1367 resp = commits_update() # retry x3
1368
1368
1369 if resp.executed:
1369 if resp.executed:
1370
1370
1371 if resp.target_changed and resp.source_changed:
1371 if resp.target_changed and resp.source_changed:
1372 changed = 'target and source repositories'
1372 changed = 'target and source repositories'
1373 elif resp.target_changed and not resp.source_changed:
1373 elif resp.target_changed and not resp.source_changed:
1374 changed = 'target repository'
1374 changed = 'target repository'
1375 elif not resp.target_changed and resp.source_changed:
1375 elif not resp.target_changed and resp.source_changed:
1376 changed = 'source repository'
1376 changed = 'source repository'
1377 else:
1377 else:
1378 changed = 'nothing'
1378 changed = 'nothing'
1379
1379
1380 msg = _(u'Pull request updated to "{source_commit_id}" with '
1380 msg = _(u'Pull request updated to "{source_commit_id}" with '
1381 u'{count_added} added, {count_removed} removed commits. '
1381 u'{count_added} added, {count_removed} removed commits. '
1382 u'Source of changes: {change_source}.')
1382 u'Source of changes: {change_source}.')
1383 msg = msg.format(
1383 msg = msg.format(
1384 source_commit_id=pull_request.source_ref_parts.commit_id,
1384 source_commit_id=pull_request.source_ref_parts.commit_id,
1385 count_added=len(resp.changes.added),
1385 count_added=len(resp.changes.added),
1386 count_removed=len(resp.changes.removed),
1386 count_removed=len(resp.changes.removed),
1387 change_source=changed)
1387 change_source=changed)
1388 h.flash(msg, category='success')
1388 h.flash(msg, category='success')
1389 channelstream.pr_update_channelstream_push(
1389 channelstream.pr_update_channelstream_push(
1390 self.request, c.pr_broadcast_channel, self._rhodecode_user, msg)
1390 self.request, c.pr_broadcast_channel, self._rhodecode_user, msg)
1391 else:
1391 else:
1392 msg = PullRequestModel.UPDATE_STATUS_MESSAGES[resp.reason]
1392 msg = PullRequestModel.UPDATE_STATUS_MESSAGES[resp.reason]
1393 warning_reasons = [
1393 warning_reasons = [
1394 UpdateFailureReason.NO_CHANGE,
1394 UpdateFailureReason.NO_CHANGE,
1395 UpdateFailureReason.WRONG_REF_TYPE,
1395 UpdateFailureReason.WRONG_REF_TYPE,
1396 ]
1396 ]
1397 category = 'warning' if resp.reason in warning_reasons else 'error'
1397 category = 'warning' if resp.reason in warning_reasons else 'error'
1398 h.flash(msg, category=category)
1398 h.flash(msg, category=category)
1399
1399
1400 def _update_reviewers(self, c, pull_request, review_members, reviewer_rules, role):
1400 def _update_reviewers(self, c, pull_request, review_members, reviewer_rules, role):
1401 _ = self.request.translate
1401 _ = self.request.translate
1402
1402
1403 get_default_reviewers_data, validate_default_reviewers, validate_observers = \
1403 get_default_reviewers_data, validate_default_reviewers, validate_observers = \
1404 PullRequestModel().get_reviewer_functions()
1404 PullRequestModel().get_reviewer_functions()
1405
1405
1406 if role == PullRequestReviewers.ROLE_REVIEWER:
1406 if role == PullRequestReviewers.ROLE_REVIEWER:
1407 try:
1407 try:
1408 reviewers = validate_default_reviewers(review_members, reviewer_rules)
1408 reviewers = validate_default_reviewers(review_members, reviewer_rules)
1409 except ValueError as e:
1409 except ValueError as e:
1410 log.error('Reviewers Validation: {}'.format(e))
1410 log.error('Reviewers Validation: {}'.format(e))
1411 h.flash(e, category='error')
1411 h.flash(e, category='error')
1412 return
1412 return
1413
1413
1414 old_calculated_status = pull_request.calculated_review_status()
1414 old_calculated_status = pull_request.calculated_review_status()
1415 PullRequestModel().update_reviewers(
1415 PullRequestModel().update_reviewers(
1416 pull_request, reviewers, self._rhodecode_db_user)
1416 pull_request, reviewers, self._rhodecode_db_user)
1417
1417
1418 Session().commit()
1418 Session().commit()
1419
1419
1420 msg = _('Pull request reviewers updated.')
1420 msg = _('Pull request reviewers updated.')
1421 h.flash(msg, category='success')
1421 h.flash(msg, category='success')
1422 channelstream.pr_update_channelstream_push(
1422 channelstream.pr_update_channelstream_push(
1423 self.request, c.pr_broadcast_channel, self._rhodecode_user, msg)
1423 self.request, c.pr_broadcast_channel, self._rhodecode_user, msg)
1424
1424
1425 # trigger status changed if change in reviewers changes the status
1425 # trigger status changed if change in reviewers changes the status
1426 calculated_status = pull_request.calculated_review_status()
1426 calculated_status = pull_request.calculated_review_status()
1427 if old_calculated_status != calculated_status:
1427 if old_calculated_status != calculated_status:
1428 PullRequestModel().trigger_pull_request_hook(
1428 PullRequestModel().trigger_pull_request_hook(
1429 pull_request, self._rhodecode_user, 'review_status_change',
1429 pull_request, self._rhodecode_user, 'review_status_change',
1430 data={'status': calculated_status})
1430 data={'status': calculated_status})
1431
1431
1432 elif role == PullRequestReviewers.ROLE_OBSERVER:
1432 elif role == PullRequestReviewers.ROLE_OBSERVER:
1433 try:
1433 try:
1434 observers = validate_observers(review_members, reviewer_rules)
1434 observers = validate_observers(review_members, reviewer_rules)
1435 except ValueError as e:
1435 except ValueError as e:
1436 log.error('Observers Validation: {}'.format(e))
1436 log.error('Observers Validation: {}'.format(e))
1437 h.flash(e, category='error')
1437 h.flash(e, category='error')
1438 return
1438 return
1439
1439
1440 PullRequestModel().update_observers(
1440 PullRequestModel().update_observers(
1441 pull_request, observers, self._rhodecode_db_user)
1441 pull_request, observers, self._rhodecode_db_user)
1442
1442
1443 Session().commit()
1443 Session().commit()
1444 msg = _('Pull request observers updated.')
1444 msg = _('Pull request observers updated.')
1445 h.flash(msg, category='success')
1445 h.flash(msg, category='success')
1446 channelstream.pr_update_channelstream_push(
1446 channelstream.pr_update_channelstream_push(
1447 self.request, c.pr_broadcast_channel, self._rhodecode_user, msg)
1447 self.request, c.pr_broadcast_channel, self._rhodecode_user, msg)
1448
1448
1449 @LoginRequired()
1449 @LoginRequired()
1450 @NotAnonymous()
1450 @NotAnonymous()
1451 @HasRepoPermissionAnyDecorator(
1451 @HasRepoPermissionAnyDecorator(
1452 'repository.read', 'repository.write', 'repository.admin')
1452 'repository.read', 'repository.write', 'repository.admin')
1453 @CSRFRequired()
1453 @CSRFRequired()
1454 def pull_request_merge(self):
1454 def pull_request_merge(self):
1455 """
1455 """
1456 Merge will perform a server-side merge of the specified
1456 Merge will perform a server-side merge of the specified
1457 pull request, if the pull request is approved and mergeable.
1457 pull request, if the pull request is approved and mergeable.
1458 After successful merging, the pull request is automatically
1458 After successful merging, the pull request is automatically
1459 closed, with a relevant comment.
1459 closed, with a relevant comment.
1460 """
1460 """
1461 pull_request = PullRequest.get_or_404(
1461 pull_request = PullRequest.get_or_404(
1462 self.request.matchdict['pull_request_id'])
1462 self.request.matchdict['pull_request_id'])
1463 _ = self.request.translate
1463 _ = self.request.translate
1464
1464
1465 if pull_request.is_state_changing():
1465 if pull_request.is_state_changing():
1466 log.debug('show: forbidden because pull request is in state %s',
1466 log.debug('show: forbidden because pull request is in state %s',
1467 pull_request.pull_request_state)
1467 pull_request.pull_request_state)
1468 msg = _(u'Cannot merge pull requests in state other than `{}`. '
1468 msg = _(u'Cannot merge pull requests in state other than `{}`. '
1469 u'Current state is: `{}`').format(PullRequest.STATE_CREATED,
1469 u'Current state is: `{}`').format(PullRequest.STATE_CREATED,
1470 pull_request.pull_request_state)
1470 pull_request.pull_request_state)
1471 h.flash(msg, category='error')
1471 h.flash(msg, category='error')
1472 raise HTTPFound(
1472 raise HTTPFound(
1473 h.route_path('pullrequest_show',
1473 h.route_path('pullrequest_show',
1474 repo_name=pull_request.target_repo.repo_name,
1474 repo_name=pull_request.target_repo.repo_name,
1475 pull_request_id=pull_request.pull_request_id))
1475 pull_request_id=pull_request.pull_request_id))
1476
1476
1477 self.load_default_context()
1477 self.load_default_context()
1478
1478
1479 with pull_request.set_state(PullRequest.STATE_UPDATING):
1479 with pull_request.set_state(PullRequest.STATE_UPDATING):
1480 check = MergeCheck.validate(
1480 check = MergeCheck.validate(
1481 pull_request, auth_user=self._rhodecode_user,
1481 pull_request, auth_user=self._rhodecode_user,
1482 translator=self.request.translate)
1482 translator=self.request.translate)
1483 merge_possible = not check.failed
1483 merge_possible = not check.failed
1484
1484
1485 for err_type, error_msg in check.errors:
1485 for err_type, error_msg in check.errors:
1486 h.flash(error_msg, category=err_type)
1486 h.flash(error_msg, category=err_type)
1487
1487
1488 if merge_possible:
1488 if merge_possible:
1489 log.debug("Pre-conditions checked, trying to merge.")
1489 log.debug("Pre-conditions checked, trying to merge.")
1490 extras = vcs_operation_context(
1490 extras = vcs_operation_context(
1491 self.request.environ, repo_name=pull_request.target_repo.repo_name,
1491 self.request.environ, repo_name=pull_request.target_repo.repo_name,
1492 username=self._rhodecode_db_user.username, action='push',
1492 username=self._rhodecode_db_user.username, action='push',
1493 scm=pull_request.target_repo.repo_type)
1493 scm=pull_request.target_repo.repo_type)
1494 with pull_request.set_state(PullRequest.STATE_UPDATING):
1494 with pull_request.set_state(PullRequest.STATE_UPDATING):
1495 self._merge_pull_request(
1495 self._merge_pull_request(
1496 pull_request, self._rhodecode_db_user, extras)
1496 pull_request, self._rhodecode_db_user, extras)
1497 else:
1497 else:
1498 log.debug("Pre-conditions failed, NOT merging.")
1498 log.debug("Pre-conditions failed, NOT merging.")
1499
1499
1500 raise HTTPFound(
1500 raise HTTPFound(
1501 h.route_path('pullrequest_show',
1501 h.route_path('pullrequest_show',
1502 repo_name=pull_request.target_repo.repo_name,
1502 repo_name=pull_request.target_repo.repo_name,
1503 pull_request_id=pull_request.pull_request_id))
1503 pull_request_id=pull_request.pull_request_id))
1504
1504
1505 def _merge_pull_request(self, pull_request, user, extras):
1505 def _merge_pull_request(self, pull_request, user, extras):
1506 _ = self.request.translate
1506 _ = self.request.translate
1507 merge_resp = PullRequestModel().merge_repo(pull_request, user, extras=extras)
1507 merge_resp = PullRequestModel().merge_repo(pull_request, user, extras=extras)
1508
1508
1509 if merge_resp.executed:
1509 if merge_resp.executed:
1510 log.debug("The merge was successful, closing the pull request.")
1510 log.debug("The merge was successful, closing the pull request.")
1511 PullRequestModel().close_pull_request(
1511 PullRequestModel().close_pull_request(
1512 pull_request.pull_request_id, user)
1512 pull_request.pull_request_id, user)
1513 Session().commit()
1513 Session().commit()
1514 msg = _('Pull request was successfully merged and closed.')
1514 msg = _('Pull request was successfully merged and closed.')
1515 h.flash(msg, category='success')
1515 h.flash(msg, category='success')
1516 else:
1516 else:
1517 log.debug(
1517 log.debug(
1518 "The merge was not successful. Merge response: %s", merge_resp)
1518 "The merge was not successful. Merge response: %s", merge_resp)
1519 msg = merge_resp.merge_status_message
1519 msg = merge_resp.merge_status_message
1520 h.flash(msg, category='error')
1520 h.flash(msg, category='error')
1521
1521
1522 @LoginRequired()
1522 @LoginRequired()
1523 @NotAnonymous()
1523 @NotAnonymous()
1524 @HasRepoPermissionAnyDecorator(
1524 @HasRepoPermissionAnyDecorator(
1525 'repository.read', 'repository.write', 'repository.admin')
1525 'repository.read', 'repository.write', 'repository.admin')
1526 @CSRFRequired()
1526 @CSRFRequired()
1527 def pull_request_delete(self):
1527 def pull_request_delete(self):
1528 _ = self.request.translate
1528 _ = self.request.translate
1529
1529
1530 pull_request = PullRequest.get_or_404(
1530 pull_request = PullRequest.get_or_404(
1531 self.request.matchdict['pull_request_id'])
1531 self.request.matchdict['pull_request_id'])
1532 self.load_default_context()
1532 self.load_default_context()
1533
1533
1534 pr_closed = pull_request.is_closed()
1534 pr_closed = pull_request.is_closed()
1535 allowed_to_delete = PullRequestModel().check_user_delete(
1535 allowed_to_delete = PullRequestModel().check_user_delete(
1536 pull_request, self._rhodecode_user) and not pr_closed
1536 pull_request, self._rhodecode_user) and not pr_closed
1537
1537
1538 # only owner can delete it !
1538 # only owner can delete it !
1539 if allowed_to_delete:
1539 if allowed_to_delete:
1540 PullRequestModel().delete(pull_request, self._rhodecode_user)
1540 PullRequestModel().delete(pull_request, self._rhodecode_user)
1541 Session().commit()
1541 Session().commit()
1542 h.flash(_('Successfully deleted pull request'),
1542 h.flash(_('Successfully deleted pull request'),
1543 category='success')
1543 category='success')
1544 raise HTTPFound(h.route_path('pullrequest_show_all',
1544 raise HTTPFound(h.route_path('pullrequest_show_all',
1545 repo_name=self.db_repo_name))
1545 repo_name=self.db_repo_name))
1546
1546
1547 log.warning('user %s tried to delete pull request without access',
1547 log.warning('user %s tried to delete pull request without access',
1548 self._rhodecode_user)
1548 self._rhodecode_user)
1549 raise HTTPNotFound()
1549 raise HTTPNotFound()
1550
1550
1551 def _pull_request_comments_create(self, pull_request, comments):
1551 def _pull_request_comments_create(self, pull_request, comments):
1552 _ = self.request.translate
1552 _ = self.request.translate
1553 data = {}
1553 data = {}
1554 if not comments:
1554 if not comments:
1555 return
1555 return
1556 pull_request_id = pull_request.pull_request_id
1556 pull_request_id = pull_request.pull_request_id
1557
1557
1558 all_drafts = len([x for x in comments if str2bool(x['is_draft'])]) == len(comments)
1558 all_drafts = len([x for x in comments if str2bool(x['is_draft'])]) == len(comments)
1559
1559
1560 for entry in comments:
1560 for entry in comments:
1561 c = self.load_default_context()
1561 c = self.load_default_context()
1562 comment_type = entry['comment_type']
1562 comment_type = entry['comment_type']
1563 text = entry['text']
1563 text = entry['text']
1564 status = entry['status']
1564 status = entry['status']
1565 is_draft = str2bool(entry['is_draft'])
1565 is_draft = str2bool(entry['is_draft'])
1566 resolves_comment_id = entry['resolves_comment_id']
1566 resolves_comment_id = entry['resolves_comment_id']
1567 close_pull_request = entry['close_pull_request']
1567 close_pull_request = entry['close_pull_request']
1568 f_path = entry['f_path']
1568 f_path = entry['f_path']
1569 line_no = entry['line']
1569 line_no = entry['line']
1570 target_elem_id = 'file-{}'.format(h.safeid(h.safe_unicode(f_path)))
1570 target_elem_id = 'file-{}'.format(h.safeid(h.safe_unicode(f_path)))
1571
1571
1572 # the logic here should work like following, if we submit close
1572 # the logic here should work like following, if we submit close
1573 # pr comment, use `close_pull_request_with_comment` function
1573 # pr comment, use `close_pull_request_with_comment` function
1574 # else handle regular comment logic
1574 # else handle regular comment logic
1575
1575
1576 if close_pull_request:
1576 if close_pull_request:
1577 # only owner or admin or person with write permissions
1577 # only owner or admin or person with write permissions
1578 allowed_to_close = PullRequestModel().check_user_update(
1578 allowed_to_close = PullRequestModel().check_user_update(
1579 pull_request, self._rhodecode_user)
1579 pull_request, self._rhodecode_user)
1580 if not allowed_to_close:
1580 if not allowed_to_close:
1581 log.debug('comment: forbidden because not allowed to close '
1581 log.debug('comment: forbidden because not allowed to close '
1582 'pull request %s', pull_request_id)
1582 'pull request %s', pull_request_id)
1583 raise HTTPForbidden()
1583 raise HTTPForbidden()
1584
1584
1585 # This also triggers `review_status_change`
1585 # This also triggers `review_status_change`
1586 comment, status = PullRequestModel().close_pull_request_with_comment(
1586 comment, status = PullRequestModel().close_pull_request_with_comment(
1587 pull_request, self._rhodecode_user, self.db_repo, message=text,
1587 pull_request, self._rhodecode_user, self.db_repo, message=text,
1588 auth_user=self._rhodecode_user)
1588 auth_user=self._rhodecode_user)
1589 Session().flush()
1589 Session().flush()
1590 is_inline = comment.is_inline
1590 is_inline = comment.is_inline
1591
1591
1592 PullRequestModel().trigger_pull_request_hook(
1592 PullRequestModel().trigger_pull_request_hook(
1593 pull_request, self._rhodecode_user, 'comment',
1593 pull_request, self._rhodecode_user, 'comment',
1594 data={'comment': comment})
1594 data={'comment': comment})
1595
1595
1596 else:
1596 else:
1597 # regular comment case, could be inline, or one with status.
1597 # regular comment case, could be inline, or one with status.
1598 # for that one we check also permissions
1598 # for that one we check also permissions
1599 # Additionally ENSURE if somehow draft is sent we're then unable to change status
1599 # Additionally ENSURE if somehow draft is sent we're then unable to change status
1600 allowed_to_change_status = PullRequestModel().check_user_change_status(
1600 allowed_to_change_status = PullRequestModel().check_user_change_status(
1601 pull_request, self._rhodecode_user) and not is_draft
1601 pull_request, self._rhodecode_user) and not is_draft
1602
1602
1603 if status and allowed_to_change_status:
1603 if status and allowed_to_change_status:
1604 message = (_('Status change %(transition_icon)s %(status)s')
1604 message = (_('Status change %(transition_icon)s %(status)s')
1605 % {'transition_icon': '>',
1605 % {'transition_icon': '>',
1606 'status': ChangesetStatus.get_status_lbl(status)})
1606 'status': ChangesetStatus.get_status_lbl(status)})
1607 text = text or message
1607 text = text or message
1608
1608
1609 comment = CommentsModel().create(
1609 comment = CommentsModel().create(
1610 text=text,
1610 text=text,
1611 repo=self.db_repo.repo_id,
1611 repo=self.db_repo.repo_id,
1612 user=self._rhodecode_user.user_id,
1612 user=self._rhodecode_user.user_id,
1613 pull_request=pull_request,
1613 pull_request=pull_request,
1614 f_path=f_path,
1614 f_path=f_path,
1615 line_no=line_no,
1615 line_no=line_no,
1616 status_change=(ChangesetStatus.get_status_lbl(status)
1616 status_change=(ChangesetStatus.get_status_lbl(status)
1617 if status and allowed_to_change_status else None),
1617 if status and allowed_to_change_status else None),
1618 status_change_type=(status
1618 status_change_type=(status
1619 if status and allowed_to_change_status else None),
1619 if status and allowed_to_change_status else None),
1620 comment_type=comment_type,
1620 comment_type=comment_type,
1621 is_draft=is_draft,
1621 is_draft=is_draft,
1622 resolves_comment_id=resolves_comment_id,
1622 resolves_comment_id=resolves_comment_id,
1623 auth_user=self._rhodecode_user,
1623 auth_user=self._rhodecode_user,
1624 send_email=not is_draft, # skip notification for draft comments
1624 send_email=not is_draft, # skip notification for draft comments
1625 )
1625 )
1626 is_inline = comment.is_inline
1626 is_inline = comment.is_inline
1627
1627
1628 if allowed_to_change_status:
1628 if allowed_to_change_status:
1629 # calculate old status before we change it
1629 # calculate old status before we change it
1630 old_calculated_status = pull_request.calculated_review_status()
1630 old_calculated_status = pull_request.calculated_review_status()
1631
1631
1632 # get status if set !
1632 # get status if set !
1633 if status:
1633 if status:
1634 ChangesetStatusModel().set_status(
1634 ChangesetStatusModel().set_status(
1635 self.db_repo.repo_id,
1635 self.db_repo.repo_id,
1636 status,
1636 status,
1637 self._rhodecode_user.user_id,
1637 self._rhodecode_user.user_id,
1638 comment,
1638 comment,
1639 pull_request=pull_request
1639 pull_request=pull_request
1640 )
1640 )
1641
1641
1642 Session().flush()
1642 Session().flush()
1643 # this is somehow required to get access to some relationship
1643 # this is somehow required to get access to some relationship
1644 # loaded on comment
1644 # loaded on comment
1645 Session().refresh(comment)
1645 Session().refresh(comment)
1646
1646
1647 # skip notifications for drafts
1647 # skip notifications for drafts
1648 if not is_draft:
1648 if not is_draft:
1649 PullRequestModel().trigger_pull_request_hook(
1649 PullRequestModel().trigger_pull_request_hook(
1650 pull_request, self._rhodecode_user, 'comment',
1650 pull_request, self._rhodecode_user, 'comment',
1651 data={'comment': comment})
1651 data={'comment': comment})
1652
1652
1653 # we now calculate the status of pull request, and based on that
1653 # we now calculate the status of pull request, and based on that
1654 # calculation we set the commits status
1654 # calculation we set the commits status
1655 calculated_status = pull_request.calculated_review_status()
1655 calculated_status = pull_request.calculated_review_status()
1656 if old_calculated_status != calculated_status:
1656 if old_calculated_status != calculated_status:
1657 PullRequestModel().trigger_pull_request_hook(
1657 PullRequestModel().trigger_pull_request_hook(
1658 pull_request, self._rhodecode_user, 'review_status_change',
1658 pull_request, self._rhodecode_user, 'review_status_change',
1659 data={'status': calculated_status})
1659 data={'status': calculated_status})
1660
1660
1661 comment_id = comment.comment_id
1661 comment_id = comment.comment_id
1662 data[comment_id] = {
1662 data[comment_id] = {
1663 'target_id': target_elem_id
1663 'target_id': target_elem_id
1664 }
1664 }
1665 Session().flush()
1665 Session().flush()
1666
1666
1667 c.co = comment
1667 c.co = comment
1668 c.at_version_num = None
1668 c.at_version_num = None
1669 c.is_new = True
1669 c.is_new = True
1670 rendered_comment = render(
1670 rendered_comment = render(
1671 'rhodecode:templates/changeset/changeset_comment_block.mako',
1671 'rhodecode:templates/changeset/changeset_comment_block.mako',
1672 self._get_template_context(c), self.request)
1672 self._get_template_context(c), self.request)
1673
1673
1674 data[comment_id].update(comment.get_dict())
1674 data[comment_id].update(comment.get_dict())
1675 data[comment_id].update({'rendered_text': rendered_comment})
1675 data[comment_id].update({'rendered_text': rendered_comment})
1676
1676
1677 Session().commit()
1677 Session().commit()
1678
1678
1679 # skip channelstream for draft comments
1679 # skip channelstream for draft comments
1680 if not all_drafts:
1680 if not all_drafts:
1681 comment_broadcast_channel = channelstream.comment_channel(
1681 comment_broadcast_channel = channelstream.comment_channel(
1682 self.db_repo_name, pull_request_obj=pull_request)
1682 self.db_repo_name, pull_request_obj=pull_request)
1683
1683
1684 comment_data = data
1684 comment_data = data
1685 posted_comment_type = 'inline' if is_inline else 'general'
1685 posted_comment_type = 'inline' if is_inline else 'general'
1686 if len(data) == 1:
1686 if len(data) == 1:
1687 msg = _('posted {} new {} comment').format(len(data), posted_comment_type)
1687 msg = _('posted {} new {} comment').format(len(data), posted_comment_type)
1688 else:
1688 else:
1689 msg = _('posted {} new {} comments').format(len(data), posted_comment_type)
1689 msg = _('posted {} new {} comments').format(len(data), posted_comment_type)
1690
1690
1691 channelstream.comment_channelstream_push(
1691 channelstream.comment_channelstream_push(
1692 self.request, comment_broadcast_channel, self._rhodecode_user, msg,
1692 self.request, comment_broadcast_channel, self._rhodecode_user, msg,
1693 comment_data=comment_data)
1693 comment_data=comment_data)
1694
1694
1695 return data
1695 return data
1696
1696
1697 @LoginRequired()
1697 @LoginRequired()
1698 @NotAnonymous()
1698 @NotAnonymous()
1699 @HasRepoPermissionAnyDecorator(
1699 @HasRepoPermissionAnyDecorator(
1700 'repository.read', 'repository.write', 'repository.admin')
1700 'repository.read', 'repository.write', 'repository.admin')
1701 @CSRFRequired()
1701 @CSRFRequired()
1702 def pull_request_comment_create(self):
1702 def pull_request_comment_create(self):
1703 _ = self.request.translate
1703 _ = self.request.translate
1704
1704
1705 pull_request = PullRequest.get_or_404(self.request.matchdict['pull_request_id'])
1705 pull_request = PullRequest.get_or_404(self.request.matchdict['pull_request_id'])
1706
1706
1707 if pull_request.is_closed():
1707 if pull_request.is_closed():
1708 log.debug('comment: forbidden because pull request is closed')
1708 log.debug('comment: forbidden because pull request is closed')
1709 raise HTTPForbidden()
1709 raise HTTPForbidden()
1710
1710
1711 allowed_to_comment = PullRequestModel().check_user_comment(
1711 allowed_to_comment = PullRequestModel().check_user_comment(
1712 pull_request, self._rhodecode_user)
1712 pull_request, self._rhodecode_user)
1713 if not allowed_to_comment:
1713 if not allowed_to_comment:
1714 log.debug('comment: forbidden because pull request is from forbidden repo')
1714 log.debug('comment: forbidden because pull request is from forbidden repo')
1715 raise HTTPForbidden()
1715 raise HTTPForbidden()
1716
1716
1717 comment_data = {
1717 comment_data = {
1718 'comment_type': self.request.POST.get('comment_type'),
1718 'comment_type': self.request.POST.get('comment_type'),
1719 'text': self.request.POST.get('text'),
1719 'text': self.request.POST.get('text'),
1720 'status': self.request.POST.get('changeset_status', None),
1720 'status': self.request.POST.get('changeset_status', None),
1721 'is_draft': self.request.POST.get('draft'),
1721 'is_draft': self.request.POST.get('draft'),
1722 'resolves_comment_id': self.request.POST.get('resolves_comment_id', None),
1722 'resolves_comment_id': self.request.POST.get('resolves_comment_id', None),
1723 'close_pull_request': self.request.POST.get('close_pull_request'),
1723 'close_pull_request': self.request.POST.get('close_pull_request'),
1724 'f_path': self.request.POST.get('f_path'),
1724 'f_path': self.request.POST.get('f_path'),
1725 'line': self.request.POST.get('line'),
1725 'line': self.request.POST.get('line'),
1726 }
1726 }
1727 data = self._pull_request_comments_create(pull_request, [comment_data])
1727 data = self._pull_request_comments_create(pull_request, [comment_data])
1728
1728
1729 return data
1729 return data
1730
1730
1731 @LoginRequired()
1731 @LoginRequired()
1732 @NotAnonymous()
1732 @NotAnonymous()
1733 @HasRepoPermissionAnyDecorator(
1733 @HasRepoPermissionAnyDecorator(
1734 'repository.read', 'repository.write', 'repository.admin')
1734 'repository.read', 'repository.write', 'repository.admin')
1735 @CSRFRequired()
1735 @CSRFRequired()
1736 def pull_request_comment_delete(self):
1736 def pull_request_comment_delete(self):
1737 pull_request = PullRequest.get_or_404(
1737 pull_request = PullRequest.get_or_404(
1738 self.request.matchdict['pull_request_id'])
1738 self.request.matchdict['pull_request_id'])
1739
1739
1740 comment = ChangesetComment.get_or_404(
1740 comment = ChangesetComment.get_or_404(
1741 self.request.matchdict['comment_id'])
1741 self.request.matchdict['comment_id'])
1742 comment_id = comment.comment_id
1742 comment_id = comment.comment_id
1743
1743
1744 if comment.immutable:
1744 if comment.immutable:
1745 # don't allow deleting comments that are immutable
1745 # don't allow deleting comments that are immutable
1746 raise HTTPForbidden()
1746 raise HTTPForbidden()
1747
1747
1748 if pull_request.is_closed():
1748 if pull_request.is_closed():
1749 log.debug('comment: forbidden because pull request is closed')
1749 log.debug('comment: forbidden because pull request is closed')
1750 raise HTTPForbidden()
1750 raise HTTPForbidden()
1751
1751
1752 if not comment:
1752 if not comment:
1753 log.debug('Comment with id:%s not found, skipping', comment_id)
1753 log.debug('Comment with id:%s not found, skipping', comment_id)
1754 # comment already deleted in another call probably
1754 # comment already deleted in another call probably
1755 return True
1755 return True
1756
1756
1757 if comment.pull_request.is_closed():
1757 if comment.pull_request.is_closed():
1758 # don't allow deleting comments on closed pull request
1758 # don't allow deleting comments on closed pull request
1759 raise HTTPForbidden()
1759 raise HTTPForbidden()
1760
1760
1761 is_repo_admin = h.HasRepoPermissionAny('repository.admin')(self.db_repo_name)
1761 is_repo_admin = h.HasRepoPermissionAny('repository.admin')(self.db_repo_name)
1762 super_admin = h.HasPermissionAny('hg.admin')()
1762 super_admin = h.HasPermissionAny('hg.admin')()
1763 comment_owner = comment.author.user_id == self._rhodecode_user.user_id
1763 comment_owner = comment.author.user_id == self._rhodecode_user.user_id
1764 is_repo_comment = comment.repo.repo_name == self.db_repo_name
1764 is_repo_comment = comment.repo.repo_name == self.db_repo_name
1765 comment_repo_admin = is_repo_admin and is_repo_comment
1765 comment_repo_admin = is_repo_admin and is_repo_comment
1766
1766
1767 if comment.draft and not comment_owner:
1767 if comment.draft and not comment_owner:
1768 # We never allow to delete draft comments for other than owners
1768 # We never allow to delete draft comments for other than owners
1769 raise HTTPNotFound()
1769 raise HTTPNotFound()
1770
1770
1771 if super_admin or comment_owner or comment_repo_admin:
1771 if super_admin or comment_owner or comment_repo_admin:
1772 old_calculated_status = comment.pull_request.calculated_review_status()
1772 old_calculated_status = comment.pull_request.calculated_review_status()
1773 CommentsModel().delete(comment=comment, auth_user=self._rhodecode_user)
1773 CommentsModel().delete(comment=comment, auth_user=self._rhodecode_user)
1774 Session().commit()
1774 Session().commit()
1775 calculated_status = comment.pull_request.calculated_review_status()
1775 calculated_status = comment.pull_request.calculated_review_status()
1776 if old_calculated_status != calculated_status:
1776 if old_calculated_status != calculated_status:
1777 PullRequestModel().trigger_pull_request_hook(
1777 PullRequestModel().trigger_pull_request_hook(
1778 comment.pull_request, self._rhodecode_user, 'review_status_change',
1778 comment.pull_request, self._rhodecode_user, 'review_status_change',
1779 data={'status': calculated_status})
1779 data={'status': calculated_status})
1780 return True
1780 return True
1781 else:
1781 else:
1782 log.warning('No permissions for user %s to delete comment_id: %s',
1782 log.warning('No permissions for user %s to delete comment_id: %s',
1783 self._rhodecode_db_user, comment_id)
1783 self._rhodecode_db_user, comment_id)
1784 raise HTTPNotFound()
1784 raise HTTPNotFound()
1785
1785
1786 @LoginRequired()
1786 @LoginRequired()
1787 @NotAnonymous()
1787 @NotAnonymous()
1788 @HasRepoPermissionAnyDecorator(
1788 @HasRepoPermissionAnyDecorator(
1789 'repository.read', 'repository.write', 'repository.admin')
1789 'repository.read', 'repository.write', 'repository.admin')
1790 @CSRFRequired()
1790 @CSRFRequired()
1791 def pull_request_comment_edit(self):
1791 def pull_request_comment_edit(self):
1792 self.load_default_context()
1792 self.load_default_context()
1793
1793
1794 pull_request = PullRequest.get_or_404(
1794 pull_request = PullRequest.get_or_404(
1795 self.request.matchdict['pull_request_id']
1795 self.request.matchdict['pull_request_id']
1796 )
1796 )
1797 comment = ChangesetComment.get_or_404(
1797 comment = ChangesetComment.get_or_404(
1798 self.request.matchdict['comment_id']
1798 self.request.matchdict['comment_id']
1799 )
1799 )
1800 comment_id = comment.comment_id
1800 comment_id = comment.comment_id
1801
1801
1802 if comment.immutable:
1802 if comment.immutable:
1803 # don't allow deleting comments that are immutable
1803 # don't allow deleting comments that are immutable
1804 raise HTTPForbidden()
1804 raise HTTPForbidden()
1805
1805
1806 if pull_request.is_closed():
1806 if pull_request.is_closed():
1807 log.debug('comment: forbidden because pull request is closed')
1807 log.debug('comment: forbidden because pull request is closed')
1808 raise HTTPForbidden()
1808 raise HTTPForbidden()
1809
1809
1810 if comment.pull_request.is_closed():
1810 if comment.pull_request.is_closed():
1811 # don't allow deleting comments on closed pull request
1811 # don't allow deleting comments on closed pull request
1812 raise HTTPForbidden()
1812 raise HTTPForbidden()
1813
1813
1814 is_repo_admin = h.HasRepoPermissionAny('repository.admin')(self.db_repo_name)
1814 is_repo_admin = h.HasRepoPermissionAny('repository.admin')(self.db_repo_name)
1815 super_admin = h.HasPermissionAny('hg.admin')()
1815 super_admin = h.HasPermissionAny('hg.admin')()
1816 comment_owner = comment.author.user_id == self._rhodecode_user.user_id
1816 comment_owner = comment.author.user_id == self._rhodecode_user.user_id
1817 is_repo_comment = comment.repo.repo_name == self.db_repo_name
1817 is_repo_comment = comment.repo.repo_name == self.db_repo_name
1818 comment_repo_admin = is_repo_admin and is_repo_comment
1818 comment_repo_admin = is_repo_admin and is_repo_comment
1819
1819
1820 if super_admin or comment_owner or comment_repo_admin:
1820 if super_admin or comment_owner or comment_repo_admin:
1821 text = self.request.POST.get('text')
1821 text = self.request.POST.get('text')
1822 version = self.request.POST.get('version')
1822 version = self.request.POST.get('version')
1823 if text == comment.text:
1823 if text == comment.text:
1824 log.warning(
1824 log.warning(
1825 'Comment(PR): '
1825 'Comment(PR): '
1826 'Trying to create new version '
1826 'Trying to create new version '
1827 'with the same comment body {}'.format(
1827 'with the same comment body {}'.format(
1828 comment_id,
1828 comment_id,
1829 )
1829 )
1830 )
1830 )
1831 raise HTTPNotFound()
1831 raise HTTPNotFound()
1832
1832
1833 if version.isdigit():
1833 if version.isdigit():
1834 version = int(version)
1834 version = int(version)
1835 else:
1835 else:
1836 log.warning(
1836 log.warning(
1837 'Comment(PR): Wrong version type {} {} '
1837 'Comment(PR): Wrong version type {} {} '
1838 'for comment {}'.format(
1838 'for comment {}'.format(
1839 version,
1839 version,
1840 type(version),
1840 type(version),
1841 comment_id,
1841 comment_id,
1842 )
1842 )
1843 )
1843 )
1844 raise HTTPNotFound()
1844 raise HTTPNotFound()
1845
1845
1846 try:
1846 try:
1847 comment_history = CommentsModel().edit(
1847 comment_history = CommentsModel().edit(
1848 comment_id=comment_id,
1848 comment_id=comment_id,
1849 text=text,
1849 text=text,
1850 auth_user=self._rhodecode_user,
1850 auth_user=self._rhodecode_user,
1851 version=version,
1851 version=version,
1852 )
1852 )
1853 except CommentVersionMismatch:
1853 except CommentVersionMismatch:
1854 raise HTTPConflict()
1854 raise HTTPConflict()
1855
1855
1856 if not comment_history:
1856 if not comment_history:
1857 raise HTTPNotFound()
1857 raise HTTPNotFound()
1858
1858
1859 Session().commit()
1859 Session().commit()
1860 if not comment.draft:
1860 if not comment.draft:
1861 PullRequestModel().trigger_pull_request_hook(
1861 PullRequestModel().trigger_pull_request_hook(
1862 pull_request, self._rhodecode_user, 'comment_edit',
1862 pull_request, self._rhodecode_user, 'comment_edit',
1863 data={'comment': comment})
1863 data={'comment': comment})
1864
1864
1865 return {
1865 return {
1866 'comment_history_id': comment_history.comment_history_id,
1866 'comment_history_id': comment_history.comment_history_id,
1867 'comment_id': comment.comment_id,
1867 'comment_id': comment.comment_id,
1868 'comment_version': comment_history.version,
1868 'comment_version': comment_history.version,
1869 'comment_author_username': comment_history.author.username,
1869 'comment_author_username': comment_history.author.username,
1870 'comment_author_gravatar': h.gravatar_url(comment_history.author.email, 16),
1870 'comment_author_gravatar': h.gravatar_url(comment_history.author.email, 16),
1871 'comment_created_on': h.age_component(comment_history.created_on,
1871 'comment_created_on': h.age_component(comment_history.created_on,
1872 time_is_local=True),
1872 time_is_local=True),
1873 }
1873 }
1874 else:
1874 else:
1875 log.warning('No permissions for user %s to edit comment_id: %s',
1875 log.warning('No permissions for user %s to edit comment_id: %s',
1876 self._rhodecode_db_user, comment_id)
1876 self._rhodecode_db_user, comment_id)
1877 raise HTTPNotFound()
1877 raise HTTPNotFound()
@@ -1,127 +1,125 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2017-2020 RhodeCode GmbH
3 # Copyright (C) 2017-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import logging
21 import logging
22
22
23 from pyramid.httpexceptions import HTTPFound, HTTPNotFound
23 from pyramid.httpexceptions import HTTPFound, HTTPNotFound
24
24
25 import formencode
25 import formencode
26
26
27 from rhodecode.apps._base import RepoAppView
27 from rhodecode.apps._base import RepoAppView
28 from rhodecode.lib import audit_logger
28 from rhodecode.lib import audit_logger
29 from rhodecode.lib import helpers as h
29 from rhodecode.lib import helpers as h
30 from rhodecode.lib.auth import (
30 from rhodecode.lib.auth import (
31 LoginRequired, HasRepoPermissionAnyDecorator, CSRFRequired)
31 LoginRequired, HasRepoPermissionAnyDecorator, CSRFRequired)
32 from rhodecode.model.forms import IssueTrackerPatternsForm
32 from rhodecode.model.forms import IssueTrackerPatternsForm
33 from rhodecode.model.meta import Session
33 from rhodecode.model.meta import Session
34 from rhodecode.model.settings import SettingsModel
34 from rhodecode.model.settings import SettingsModel
35
35
36 log = logging.getLogger(__name__)
36 log = logging.getLogger(__name__)
37
37
38
38
39 class RepoSettingsIssueTrackersView(RepoAppView):
39 class RepoSettingsIssueTrackersView(RepoAppView):
40 def load_default_context(self):
40 def load_default_context(self):
41 c = self._get_local_tmpl_context()
41 c = self._get_local_tmpl_context()
42
43
44 return c
42 return c
45
43
46 @LoginRequired()
44 @LoginRequired()
47 @HasRepoPermissionAnyDecorator('repository.admin')
45 @HasRepoPermissionAnyDecorator('repository.admin')
48 def repo_issuetracker(self):
46 def repo_issuetracker(self):
49 c = self.load_default_context()
47 c = self.load_default_context()
50 c.active = 'issuetracker'
48 c.active = 'issuetracker'
51 c.data = 'data'
49 c.data = 'data'
52
50
53 c.settings_model = self.db_repo_patterns
51 c.settings_model = self.db_repo_patterns
54 c.global_patterns = c.settings_model.get_global_settings()
52 c.global_patterns = c.settings_model.get_global_settings()
55 c.repo_patterns = c.settings_model.get_repo_settings()
53 c.repo_patterns = c.settings_model.get_repo_settings()
56
54
57 return self._get_template_context(c)
55 return self._get_template_context(c)
58
56
59 @LoginRequired()
57 @LoginRequired()
60 @HasRepoPermissionAnyDecorator('repository.admin')
58 @HasRepoPermissionAnyDecorator('repository.admin')
61 @CSRFRequired()
59 @CSRFRequired()
62 def repo_issuetracker_test(self):
60 def repo_issuetracker_test(self):
63 return h.urlify_commit_message(
61 return h.urlify_commit_message(
64 self.request.POST.get('test_text', ''),
62 self.request.POST.get('test_text', ''),
65 self.db_repo_name)
63 self.db_repo_name)
66
64
67 @LoginRequired()
65 @LoginRequired()
68 @HasRepoPermissionAnyDecorator('repository.admin')
66 @HasRepoPermissionAnyDecorator('repository.admin')
69 @CSRFRequired()
67 @CSRFRequired()
70 def repo_issuetracker_delete(self):
68 def repo_issuetracker_delete(self):
71 _ = self.request.translate
69 _ = self.request.translate
72 uid = self.request.POST.get('uid')
70 uid = self.request.POST.get('uid')
73 repo_settings = self.db_repo_patterns
71 repo_settings = self.db_repo_patterns
74 try:
72 try:
75 repo_settings.delete_entries(uid)
73 repo_settings.delete_entries(uid)
76 except Exception:
74 except Exception:
77 h.flash(_('Error occurred during deleting issue tracker entry'),
75 h.flash(_('Error occurred during deleting issue tracker entry'),
78 category='error')
76 category='error')
79 raise HTTPNotFound()
77 raise HTTPNotFound()
80
78
81 SettingsModel().invalidate_settings_cache()
79 SettingsModel().invalidate_settings_cache()
82 h.flash(_('Removed issue tracker entry.'), category='success')
80 h.flash(_('Removed issue tracker entry.'), category='success')
83
81
84 return {'deleted': uid}
82 return {'deleted': uid}
85
83
86 def _update_patterns(self, form, repo_settings):
84 def _update_patterns(self, form, repo_settings):
87 for uid in form['delete_patterns']:
85 for uid in form['delete_patterns']:
88 repo_settings.delete_entries(uid)
86 repo_settings.delete_entries(uid)
89
87
90 for pattern_data in form['patterns']:
88 for pattern_data in form['patterns']:
91 for setting_key, pattern, type_ in pattern_data:
89 for setting_key, pattern, type_ in pattern_data:
92 sett = repo_settings.create_or_update_setting(
90 sett = repo_settings.create_or_update_setting(
93 setting_key, pattern.strip(), type_)
91 setting_key, pattern.strip(), type_)
94 Session().add(sett)
92 Session().add(sett)
95
93
96 Session().commit()
94 Session().commit()
97
95
98 @LoginRequired()
96 @LoginRequired()
99 @HasRepoPermissionAnyDecorator('repository.admin')
97 @HasRepoPermissionAnyDecorator('repository.admin')
100 @CSRFRequired()
98 @CSRFRequired()
101 def repo_issuetracker_update(self):
99 def repo_issuetracker_update(self):
102 _ = self.request.translate
100 _ = self.request.translate
103 # Save inheritance
101 # Save inheritance
104 repo_settings = self.db_repo_patterns
102 repo_settings = self.db_repo_patterns
105 inherited = (
103 inherited = (
106 self.request.POST.get('inherit_global_issuetracker') == "inherited")
104 self.request.POST.get('inherit_global_issuetracker') == "inherited")
107 repo_settings.inherit_global_settings = inherited
105 repo_settings.inherit_global_settings = inherited
108 Session().commit()
106 Session().commit()
109
107
110 try:
108 try:
111 form = IssueTrackerPatternsForm(self.request.translate)().to_python(self.request.POST)
109 form = IssueTrackerPatternsForm(self.request.translate)().to_python(self.request.POST)
112 except formencode.Invalid as errors:
110 except formencode.Invalid as errors:
113 log.exception('Failed to add new pattern')
111 log.exception('Failed to add new pattern')
114 error = errors
112 error = errors
115 h.flash(_('Invalid issue tracker pattern: {}'.format(error)),
113 h.flash(_('Invalid issue tracker pattern: {}'.format(error)),
116 category='error')
114 category='error')
117 raise HTTPFound(
115 raise HTTPFound(
118 h.route_path('edit_repo_issuetracker',
116 h.route_path('edit_repo_issuetracker',
119 repo_name=self.db_repo_name))
117 repo_name=self.db_repo_name))
120
118
121 if form:
119 if form:
122 self._update_patterns(form, repo_settings)
120 self._update_patterns(form, repo_settings)
123
121
124 h.flash(_('Updated issue tracker entries'), category='success')
122 h.flash(_('Updated issue tracker entries'), category='success')
125 raise HTTPFound(
123 raise HTTPFound(
126 h.route_path('edit_repo_issuetracker', repo_name=self.db_repo_name))
124 h.route_path('edit_repo_issuetracker', repo_name=self.db_repo_name))
127
125
@@ -1,290 +1,290 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2011-2020 RhodeCode GmbH
3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import logging
21 import logging
22 import string
22 import string
23 import time
23 import time
24
24
25 import rhodecode
25 import rhodecode
26
26
27
27
28
28
29 from rhodecode.lib.view_utils import get_format_ref_id
29 from rhodecode.lib.view_utils import get_format_ref_id
30 from rhodecode.apps._base import RepoAppView
30 from rhodecode.apps._base import RepoAppView
31 from rhodecode.config.conf import (LANGUAGES_EXTENSIONS_MAP)
31 from rhodecode.config.conf import (LANGUAGES_EXTENSIONS_MAP)
32 from rhodecode.lib import helpers as h, rc_cache
32 from rhodecode.lib import helpers as h, rc_cache
33 from rhodecode.lib.utils2 import safe_str, safe_int
33 from rhodecode.lib.utils2 import safe_str, safe_int
34 from rhodecode.lib.auth import LoginRequired, HasRepoPermissionAnyDecorator
34 from rhodecode.lib.auth import LoginRequired, HasRepoPermissionAnyDecorator
35 from rhodecode.lib.ext_json import json
35 from rhodecode.lib.ext_json import json
36 from rhodecode.lib.vcs.backends.base import EmptyCommit
36 from rhodecode.lib.vcs.backends.base import EmptyCommit
37 from rhodecode.lib.vcs.exceptions import (
37 from rhodecode.lib.vcs.exceptions import (
38 CommitError, EmptyRepositoryError, CommitDoesNotExistError)
38 CommitError, EmptyRepositoryError, CommitDoesNotExistError)
39 from rhodecode.model.db import Statistics, CacheKey, User
39 from rhodecode.model.db import Statistics, CacheKey, User
40 from rhodecode.model.meta import Session
40 from rhodecode.model.meta import Session
41 from rhodecode.model.scm import ScmModel
41 from rhodecode.model.scm import ScmModel
42
42
43 log = logging.getLogger(__name__)
43 log = logging.getLogger(__name__)
44
44
45
45
46 class RepoSummaryView(RepoAppView):
46 class RepoSummaryView(RepoAppView):
47
47
48 def load_default_context(self):
48 def load_default_context(self):
49 c = self._get_local_tmpl_context(include_app_defaults=True)
49 c = self._get_local_tmpl_context(include_app_defaults=True)
50 c.rhodecode_repo = None
50 c.rhodecode_repo = None
51 if not c.repository_requirements_missing:
51 if not c.repository_requirements_missing:
52 c.rhodecode_repo = self.rhodecode_vcs_repo
52 c.rhodecode_repo = self.rhodecode_vcs_repo
53 return c
53 return c
54
54
55 def _load_commits_context(self, c):
55 def _load_commits_context(self, c):
56 p = safe_int(self.request.GET.get('page'), 1)
56 p = safe_int(self.request.GET.get('page'), 1)
57 size = safe_int(self.request.GET.get('size'), 10)
57 size = safe_int(self.request.GET.get('size'), 10)
58
58
59 def url_generator(page_num):
59 def url_generator(page_num):
60 query_params = {
60 query_params = {
61 'page': page_num,
61 'page': page_num,
62 'size': size
62 'size': size
63 }
63 }
64 return h.route_path(
64 return h.route_path(
65 'repo_summary_commits',
65 'repo_summary_commits',
66 repo_name=c.rhodecode_db_repo.repo_name, _query=query_params)
66 repo_name=c.rhodecode_db_repo.repo_name, _query=query_params)
67
67
68 pre_load = self.get_commit_preload_attrs()
68 pre_load = self.get_commit_preload_attrs()
69
69
70 try:
70 try:
71 collection = self.rhodecode_vcs_repo.get_commits(
71 collection = self.rhodecode_vcs_repo.get_commits(
72 pre_load=pre_load, translate_tags=False)
72 pre_load=pre_load, translate_tags=False)
73 except EmptyRepositoryError:
73 except EmptyRepositoryError:
74 collection = self.rhodecode_vcs_repo
74 collection = self.rhodecode_vcs_repo
75
75
76 c.repo_commits = h.RepoPage(
76 c.repo_commits = h.RepoPage(
77 collection, page=p, items_per_page=size, url_maker=url_generator)
77 collection, page=p, items_per_page=size, url_maker=url_generator)
78 page_ids = [x.raw_id for x in c.repo_commits]
78 page_ids = [x.raw_id for x in c.repo_commits]
79 c.comments = self.db_repo.get_comments(page_ids)
79 c.comments = self.db_repo.get_comments(page_ids)
80 c.statuses = self.db_repo.statuses(page_ids)
80 c.statuses = self.db_repo.statuses(page_ids)
81
81
82 def _prepare_and_set_clone_url(self, c):
82 def _prepare_and_set_clone_url(self, c):
83 username = ''
83 username = ''
84 if self._rhodecode_user.username != User.DEFAULT_USER:
84 if self._rhodecode_user.username != User.DEFAULT_USER:
85 username = safe_str(self._rhodecode_user.username)
85 username = safe_str(self._rhodecode_user.username)
86
86
87 _def_clone_uri = c.clone_uri_tmpl
87 _def_clone_uri = c.clone_uri_tmpl
88 _def_clone_uri_id = c.clone_uri_id_tmpl
88 _def_clone_uri_id = c.clone_uri_id_tmpl
89 _def_clone_uri_ssh = c.clone_uri_ssh_tmpl
89 _def_clone_uri_ssh = c.clone_uri_ssh_tmpl
90
90
91 c.clone_repo_url = self.db_repo.clone_url(
91 c.clone_repo_url = self.db_repo.clone_url(
92 user=username, uri_tmpl=_def_clone_uri)
92 user=username, uri_tmpl=_def_clone_uri)
93 c.clone_repo_url_id = self.db_repo.clone_url(
93 c.clone_repo_url_id = self.db_repo.clone_url(
94 user=username, uri_tmpl=_def_clone_uri_id)
94 user=username, uri_tmpl=_def_clone_uri_id)
95 c.clone_repo_url_ssh = self.db_repo.clone_url(
95 c.clone_repo_url_ssh = self.db_repo.clone_url(
96 uri_tmpl=_def_clone_uri_ssh, ssh=True)
96 uri_tmpl=_def_clone_uri_ssh, ssh=True)
97
97
98 @LoginRequired()
98 @LoginRequired()
99 @HasRepoPermissionAnyDecorator(
99 @HasRepoPermissionAnyDecorator(
100 'repository.read', 'repository.write', 'repository.admin')
100 'repository.read', 'repository.write', 'repository.admin')
101 def summary_commits(self):
101 def summary_commits(self):
102 c = self.load_default_context()
102 c = self.load_default_context()
103 self._prepare_and_set_clone_url(c)
103 self._prepare_and_set_clone_url(c)
104 self._load_commits_context(c)
104 self._load_commits_context(c)
105 return self._get_template_context(c)
105 return self._get_template_context(c)
106
106
107 @LoginRequired()
107 @LoginRequired()
108 @HasRepoPermissionAnyDecorator(
108 @HasRepoPermissionAnyDecorator(
109 'repository.read', 'repository.write', 'repository.admin')
109 'repository.read', 'repository.write', 'repository.admin')
110 def summary(self):
110 def summary(self):
111 c = self.load_default_context()
111 c = self.load_default_context()
112
112
113 # Prepare the clone URL
113 # Prepare the clone URL
114 self._prepare_and_set_clone_url(c)
114 self._prepare_and_set_clone_url(c)
115
115
116 # If enabled, get statistics data
116 # If enabled, get statistics data
117 c.show_stats = bool(self.db_repo.enable_statistics)
117 c.show_stats = bool(self.db_repo.enable_statistics)
118
118
119 stats = Session().query(Statistics) \
119 stats = Session().query(Statistics) \
120 .filter(Statistics.repository == self.db_repo) \
120 .filter(Statistics.repository == self.db_repo) \
121 .scalar()
121 .scalar()
122
122
123 c.stats_percentage = 0
123 c.stats_percentage = 0
124
124
125 if stats and stats.languages:
125 if stats and stats.languages:
126 c.no_data = False is self.db_repo.enable_statistics
126 c.no_data = False is self.db_repo.enable_statistics
127 lang_stats_d = json.loads(stats.languages)
127 lang_stats_d = json.loads(stats.languages)
128
128
129 # Sort first by decreasing count and second by the file extension,
129 # Sort first by decreasing count and second by the file extension,
130 # so we have a consistent output.
130 # so we have a consistent output.
131 lang_stats_items = sorted(lang_stats_d.items(),
131 lang_stats_items = sorted(lang_stats_d.items(),
132 key=lambda k: (-k[1], k[0]))[:10]
132 key=lambda k: (-k[1], k[0]))[:10]
133 lang_stats = [(x, {"count": y,
133 lang_stats = [(x, {"count": y,
134 "desc": LANGUAGES_EXTENSIONS_MAP.get(x)})
134 "desc": LANGUAGES_EXTENSIONS_MAP.get(x)})
135 for x, y in lang_stats_items]
135 for x, y in lang_stats_items]
136
136
137 c.trending_languages = json.dumps(lang_stats)
137 c.trending_languages = json.dumps(lang_stats)
138 else:
138 else:
139 c.no_data = True
139 c.no_data = True
140 c.trending_languages = json.dumps({})
140 c.trending_languages = json.dumps({})
141
141
142 scm_model = ScmModel()
142 scm_model = ScmModel()
143 c.enable_downloads = self.db_repo.enable_downloads
143 c.enable_downloads = self.db_repo.enable_downloads
144 c.repository_followers = scm_model.get_followers(self.db_repo)
144 c.repository_followers = scm_model.get_followers(self.db_repo)
145 c.repository_forks = scm_model.get_forks(self.db_repo)
145 c.repository_forks = scm_model.get_forks(self.db_repo)
146
146
147 # first interaction with the VCS instance after here...
147 # first interaction with the VCS instance after here...
148 if c.repository_requirements_missing:
148 if c.repository_requirements_missing:
149 self.request.override_renderer = \
149 self.request.override_renderer = \
150 'rhodecode:templates/summary/missing_requirements.mako'
150 'rhodecode:templates/summary/missing_requirements.mako'
151 return self._get_template_context(c)
151 return self._get_template_context(c)
152
152
153 c.readme_data, c.readme_file = \
153 c.readme_data, c.readme_file = \
154 self._get_readme_data(self.db_repo, c.visual.default_renderer)
154 self._get_readme_data(self.db_repo, c.visual.default_renderer)
155
155
156 # loads the summary commits template context
156 # loads the summary commits template context
157 self._load_commits_context(c)
157 self._load_commits_context(c)
158
158
159 return self._get_template_context(c)
159 return self._get_template_context(c)
160
160
161 @LoginRequired()
161 @LoginRequired()
162 @HasRepoPermissionAnyDecorator(
162 @HasRepoPermissionAnyDecorator(
163 'repository.read', 'repository.write', 'repository.admin')
163 'repository.read', 'repository.write', 'repository.admin')
164 def repo_stats(self):
164 def repo_stats(self):
165 show_stats = bool(self.db_repo.enable_statistics)
165 show_stats = bool(self.db_repo.enable_statistics)
166 repo_id = self.db_repo.repo_id
166 repo_id = self.db_repo.repo_id
167
167
168 landing_commit = self.db_repo.get_landing_commit()
168 landing_commit = self.db_repo.get_landing_commit()
169 if isinstance(landing_commit, EmptyCommit):
169 if isinstance(landing_commit, EmptyCommit):
170 return {'size': 0, 'code_stats': {}}
170 return {'size': 0, 'code_stats': {}}
171
171
172 cache_seconds = safe_int(rhodecode.CONFIG.get('rc_cache.cache_repo.expiration_time'))
172 cache_seconds = safe_int(rhodecode.CONFIG.get('rc_cache.cache_repo.expiration_time'))
173 cache_on = cache_seconds > 0
173 cache_on = cache_seconds > 0
174
174
175 log.debug(
175 log.debug(
176 'Computing REPO STATS for repo_id %s commit_id `%s` '
176 'Computing REPO STATS for repo_id %s commit_id `%s` '
177 'with caching: %s[TTL: %ss]' % (
177 'with caching: %s[TTL: %ss]' % (
178 repo_id, landing_commit, cache_on, cache_seconds or 0))
178 repo_id, landing_commit, cache_on, cache_seconds or 0))
179
179
180 cache_namespace_uid = 'cache_repo.{}'.format(repo_id)
180 cache_namespace_uid = 'cache_repo.{}'.format(repo_id)
181 region = rc_cache.get_or_create_region('cache_repo', cache_namespace_uid)
181 region = rc_cache.get_or_create_region('cache_repo', cache_namespace_uid)
182
182
183 @region.conditional_cache_on_arguments(namespace=cache_namespace_uid,
183 @region.conditional_cache_on_arguments(namespace=cache_namespace_uid,
184 condition=cache_on)
184 condition=cache_on)
185 def compute_stats(repo_id, commit_id, _show_stats):
185 def compute_stats(repo_id, commit_id, _show_stats):
186 code_stats = {}
186 code_stats = {}
187 size = 0
187 size = 0
188 try:
188 try:
189 commit = self.db_repo.get_commit(commit_id)
189 commit = self.db_repo.get_commit(commit_id)
190
190
191 for node in commit.get_filenodes_generator():
191 for node in commit.get_filenodes_generator():
192 size += node.size
192 size += node.size
193 if not _show_stats:
193 if not _show_stats:
194 continue
194 continue
195 ext = string.lower(node.extension)
195 ext = node.extension.lower()
196 ext_info = LANGUAGES_EXTENSIONS_MAP.get(ext)
196 ext_info = LANGUAGES_EXTENSIONS_MAP.get(ext)
197 if ext_info:
197 if ext_info:
198 if ext in code_stats:
198 if ext in code_stats:
199 code_stats[ext]['count'] += 1
199 code_stats[ext]['count'] += 1
200 else:
200 else:
201 code_stats[ext] = {"count": 1, "desc": ext_info}
201 code_stats[ext] = {"count": 1, "desc": ext_info}
202 except (EmptyRepositoryError, CommitDoesNotExistError):
202 except (EmptyRepositoryError, CommitDoesNotExistError):
203 pass
203 pass
204 return {'size': h.format_byte_size_binary(size),
204 return {'size': h.format_byte_size_binary(size),
205 'code_stats': code_stats}
205 'code_stats': code_stats}
206
206
207 stats = compute_stats(self.db_repo.repo_id, landing_commit.raw_id, show_stats)
207 stats = compute_stats(self.db_repo.repo_id, landing_commit.raw_id, show_stats)
208 return stats
208 return stats
209
209
210 @LoginRequired()
210 @LoginRequired()
211 @HasRepoPermissionAnyDecorator(
211 @HasRepoPermissionAnyDecorator(
212 'repository.read', 'repository.write', 'repository.admin')
212 'repository.read', 'repository.write', 'repository.admin')
213 def repo_refs_data(self):
213 def repo_refs_data(self):
214 _ = self.request.translate
214 _ = self.request.translate
215 self.load_default_context()
215 self.load_default_context()
216
216
217 repo = self.rhodecode_vcs_repo
217 repo = self.rhodecode_vcs_repo
218 refs_to_create = [
218 refs_to_create = [
219 (_("Branch"), repo.branches, 'branch'),
219 (_("Branch"), repo.branches, 'branch'),
220 (_("Tag"), repo.tags, 'tag'),
220 (_("Tag"), repo.tags, 'tag'),
221 (_("Bookmark"), repo.bookmarks, 'book'),
221 (_("Bookmark"), repo.bookmarks, 'book'),
222 ]
222 ]
223 res = self._create_reference_data(repo, self.db_repo_name, refs_to_create)
223 res = self._create_reference_data(repo, self.db_repo_name, refs_to_create)
224 data = {
224 data = {
225 'more': False,
225 'more': False,
226 'results': res
226 'results': res
227 }
227 }
228 return data
228 return data
229
229
230 @LoginRequired()
230 @LoginRequired()
231 @HasRepoPermissionAnyDecorator(
231 @HasRepoPermissionAnyDecorator(
232 'repository.read', 'repository.write', 'repository.admin')
232 'repository.read', 'repository.write', 'repository.admin')
233 def repo_refs_changelog_data(self):
233 def repo_refs_changelog_data(self):
234 _ = self.request.translate
234 _ = self.request.translate
235 self.load_default_context()
235 self.load_default_context()
236
236
237 repo = self.rhodecode_vcs_repo
237 repo = self.rhodecode_vcs_repo
238
238
239 refs_to_create = [
239 refs_to_create = [
240 (_("Branches"), repo.branches, 'branch'),
240 (_("Branches"), repo.branches, 'branch'),
241 (_("Closed branches"), repo.branches_closed, 'branch_closed'),
241 (_("Closed branches"), repo.branches_closed, 'branch_closed'),
242 # TODO: enable when vcs can handle bookmarks filters
242 # TODO: enable when vcs can handle bookmarks filters
243 # (_("Bookmarks"), repo.bookmarks, "book"),
243 # (_("Bookmarks"), repo.bookmarks, "book"),
244 ]
244 ]
245 res = self._create_reference_data(
245 res = self._create_reference_data(
246 repo, self.db_repo_name, refs_to_create)
246 repo, self.db_repo_name, refs_to_create)
247 data = {
247 data = {
248 'more': False,
248 'more': False,
249 'results': res
249 'results': res
250 }
250 }
251 return data
251 return data
252
252
253 def _create_reference_data(self, repo, full_repo_name, refs_to_create):
253 def _create_reference_data(self, repo, full_repo_name, refs_to_create):
254 format_ref_id = get_format_ref_id(repo)
254 format_ref_id = get_format_ref_id(repo)
255
255
256 result = []
256 result = []
257 for title, refs, ref_type in refs_to_create:
257 for title, refs, ref_type in refs_to_create:
258 if refs:
258 if refs:
259 result.append({
259 result.append({
260 'text': title,
260 'text': title,
261 'children': self._create_reference_items(
261 'children': self._create_reference_items(
262 repo, full_repo_name, refs, ref_type,
262 repo, full_repo_name, refs, ref_type,
263 format_ref_id),
263 format_ref_id),
264 })
264 })
265 return result
265 return result
266
266
267 def _create_reference_items(self, repo, full_repo_name, refs, ref_type, format_ref_id):
267 def _create_reference_items(self, repo, full_repo_name, refs, ref_type, format_ref_id):
268 result = []
268 result = []
269 is_svn = h.is_svn(repo)
269 is_svn = h.is_svn(repo)
270 for ref_name, raw_id in refs.items():
270 for ref_name, raw_id in refs.items():
271 files_url = self._create_files_url(
271 files_url = self._create_files_url(
272 repo, full_repo_name, ref_name, raw_id, is_svn)
272 repo, full_repo_name, ref_name, raw_id, is_svn)
273 result.append({
273 result.append({
274 'text': ref_name,
274 'text': ref_name,
275 'id': format_ref_id(ref_name, raw_id),
275 'id': format_ref_id(ref_name, raw_id),
276 'raw_id': raw_id,
276 'raw_id': raw_id,
277 'type': ref_type,
277 'type': ref_type,
278 'files_url': files_url,
278 'files_url': files_url,
279 'idx': 0,
279 'idx': 0,
280 })
280 })
281 return result
281 return result
282
282
283 def _create_files_url(self, repo, full_repo_name, ref_name, raw_id, is_svn):
283 def _create_files_url(self, repo, full_repo_name, ref_name, raw_id, is_svn):
284 use_commit_id = '/' in ref_name or is_svn
284 use_commit_id = '/' in ref_name or is_svn
285 return h.route_path(
285 return h.route_path(
286 'repo_files',
286 'repo_files',
287 repo_name=full_repo_name,
287 repo_name=full_repo_name,
288 f_path=ref_name if is_svn else '',
288 f_path=ref_name if is_svn else '',
289 commit_id=raw_id if use_commit_id else ref_name,
289 commit_id=raw_id if use_commit_id else ref_name,
290 _query=dict(at=ref_name))
290 _query=dict(at=ref_name))
@@ -1,172 +1,172 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2012-2020 RhodeCode GmbH
3 # Copyright (C) 2012-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 """
21 """
22 RhodeCode authentication library for PAM
22 RhodeCode authentication library for PAM
23 """
23 """
24
24
25 import colander
25 import colander
26 import grp
26 import grp
27 import logging
27 import logging
28 import pam
28 import pam
29 import pwd
29 import pwd
30 import re
30 import re
31 import socket
31 import socket
32
32
33 from rhodecode.translation import _
33 from rhodecode.translation import _
34 from rhodecode.authentication.base import (
34 from rhodecode.authentication.base import (
35 RhodeCodeExternalAuthPlugin, hybrid_property)
35 RhodeCodeExternalAuthPlugin, hybrid_property)
36 from rhodecode.authentication.schema import AuthnPluginSettingsSchemaBase
36 from rhodecode.authentication.schema import AuthnPluginSettingsSchemaBase
37 from rhodecode.authentication.routes import AuthnPluginResourceBase
37 from rhodecode.authentication.routes import AuthnPluginResourceBase
38 from rhodecode.lib.colander_utils import strip_whitespace
38 from rhodecode.lib.colander_utils import strip_whitespace
39
39
40 log = logging.getLogger(__name__)
40 log = logging.getLogger(__name__)
41
41
42
42
43 def plugin_factory(plugin_id, *args, **kwargs):
43 def plugin_factory(plugin_id, *args, **kwargs):
44 """
44 """
45 Factory function that is called during plugin discovery.
45 Factory function that is called during plugin discovery.
46 It returns the plugin instance.
46 It returns the plugin instance.
47 """
47 """
48 plugin = RhodeCodeAuthPlugin(plugin_id)
48 plugin = RhodeCodeAuthPlugin(plugin_id)
49 return plugin
49 return plugin
50
50
51
51
52 class PamAuthnResource(AuthnPluginResourceBase):
52 class PamAuthnResource(AuthnPluginResourceBase):
53 pass
53 pass
54
54
55
55
56 class PamSettingsSchema(AuthnPluginSettingsSchemaBase):
56 class PamSettingsSchema(AuthnPluginSettingsSchemaBase):
57 service = colander.SchemaNode(
57 service = colander.SchemaNode(
58 colander.String(),
58 colander.String(),
59 default='login',
59 default='login',
60 description=_('PAM service name to use for authentication.'),
60 description=_('PAM service name to use for authentication.'),
61 preparer=strip_whitespace,
61 preparer=strip_whitespace,
62 title=_('PAM service name'),
62 title=_('PAM service name'),
63 widget='string')
63 widget='string')
64 gecos = colander.SchemaNode(
64 gecos = colander.SchemaNode(
65 colander.String(),
65 colander.String(),
66 default='(?P<last_name>.+),\s*(?P<first_name>\w+)',
66 default=r'(?P<last_name>.+),\s*(?P<first_name>\w+)',
67 description=_('Regular expression for extracting user name/email etc. '
67 description=_('Regular expression for extracting user name/email etc. '
68 'from Unix userinfo.'),
68 'from Unix userinfo.'),
69 preparer=strip_whitespace,
69 preparer=strip_whitespace,
70 title=_('Gecos Regex'),
70 title=_('Gecos Regex'),
71 widget='string')
71 widget='string')
72
72
73
73
74 class RhodeCodeAuthPlugin(RhodeCodeExternalAuthPlugin):
74 class RhodeCodeAuthPlugin(RhodeCodeExternalAuthPlugin):
75 uid = 'pam'
75 uid = 'pam'
76 # PAM authentication can be slow. Repository operations involve a lot of
76 # PAM authentication can be slow. Repository operations involve a lot of
77 # auth calls. Little caching helps speedup push/pull operations significantly
77 # auth calls. Little caching helps speedup push/pull operations significantly
78 AUTH_CACHE_TTL = 4
78 AUTH_CACHE_TTL = 4
79
79
80 def includeme(self, config):
80 def includeme(self, config):
81 config.add_authn_plugin(self)
81 config.add_authn_plugin(self)
82 config.add_authn_resource(self.get_id(), PamAuthnResource(self))
82 config.add_authn_resource(self.get_id(), PamAuthnResource(self))
83 config.add_view(
83 config.add_view(
84 'rhodecode.authentication.views.AuthnPluginViewBase',
84 'rhodecode.authentication.views.AuthnPluginViewBase',
85 attr='settings_get',
85 attr='settings_get',
86 renderer='rhodecode:templates/admin/auth/plugin_settings.mako',
86 renderer='rhodecode:templates/admin/auth/plugin_settings.mako',
87 request_method='GET',
87 request_method='GET',
88 route_name='auth_home',
88 route_name='auth_home',
89 context=PamAuthnResource)
89 context=PamAuthnResource)
90 config.add_view(
90 config.add_view(
91 'rhodecode.authentication.views.AuthnPluginViewBase',
91 'rhodecode.authentication.views.AuthnPluginViewBase',
92 attr='settings_post',
92 attr='settings_post',
93 renderer='rhodecode:templates/admin/auth/plugin_settings.mako',
93 renderer='rhodecode:templates/admin/auth/plugin_settings.mako',
94 request_method='POST',
94 request_method='POST',
95 route_name='auth_home',
95 route_name='auth_home',
96 context=PamAuthnResource)
96 context=PamAuthnResource)
97
97
98 def get_display_name(self, load_from_settings=False):
98 def get_display_name(self, load_from_settings=False):
99 return _('PAM')
99 return _('PAM')
100
100
101 @classmethod
101 @classmethod
102 def docs(cls):
102 def docs(cls):
103 return "https://docs.rhodecode.com/RhodeCode-Enterprise/auth/auth-pam.html"
103 return "https://docs.rhodecode.com/RhodeCode-Enterprise/auth/auth-pam.html"
104
104
105 @hybrid_property
105 @hybrid_property
106 def name(self):
106 def name(self):
107 return u"pam"
107 return u"pam"
108
108
109 def get_settings_schema(self):
109 def get_settings_schema(self):
110 return PamSettingsSchema()
110 return PamSettingsSchema()
111
111
112 def use_fake_password(self):
112 def use_fake_password(self):
113 return True
113 return True
114
114
115 def auth(self, userobj, username, password, settings, **kwargs):
115 def auth(self, userobj, username, password, settings, **kwargs):
116 if not username or not password:
116 if not username or not password:
117 log.debug('Empty username or password skipping...')
117 log.debug('Empty username or password skipping...')
118 return None
118 return None
119 _pam = pam.pam()
119 _pam = pam.pam()
120 auth_result = _pam.authenticate(username, password, settings["service"])
120 auth_result = _pam.authenticate(username, password, settings["service"])
121
121
122 if not auth_result:
122 if not auth_result:
123 log.error("PAM was unable to authenticate user: %s", username)
123 log.error("PAM was unable to authenticate user: %s", username)
124 return None
124 return None
125
125
126 log.debug('Got PAM response %s', auth_result)
126 log.debug('Got PAM response %s', auth_result)
127
127
128 # old attrs fetched from RhodeCode database
128 # old attrs fetched from RhodeCode database
129 default_email = "%s@%s" % (username, socket.gethostname())
129 default_email = "%s@%s" % (username, socket.gethostname())
130 admin = getattr(userobj, 'admin', False)
130 admin = getattr(userobj, 'admin', False)
131 active = getattr(userobj, 'active', True)
131 active = getattr(userobj, 'active', True)
132 email = getattr(userobj, 'email', '') or default_email
132 email = getattr(userobj, 'email', '') or default_email
133 username = getattr(userobj, 'username', username)
133 username = getattr(userobj, 'username', username)
134 firstname = getattr(userobj, 'firstname', '')
134 firstname = getattr(userobj, 'firstname', '')
135 lastname = getattr(userobj, 'lastname', '')
135 lastname = getattr(userobj, 'lastname', '')
136 extern_type = getattr(userobj, 'extern_type', '')
136 extern_type = getattr(userobj, 'extern_type', '')
137
137
138 user_attrs = {
138 user_attrs = {
139 'username': username,
139 'username': username,
140 'firstname': firstname,
140 'firstname': firstname,
141 'lastname': lastname,
141 'lastname': lastname,
142 'groups': [g.gr_name for g in grp.getgrall()
142 'groups': [g.gr_name for g in grp.getgrall()
143 if username in g.gr_mem],
143 if username in g.gr_mem],
144 'user_group_sync': True,
144 'user_group_sync': True,
145 'email': email,
145 'email': email,
146 'admin': admin,
146 'admin': admin,
147 'active': active,
147 'active': active,
148 'active_from_extern': None,
148 'active_from_extern': None,
149 'extern_name': username,
149 'extern_name': username,
150 'extern_type': extern_type,
150 'extern_type': extern_type,
151 }
151 }
152
152
153 try:
153 try:
154 user_data = pwd.getpwnam(username)
154 user_data = pwd.getpwnam(username)
155 regex = settings["gecos"]
155 regex = settings["gecos"]
156 match = re.search(regex, user_data.pw_gecos)
156 match = re.search(regex, user_data.pw_gecos)
157 if match:
157 if match:
158 user_attrs["firstname"] = match.group('first_name')
158 user_attrs["firstname"] = match.group('first_name')
159 user_attrs["lastname"] = match.group('last_name')
159 user_attrs["lastname"] = match.group('last_name')
160 except Exception:
160 except Exception:
161 log.warning("Cannot extract additional info for PAM user")
161 log.warning("Cannot extract additional info for PAM user")
162 pass
162 pass
163
163
164 log.debug("pamuser: %s", user_attrs)
164 log.debug("pamuser: %s", user_attrs)
165 log.info('user `%s` authenticated correctly', user_attrs['username'],
165 log.info('user `%s` authenticated correctly', user_attrs['username'],
166 extra={"action": "user_auth_ok", "auth_module": "auth_pam", "username": user_attrs["username"]})
166 extra={"action": "user_auth_ok", "auth_module": "auth_pam", "username": user_attrs["username"]})
167 return user_attrs
167 return user_attrs
168
168
169
169
170 def includeme(config):
170 def includeme(config):
171 plugin_id = 'egg:rhodecode-enterprise-ce#{}'.format(RhodeCodeAuthPlugin.uid)
171 plugin_id = 'egg:rhodecode-enterprise-ce#{}'.format(RhodeCodeAuthPlugin.uid)
172 plugin_factory(plugin_id).includeme(config)
172 plugin_factory(plugin_id).includeme(config)
@@ -1,89 +1,90 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import os
21 import os
22 import logging
22 import logging
23 import rhodecode
23 import rhodecode
24 import collections
24
25
25 from rhodecode.config import utils
26 from rhodecode.config import utils
26
27
27 from rhodecode.lib.utils import load_rcextensions
28 from rhodecode.lib.utils import load_rcextensions
28 from rhodecode.lib.utils2 import str2bool
29 from rhodecode.lib.utils2 import str2bool
29 from rhodecode.lib.vcs import connect_vcs
30 from rhodecode.lib.vcs import connect_vcs
30
31
31 log = logging.getLogger(__name__)
32 log = logging.getLogger(__name__)
32
33
33
34
34 def load_pyramid_environment(global_config, settings):
35 def load_pyramid_environment(global_config, settings):
35 # Some parts of the code expect a merge of global and app settings.
36 # Some parts of the code expect a merge of global and app settings.
36 settings_merged = global_config.copy()
37 settings_merged = global_config.copy()
37 settings_merged.update(settings)
38 settings_merged.update(settings)
38
39
39 # TODO(marcink): probably not required anymore
40 # TODO(marcink): probably not required anymore
40 # configure channelstream,
41 # configure channelstream,
41 settings_merged['channelstream_config'] = {
42 settings_merged['channelstream_config'] = {
42 'enabled': str2bool(settings_merged.get('channelstream.enabled', False)),
43 'enabled': str2bool(settings_merged.get('channelstream.enabled', False)),
43 'server': settings_merged.get('channelstream.server'),
44 'server': settings_merged.get('channelstream.server'),
44 'secret': settings_merged.get('channelstream.secret')
45 'secret': settings_merged.get('channelstream.secret')
45 }
46 }
46
47
47 # If this is a test run we prepare the test environment like
48 # If this is a test run we prepare the test environment like
48 # creating a test database, test search index and test repositories.
49 # creating a test database, test search index and test repositories.
49 # This has to be done before the database connection is initialized.
50 # This has to be done before the database connection is initialized.
50 if settings['is_test']:
51 if settings['is_test']:
51 rhodecode.is_test = True
52 rhodecode.is_test = True
52 rhodecode.disable_error_handler = True
53 rhodecode.disable_error_handler = True
53 from rhodecode import authentication
54 from rhodecode import authentication
54 authentication.plugin_default_auth_ttl = 0
55 authentication.plugin_default_auth_ttl = 0
55
56
56 utils.initialize_test_environment(settings_merged)
57 utils.initialize_test_environment(settings_merged)
57
58
58 # Initialize the database connection.
59 # Initialize the database connection.
59 utils.initialize_database(settings_merged)
60 utils.initialize_database(settings_merged)
60
61
61 load_rcextensions(root_path=settings_merged['here'])
62 load_rcextensions(root_path=settings_merged['here'])
62
63
63 # Limit backends to `vcs.backends` from configuration, and preserve the order
64 # Limit backends to `vcs.backends` from configuration, and preserve the order
64 for alias in rhodecode.BACKENDS.keys():
65 for alias in rhodecode.BACKENDS.keys():
65 if alias not in settings['vcs.backends']:
66 if alias not in settings['vcs.backends']:
66 del rhodecode.BACKENDS[alias]
67 del rhodecode.BACKENDS[alias]
67
68
68 _sorted_backend = sorted(rhodecode.BACKENDS.items(),
69 _sorted_backend = sorted(rhodecode.BACKENDS.items(),
69 key=lambda item: settings['vcs.backends'].index(item[0]))
70 key=lambda item: settings['vcs.backends'].index(item[0]))
70 rhodecode.BACKENDS = rhodecode.OrderedDict(_sorted_backend)
71 rhodecode.BACKENDS = collections.OrderedDict(_sorted_backend)
71
72
72 log.info('Enabled VCS backends: %s', rhodecode.BACKENDS.keys())
73 log.info('Enabled VCS backends: %s', rhodecode.BACKENDS.keys())
73
74
74 # initialize vcs client and optionally run the server if enabled
75 # initialize vcs client and optionally run the server if enabled
75 vcs_server_uri = settings['vcs.server']
76 vcs_server_uri = settings['vcs.server']
76 vcs_server_enabled = settings['vcs.server.enable']
77 vcs_server_enabled = settings['vcs.server.enable']
77
78
78 utils.configure_vcs(settings)
79 utils.configure_vcs(settings)
79
80
80 # Store the settings to make them available to other modules.
81 # Store the settings to make them available to other modules.
81
82
82 rhodecode.PYRAMID_SETTINGS = settings_merged
83 rhodecode.PYRAMID_SETTINGS = settings_merged
83 rhodecode.CONFIG = settings_merged
84 rhodecode.CONFIG = settings_merged
84 rhodecode.CONFIG['default_user_id'] = utils.get_default_user_id()
85 rhodecode.CONFIG['default_user_id'] = utils.get_default_user_id()
85
86
86 if vcs_server_enabled:
87 if vcs_server_enabled:
87 connect_vcs(vcs_server_uri, utils.get_vcs_server_protocol(settings))
88 connect_vcs(vcs_server_uri, utils.get_vcs_server_protocol(settings))
88 else:
89 else:
89 log.warning('vcs-server not enabled, vcs connection unavailable')
90 log.warning('vcs-server not enabled, vcs connection unavailable')
@@ -1,282 +1,280 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """
2 """
3 Adapters
3 Adapters
4 --------
4 --------
5
5
6 .. contents::
6 .. contents::
7 :backlinks: none
7 :backlinks: none
8
8
9 The :func:`authomatic.login` function needs access to functionality like
9 The :func:`authomatic.login` function needs access to functionality like
10 getting the **URL** of the handler where it is being called, getting the
10 getting the **URL** of the handler where it is being called, getting the
11 **request params** and **cookies** and **writing the body**, **headers**
11 **request params** and **cookies** and **writing the body**, **headers**
12 and **status** to the response.
12 and **status** to the response.
13
13
14 Since implementation of these features varies across Python web frameworks,
14 Since implementation of these features varies across Python web frameworks,
15 the Authomatic library uses **adapters** to unify these differences into a
15 the Authomatic library uses **adapters** to unify these differences into a
16 single interface.
16 single interface.
17
17
18 Available Adapters
18 Available Adapters
19 ^^^^^^^^^^^^^^^^^^
19 ^^^^^^^^^^^^^^^^^^
20
20
21 If you are missing an adapter for the framework of your choice, please
21 If you are missing an adapter for the framework of your choice, please
22 open an `enhancement issue <https://github.com/authomatic/authomatic/issues>`_
22 open an `enhancement issue <https://github.com/authomatic/authomatic/issues>`_
23 or consider a contribution to this module by
23 or consider a contribution to this module by
24 :ref:`implementing <implement_adapters>` one by yourself.
24 :ref:`implementing <implement_adapters>` one by yourself.
25 Its very easy and shouldn't take you more than a few minutes.
25 Its very easy and shouldn't take you more than a few minutes.
26
26
27 .. autoclass:: DjangoAdapter
27 .. autoclass:: DjangoAdapter
28 :members:
28 :members:
29
29
30 .. autoclass:: Webapp2Adapter
30 .. autoclass:: Webapp2Adapter
31 :members:
31 :members:
32
32
33 .. autoclass:: WebObAdapter
33 .. autoclass:: WebObAdapter
34 :members:
34 :members:
35
35
36 .. autoclass:: WerkzeugAdapter
36 .. autoclass:: WerkzeugAdapter
37 :members:
37 :members:
38
38
39 .. _implement_adapters:
39 .. _implement_adapters:
40
40
41 Implementing an Adapter
41 Implementing an Adapter
42 ^^^^^^^^^^^^^^^^^^^^^^^
42 ^^^^^^^^^^^^^^^^^^^^^^^
43
43
44 Implementing an adapter for a Python web framework is pretty easy.
44 Implementing an adapter for a Python web framework is pretty easy.
45
45
46 Do it by subclassing the :class:`.BaseAdapter` abstract class.
46 Do it by subclassing the :class:`.BaseAdapter` abstract class.
47 There are only **six** members that you need to implement.
47 There are only **six** members that you need to implement.
48
48
49 Moreover if your framework is based on the |webob|_ or |werkzeug|_ package
49 Moreover if your framework is based on the |webob|_ or |werkzeug|_ package
50 you can subclass the :class:`.WebObAdapter` or :class:`.WerkzeugAdapter`
50 you can subclass the :class:`.WebObAdapter` or :class:`.WerkzeugAdapter`
51 respectively.
51 respectively.
52
52
53 .. autoclass:: BaseAdapter
53 .. autoclass:: BaseAdapter
54 :members:
54 :members:
55
55
56 """
56 """
57
57
58 import abc
58 import abc
59 from authomatic.core import Response
59 from authomatic.core import Response
60
60
61
61
62 class BaseAdapter(object):
62 class BaseAdapter(object, metaclass=abc.ABCMeta):
63 """
63 """
64 Base class for platform adapters.
64 Base class for platform adapters.
65
65
66 Defines common interface for WSGI framework specific functionality.
66 Defines common interface for WSGI framework specific functionality.
67
67
68 """
68 """
69
69
70 __metaclass__ = abc.ABCMeta
71
72 @abc.abstractproperty
70 @abc.abstractproperty
73 def params(self):
71 def params(self):
74 """
72 """
75 Must return a :class:`dict` of all request parameters of any HTTP
73 Must return a :class:`dict` of all request parameters of any HTTP
76 method.
74 method.
77
75
78 :returns:
76 :returns:
79 :class:`dict`
77 :class:`dict`
80
78
81 """
79 """
82
80
83 @abc.abstractproperty
81 @abc.abstractproperty
84 def url(self):
82 def url(self):
85 """
83 """
86 Must return the url of the actual request including path but without
84 Must return the url of the actual request including path but without
87 query and fragment.
85 query and fragment.
88
86
89 :returns:
87 :returns:
90 :class:`str`
88 :class:`str`
91
89
92 """
90 """
93
91
94 @abc.abstractproperty
92 @abc.abstractproperty
95 def cookies(self):
93 def cookies(self):
96 """
94 """
97 Must return cookies as a :class:`dict`.
95 Must return cookies as a :class:`dict`.
98
96
99 :returns:
97 :returns:
100 :class:`dict`
98 :class:`dict`
101
99
102 """
100 """
103
101
104 @abc.abstractmethod
102 @abc.abstractmethod
105 def write(self, value):
103 def write(self, value):
106 """
104 """
107 Must write specified value to response.
105 Must write specified value to response.
108
106
109 :param str value:
107 :param str value:
110 String to be written to response.
108 String to be written to response.
111
109
112 """
110 """
113
111
114 @abc.abstractmethod
112 @abc.abstractmethod
115 def set_header(self, key, value):
113 def set_header(self, key, value):
116 """
114 """
117 Must set response headers to ``Key: value``.
115 Must set response headers to ``Key: value``.
118
116
119 :param str key:
117 :param str key:
120 Header name.
118 Header name.
121
119
122 :param str value:
120 :param str value:
123 Header value.
121 Header value.
124
122
125 """
123 """
126
124
127 @abc.abstractmethod
125 @abc.abstractmethod
128 def set_status(self, status):
126 def set_status(self, status):
129 """
127 """
130 Must set the response status e.g. ``'302 Found'``.
128 Must set the response status e.g. ``'302 Found'``.
131
129
132 :param str status:
130 :param str status:
133 The HTTP response status.
131 The HTTP response status.
134
132
135 """
133 """
136
134
137
135
138 class DjangoAdapter(BaseAdapter):
136 class DjangoAdapter(BaseAdapter):
139 """
137 """
140 Adapter for the |django|_ framework.
138 Adapter for the |django|_ framework.
141 """
139 """
142
140
143 def __init__(self, request, response):
141 def __init__(self, request, response):
144 """
142 """
145 :param request:
143 :param request:
146 An instance of the :class:`django.http.HttpRequest` class.
144 An instance of the :class:`django.http.HttpRequest` class.
147
145
148 :param response:
146 :param response:
149 An instance of the :class:`django.http.HttpResponse` class.
147 An instance of the :class:`django.http.HttpResponse` class.
150 """
148 """
151 self.request = request
149 self.request = request
152 self.response = response
150 self.response = response
153
151
154 @property
152 @property
155 def params(self):
153 def params(self):
156 params = {}
154 params = {}
157 params.update(self.request.GET.dict())
155 params.update(self.request.GET.dict())
158 params.update(self.request.POST.dict())
156 params.update(self.request.POST.dict())
159 return params
157 return params
160
158
161 @property
159 @property
162 def url(self):
160 def url(self):
163 return self.request.build_absolute_uri(self.request.path)
161 return self.request.build_absolute_uri(self.request.path)
164
162
165 @property
163 @property
166 def cookies(self):
164 def cookies(self):
167 return dict(self.request.COOKIES)
165 return dict(self.request.COOKIES)
168
166
169 def write(self, value):
167 def write(self, value):
170 self.response.write(value)
168 self.response.write(value)
171
169
172 def set_header(self, key, value):
170 def set_header(self, key, value):
173 self.response[key] = value
171 self.response[key] = value
174
172
175 def set_status(self, status):
173 def set_status(self, status):
176 status_code, reason = status.split(' ', 1)
174 status_code, reason = status.split(' ', 1)
177 self.response.status_code = int(status_code)
175 self.response.status_code = int(status_code)
178
176
179
177
180 class WebObAdapter(BaseAdapter):
178 class WebObAdapter(BaseAdapter):
181 """
179 """
182 Adapter for the |webob|_ package.
180 Adapter for the |webob|_ package.
183 """
181 """
184
182
185 def __init__(self, request, response):
183 def __init__(self, request, response):
186 """
184 """
187 :param request:
185 :param request:
188 A |webob|_ :class:`Request` instance.
186 A |webob|_ :class:`Request` instance.
189
187
190 :param response:
188 :param response:
191 A |webob|_ :class:`Response` instance.
189 A |webob|_ :class:`Response` instance.
192 """
190 """
193 self.request = request
191 self.request = request
194 self.response = response
192 self.response = response
195
193
196 # =========================================================================
194 # =========================================================================
197 # Request
195 # Request
198 # =========================================================================
196 # =========================================================================
199
197
200 @property
198 @property
201 def url(self):
199 def url(self):
202 return self.request.path_url
200 return self.request.path_url
203
201
204 @property
202 @property
205 def params(self):
203 def params(self):
206 return dict(self.request.params)
204 return dict(self.request.params)
207
205
208 @property
206 @property
209 def cookies(self):
207 def cookies(self):
210 return dict(self.request.cookies)
208 return dict(self.request.cookies)
211
209
212 # =========================================================================
210 # =========================================================================
213 # Response
211 # Response
214 # =========================================================================
212 # =========================================================================
215
213
216 def write(self, value):
214 def write(self, value):
217 self.response.write(value)
215 self.response.write(value)
218
216
219 def set_header(self, key, value):
217 def set_header(self, key, value):
220 self.response.headers[key] = str(value)
218 self.response.headers[key] = str(value)
221
219
222 def set_status(self, status):
220 def set_status(self, status):
223 self.response.status = status
221 self.response.status = status
224
222
225
223
226 class Webapp2Adapter(WebObAdapter):
224 class Webapp2Adapter(WebObAdapter):
227 """
225 """
228 Adapter for the |webapp2|_ framework.
226 Adapter for the |webapp2|_ framework.
229
227
230 Inherits from the :class:`.WebObAdapter`.
228 Inherits from the :class:`.WebObAdapter`.
231
229
232 """
230 """
233
231
234 def __init__(self, handler):
232 def __init__(self, handler):
235 """
233 """
236 :param handler:
234 :param handler:
237 A :class:`webapp2.RequestHandler` instance.
235 A :class:`webapp2.RequestHandler` instance.
238 """
236 """
239 self.request = handler.request
237 self.request = handler.request
240 self.response = handler.response
238 self.response = handler.response
241
239
242
240
243 class WerkzeugAdapter(BaseAdapter):
241 class WerkzeugAdapter(BaseAdapter):
244 """
242 """
245 Adapter for |flask|_ and other |werkzeug|_ based frameworks.
243 Adapter for |flask|_ and other |werkzeug|_ based frameworks.
246
244
247 Thanks to `Mark Steve Samson <http://marksteve.com>`_.
245 Thanks to `Mark Steve Samson <http://marksteve.com>`_.
248
246
249 """
247 """
250
248
251 @property
249 @property
252 def params(self):
250 def params(self):
253 return self.request.args
251 return self.request.args
254
252
255 @property
253 @property
256 def url(self):
254 def url(self):
257 return self.request.base_url
255 return self.request.base_url
258
256
259 @property
257 @property
260 def cookies(self):
258 def cookies(self):
261 return self.request.cookies
259 return self.request.cookies
262
260
263 def __init__(self, request, response):
261 def __init__(self, request, response):
264 """
262 """
265 :param request:
263 :param request:
266 Instance of the :class:`werkzeug.wrappers.Request` class.
264 Instance of the :class:`werkzeug.wrappers.Request` class.
267
265
268 :param response:
266 :param response:
269 Instance of the :class:`werkzeug.wrappers.Response` class.
267 Instance of the :class:`werkzeug.wrappers.Response` class.
270 """
268 """
271
269
272 self.request = request
270 self.request = request
273 self.response = response
271 self.response = response
274
272
275 def write(self, value):
273 def write(self, value):
276 self.response.data = self.response.data + value
274 self.response.data = self.response.data + value
277
275
278 def set_header(self, key, value):
276 def set_header(self, key, value):
279 self.response.headers[key] = value
277 self.response.headers[key] = value
280
278
281 def set_status(self, status):
279 def set_status(self, status):
282 self.response.status = status
280 self.response.status = status
@@ -1,305 +1,305 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2017-2020 RhodeCode GmbH
3 # Copyright (C) 2017-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import logging
21 import logging
22 import datetime
22 import datetime
23
23
24 from rhodecode.lib.jsonalchemy import JsonRaw
24 from rhodecode.lib.jsonalchemy import JsonRaw
25 from rhodecode.model import meta
25 from rhodecode.model import meta
26 from rhodecode.model.db import User, UserLog, Repository
26 from rhodecode.model.db import User, UserLog, Repository
27
27
28
28
29 log = logging.getLogger(__name__)
29 log = logging.getLogger(__name__)
30
30
31 # action as key, and expected action_data as value
31 # action as key, and expected action_data as value
32 ACTIONS_V1 = {
32 ACTIONS_V1 = {
33 'user.login.success': {'user_agent': ''},
33 'user.login.success': {'user_agent': ''},
34 'user.login.failure': {'user_agent': ''},
34 'user.login.failure': {'user_agent': ''},
35 'user.logout': {'user_agent': ''},
35 'user.logout': {'user_agent': ''},
36 'user.register': {},
36 'user.register': {},
37 'user.password.reset_request': {},
37 'user.password.reset_request': {},
38 'user.push': {'user_agent': '', 'commit_ids': []},
38 'user.push': {'user_agent': '', 'commit_ids': []},
39 'user.pull': {'user_agent': ''},
39 'user.pull': {'user_agent': ''},
40
40
41 'user.create': {'data': {}},
41 'user.create': {'data': {}},
42 'user.delete': {'old_data': {}},
42 'user.delete': {'old_data': {}},
43 'user.edit': {'old_data': {}},
43 'user.edit': {'old_data': {}},
44 'user.edit.permissions': {},
44 'user.edit.permissions': {},
45 'user.edit.ip.add': {'ip': {}, 'user': {}},
45 'user.edit.ip.add': {'ip': {}, 'user': {}},
46 'user.edit.ip.delete': {'ip': {}, 'user': {}},
46 'user.edit.ip.delete': {'ip': {}, 'user': {}},
47 'user.edit.token.add': {'token': {}, 'user': {}},
47 'user.edit.token.add': {'token': {}, 'user': {}},
48 'user.edit.token.delete': {'token': {}, 'user': {}},
48 'user.edit.token.delete': {'token': {}, 'user': {}},
49 'user.edit.email.add': {'email': ''},
49 'user.edit.email.add': {'email': ''},
50 'user.edit.email.delete': {'email': ''},
50 'user.edit.email.delete': {'email': ''},
51 'user.edit.ssh_key.add': {'token': {}, 'user': {}},
51 'user.edit.ssh_key.add': {'token': {}, 'user': {}},
52 'user.edit.ssh_key.delete': {'token': {}, 'user': {}},
52 'user.edit.ssh_key.delete': {'token': {}, 'user': {}},
53 'user.edit.password_reset.enabled': {},
53 'user.edit.password_reset.enabled': {},
54 'user.edit.password_reset.disabled': {},
54 'user.edit.password_reset.disabled': {},
55
55
56 'user_group.create': {'data': {}},
56 'user_group.create': {'data': {}},
57 'user_group.delete': {'old_data': {}},
57 'user_group.delete': {'old_data': {}},
58 'user_group.edit': {'old_data': {}},
58 'user_group.edit': {'old_data': {}},
59 'user_group.edit.permissions': {},
59 'user_group.edit.permissions': {},
60 'user_group.edit.member.add': {'user': {}},
60 'user_group.edit.member.add': {'user': {}},
61 'user_group.edit.member.delete': {'user': {}},
61 'user_group.edit.member.delete': {'user': {}},
62
62
63 'repo.create': {'data': {}},
63 'repo.create': {'data': {}},
64 'repo.fork': {'data': {}},
64 'repo.fork': {'data': {}},
65 'repo.edit': {'old_data': {}},
65 'repo.edit': {'old_data': {}},
66 'repo.edit.permissions': {},
66 'repo.edit.permissions': {},
67 'repo.edit.permissions.branch': {},
67 'repo.edit.permissions.branch': {},
68 'repo.archive': {'old_data': {}},
68 'repo.archive': {'old_data': {}},
69 'repo.delete': {'old_data': {}},
69 'repo.delete': {'old_data': {}},
70
70
71 'repo.archive.download': {'user_agent': '', 'archive_name': '',
71 'repo.archive.download': {'user_agent': '', 'archive_name': '',
72 'archive_spec': '', 'archive_cached': ''},
72 'archive_spec': '', 'archive_cached': ''},
73
73
74 'repo.permissions.branch_rule.create': {},
74 'repo.permissions.branch_rule.create': {},
75 'repo.permissions.branch_rule.edit': {},
75 'repo.permissions.branch_rule.edit': {},
76 'repo.permissions.branch_rule.delete': {},
76 'repo.permissions.branch_rule.delete': {},
77
77
78 'repo.pull_request.create': '',
78 'repo.pull_request.create': '',
79 'repo.pull_request.edit': '',
79 'repo.pull_request.edit': '',
80 'repo.pull_request.delete': '',
80 'repo.pull_request.delete': '',
81 'repo.pull_request.close': '',
81 'repo.pull_request.close': '',
82 'repo.pull_request.merge': '',
82 'repo.pull_request.merge': '',
83 'repo.pull_request.vote': '',
83 'repo.pull_request.vote': '',
84 'repo.pull_request.comment.create': '',
84 'repo.pull_request.comment.create': '',
85 'repo.pull_request.comment.edit': '',
85 'repo.pull_request.comment.edit': '',
86 'repo.pull_request.comment.delete': '',
86 'repo.pull_request.comment.delete': '',
87
87
88 'repo.pull_request.reviewer.add': '',
88 'repo.pull_request.reviewer.add': '',
89 'repo.pull_request.reviewer.delete': '',
89 'repo.pull_request.reviewer.delete': '',
90
90
91 'repo.pull_request.observer.add': '',
91 'repo.pull_request.observer.add': '',
92 'repo.pull_request.observer.delete': '',
92 'repo.pull_request.observer.delete': '',
93
93
94 'repo.commit.strip': {'commit_id': ''},
94 'repo.commit.strip': {'commit_id': ''},
95 'repo.commit.comment.create': {'data': {}},
95 'repo.commit.comment.create': {'data': {}},
96 'repo.commit.comment.delete': {'data': {}},
96 'repo.commit.comment.delete': {'data': {}},
97 'repo.commit.comment.edit': {'data': {}},
97 'repo.commit.comment.edit': {'data': {}},
98 'repo.commit.vote': '',
98 'repo.commit.vote': '',
99
99
100 'repo.artifact.add': '',
100 'repo.artifact.add': '',
101 'repo.artifact.delete': '',
101 'repo.artifact.delete': '',
102
102
103 'repo_group.create': {'data': {}},
103 'repo_group.create': {'data': {}},
104 'repo_group.edit': {'old_data': {}},
104 'repo_group.edit': {'old_data': {}},
105 'repo_group.edit.permissions': {},
105 'repo_group.edit.permissions': {},
106 'repo_group.delete': {'old_data': {}},
106 'repo_group.delete': {'old_data': {}},
107 }
107 }
108
108
109 ACTIONS = ACTIONS_V1
109 ACTIONS = ACTIONS_V1
110
110
111 SOURCE_WEB = 'source_web'
111 SOURCE_WEB = 'source_web'
112 SOURCE_API = 'source_api'
112 SOURCE_API = 'source_api'
113
113
114
114
115 class UserWrap(object):
115 class UserWrap(object):
116 """
116 """
117 Fake object used to imitate AuthUser
117 Fake object used to imitate AuthUser
118 """
118 """
119
119
120 def __init__(self, user_id=None, username=None, ip_addr=None):
120 def __init__(self, user_id=None, username=None, ip_addr=None):
121 self.user_id = user_id
121 self.user_id = user_id
122 self.username = username
122 self.username = username
123 self.ip_addr = ip_addr
123 self.ip_addr = ip_addr
124
124
125
125
126 class RepoWrap(object):
126 class RepoWrap(object):
127 """
127 """
128 Fake object used to imitate RepoObject that audit logger requires
128 Fake object used to imitate RepoObject that audit logger requires
129 """
129 """
130
130
131 def __init__(self, repo_id=None, repo_name=None):
131 def __init__(self, repo_id=None, repo_name=None):
132 self.repo_id = repo_id
132 self.repo_id = repo_id
133 self.repo_name = repo_name
133 self.repo_name = repo_name
134
134
135
135
136 def _store_log(action_name, action_data, user_id, username, user_data,
136 def _store_log(action_name, action_data, user_id, username, user_data,
137 ip_address, repository_id, repository_name):
137 ip_address, repository_id, repository_name):
138 user_log = UserLog()
138 user_log = UserLog()
139 user_log.version = UserLog.VERSION_2
139 user_log.version = UserLog.VERSION_2
140
140
141 user_log.action = action_name
141 user_log.action = action_name
142 user_log.action_data = action_data or JsonRaw(u'{}')
142 user_log.action_data = action_data or JsonRaw('{}')
143
143
144 user_log.user_ip = ip_address
144 user_log.user_ip = ip_address
145
145
146 user_log.user_id = user_id
146 user_log.user_id = user_id
147 user_log.username = username
147 user_log.username = username
148 user_log.user_data = user_data or JsonRaw(u'{}')
148 user_log.user_data = user_data or JsonRaw('{}')
149
149
150 user_log.repository_id = repository_id
150 user_log.repository_id = repository_id
151 user_log.repository_name = repository_name
151 user_log.repository_name = repository_name
152
152
153 user_log.action_date = datetime.datetime.now()
153 user_log.action_date = datetime.datetime.now()
154
154
155 return user_log
155 return user_log
156
156
157
157
158 def store_web(*args, **kwargs):
158 def store_web(*args, **kwargs):
159 action_data = {}
159 action_data = {}
160 org_action_data = kwargs.pop('action_data', {})
160 org_action_data = kwargs.pop('action_data', {})
161 action_data.update(org_action_data)
161 action_data.update(org_action_data)
162 action_data['source'] = SOURCE_WEB
162 action_data['source'] = SOURCE_WEB
163 kwargs['action_data'] = action_data
163 kwargs['action_data'] = action_data
164
164
165 return store(*args, **kwargs)
165 return store(*args, **kwargs)
166
166
167
167
168 def store_api(*args, **kwargs):
168 def store_api(*args, **kwargs):
169 action_data = {}
169 action_data = {}
170 org_action_data = kwargs.pop('action_data', {})
170 org_action_data = kwargs.pop('action_data', {})
171 action_data.update(org_action_data)
171 action_data.update(org_action_data)
172 action_data['source'] = SOURCE_API
172 action_data['source'] = SOURCE_API
173 kwargs['action_data'] = action_data
173 kwargs['action_data'] = action_data
174
174
175 return store(*args, **kwargs)
175 return store(*args, **kwargs)
176
176
177
177
178 def store(action, user, action_data=None, user_data=None, ip_addr=None,
178 def store(action, user, action_data=None, user_data=None, ip_addr=None,
179 repo=None, sa_session=None, commit=False):
179 repo=None, sa_session=None, commit=False):
180 """
180 """
181 Audit logger for various actions made by users, typically this
181 Audit logger for various actions made by users, typically this
182 results in a call such::
182 results in a call such::
183
183
184 from rhodecode.lib import audit_logger
184 from rhodecode.lib import audit_logger
185
185
186 audit_logger.store(
186 audit_logger.store(
187 'repo.edit', user=self._rhodecode_user)
187 'repo.edit', user=self._rhodecode_user)
188 audit_logger.store(
188 audit_logger.store(
189 'repo.delete', action_data={'data': repo_data},
189 'repo.delete', action_data={'data': repo_data},
190 user=audit_logger.UserWrap(username='itried-login', ip_addr='8.8.8.8'))
190 user=audit_logger.UserWrap(username='itried-login', ip_addr='8.8.8.8'))
191
191
192 # repo action
192 # repo action
193 audit_logger.store(
193 audit_logger.store(
194 'repo.delete',
194 'repo.delete',
195 user=audit_logger.UserWrap(username='itried-login', ip_addr='8.8.8.8'),
195 user=audit_logger.UserWrap(username='itried-login', ip_addr='8.8.8.8'),
196 repo=audit_logger.RepoWrap(repo_name='some-repo'))
196 repo=audit_logger.RepoWrap(repo_name='some-repo'))
197
197
198 # repo action, when we know and have the repository object already
198 # repo action, when we know and have the repository object already
199 audit_logger.store(
199 audit_logger.store(
200 'repo.delete', action_data={'source': audit_logger.SOURCE_WEB, },
200 'repo.delete', action_data={'source': audit_logger.SOURCE_WEB, },
201 user=self._rhodecode_user,
201 user=self._rhodecode_user,
202 repo=repo_object)
202 repo=repo_object)
203
203
204 # alternative wrapper to the above
204 # alternative wrapper to the above
205 audit_logger.store_web(
205 audit_logger.store_web(
206 'repo.delete', action_data={},
206 'repo.delete', action_data={},
207 user=self._rhodecode_user,
207 user=self._rhodecode_user,
208 repo=repo_object)
208 repo=repo_object)
209
209
210 # without an user ?
210 # without an user ?
211 audit_logger.store(
211 audit_logger.store(
212 'user.login.failure',
212 'user.login.failure',
213 user=audit_logger.UserWrap(
213 user=audit_logger.UserWrap(
214 username=self.request.params.get('username'),
214 username=self.request.params.get('username'),
215 ip_addr=self.request.remote_addr))
215 ip_addr=self.request.remote_addr))
216
216
217 """
217 """
218 from rhodecode.lib.utils2 import safe_unicode
218 from rhodecode.lib.utils2 import safe_unicode
219 from rhodecode.lib.auth import AuthUser
219 from rhodecode.lib.auth import AuthUser
220
220
221 action_spec = ACTIONS.get(action, None)
221 action_spec = ACTIONS.get(action, None)
222 if action_spec is None:
222 if action_spec is None:
223 raise ValueError('Action `{}` is not supported'.format(action))
223 raise ValueError('Action `{}` is not supported'.format(action))
224
224
225 if not sa_session:
225 if not sa_session:
226 sa_session = meta.Session()
226 sa_session = meta.Session()
227
227
228 try:
228 try:
229 username = getattr(user, 'username', None)
229 username = getattr(user, 'username', None)
230 if not username:
230 if not username:
231 pass
231 pass
232
232
233 user_id = getattr(user, 'user_id', None)
233 user_id = getattr(user, 'user_id', None)
234 if not user_id:
234 if not user_id:
235 # maybe we have username ? Try to figure user_id from username
235 # maybe we have username ? Try to figure user_id from username
236 if username:
236 if username:
237 user_id = getattr(
237 user_id = getattr(
238 User.get_by_username(username), 'user_id', None)
238 User.get_by_username(username), 'user_id', None)
239
239
240 ip_addr = ip_addr or getattr(user, 'ip_addr', None)
240 ip_addr = ip_addr or getattr(user, 'ip_addr', None)
241 if not ip_addr:
241 if not ip_addr:
242 pass
242 pass
243
243
244 if not user_data:
244 if not user_data:
245 # try to get this from the auth user
245 # try to get this from the auth user
246 if isinstance(user, AuthUser):
246 if isinstance(user, AuthUser):
247 user_data = {
247 user_data = {
248 'username': user.username,
248 'username': user.username,
249 'email': user.email,
249 'email': user.email,
250 }
250 }
251
251
252 repository_name = getattr(repo, 'repo_name', None)
252 repository_name = getattr(repo, 'repo_name', None)
253 repository_id = getattr(repo, 'repo_id', None)
253 repository_id = getattr(repo, 'repo_id', None)
254 if not repository_id:
254 if not repository_id:
255 # maybe we have repo_name ? Try to figure repo_id from repo_name
255 # maybe we have repo_name ? Try to figure repo_id from repo_name
256 if repository_name:
256 if repository_name:
257 repository_id = getattr(
257 repository_id = getattr(
258 Repository.get_by_repo_name(repository_name), 'repo_id', None)
258 Repository.get_by_repo_name(repository_name), 'repo_id', None)
259
259
260 action_name = safe_unicode(action)
260 action_name = safe_unicode(action)
261 ip_address = safe_unicode(ip_addr)
261 ip_address = safe_unicode(ip_addr)
262
262
263 with sa_session.no_autoflush:
263 with sa_session.no_autoflush:
264
264
265 user_log = _store_log(
265 user_log = _store_log(
266 action_name=action_name,
266 action_name=action_name,
267 action_data=action_data or {},
267 action_data=action_data or {},
268 user_id=user_id,
268 user_id=user_id,
269 username=username,
269 username=username,
270 user_data=user_data or {},
270 user_data=user_data or {},
271 ip_address=ip_address,
271 ip_address=ip_address,
272 repository_id=repository_id,
272 repository_id=repository_id,
273 repository_name=repository_name
273 repository_name=repository_name
274 )
274 )
275
275
276 sa_session.add(user_log)
276 sa_session.add(user_log)
277 if commit:
277 if commit:
278 sa_session.commit()
278 sa_session.commit()
279 entry_id = user_log.entry_id or ''
279 entry_id = user_log.entry_id or ''
280
280
281 update_user_last_activity(sa_session, user_id)
281 update_user_last_activity(sa_session, user_id)
282
282
283 if commit:
283 if commit:
284 sa_session.commit()
284 sa_session.commit()
285
285
286 log.info('AUDIT[%s]: Logging action: `%s` by user:id:%s[%s] ip:%s',
286 log.info('AUDIT[%s]: Logging action: `%s` by user:id:%s[%s] ip:%s',
287 entry_id, action_name, user_id, username, ip_address,
287 entry_id, action_name, user_id, username, ip_address,
288 extra={"entry_id": entry_id, "action": action_name,
288 extra={"entry_id": entry_id, "action": action_name,
289 "user_id": user_id, "ip": ip_address})
289 "user_id": user_id, "ip": ip_address})
290
290
291 except Exception:
291 except Exception:
292 log.exception('AUDIT: failed to store audit log')
292 log.exception('AUDIT: failed to store audit log')
293
293
294
294
295 def update_user_last_activity(sa_session, user_id):
295 def update_user_last_activity(sa_session, user_id):
296 _last_activity = datetime.datetime.now()
296 _last_activity = datetime.datetime.now()
297 try:
297 try:
298 sa_session.query(User).filter(User.user_id == user_id).update(
298 sa_session.query(User).filter(User.user_id == user_id).update(
299 {"last_activity": _last_activity})
299 {"last_activity": _last_activity})
300 log.debug(
300 log.debug(
301 'updated user `%s` last activity to:%s', user_id, _last_activity)
301 'updated user `%s` last activity to:%s', user_id, _last_activity)
302 except Exception:
302 except Exception:
303 log.exception("Failed last activity update for user_id: %s", user_id)
303 log.exception("Failed last activity update for user_id: %s", user_id)
304 sa_session.rollback()
304 sa_session.rollback()
305
305
@@ -1,371 +1,371 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2016-2020 RhodeCode GmbH
3 # Copyright (C) 2016-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import os
21 import os
22 import hashlib
22 import hashlib
23 import itsdangerous
23 import itsdangerous
24 import logging
24 import logging
25 import requests
25 import requests
26 import datetime
26 import datetime
27
27
28 from dogpile.util.readwrite_lock import ReadWriteMutex
28 from dogpile.util.readwrite_lock import ReadWriteMutex
29 from pyramid.threadlocal import get_current_registry
29 from pyramid.threadlocal import get_current_registry
30
30
31 import rhodecode.lib.helpers as h
31 import rhodecode.lib.helpers as h
32 from rhodecode.lib.auth import HasRepoPermissionAny
32 from rhodecode.lib.auth import HasRepoPermissionAny
33 from rhodecode.lib.ext_json import json
33 from rhodecode.lib.ext_json import json
34 from rhodecode.model.db import User
34 from rhodecode.model.db import User
35
35
36 log = logging.getLogger(__name__)
36 log = logging.getLogger(__name__)
37
37
38 LOCK = ReadWriteMutex()
38 LOCK = ReadWriteMutex()
39
39
40 USER_STATE_PUBLIC_KEYS = [
40 USER_STATE_PUBLIC_KEYS = [
41 'id', 'username', 'first_name', 'last_name',
41 'id', 'username', 'first_name', 'last_name',
42 'icon_link', 'display_name', 'display_link']
42 'icon_link', 'display_name', 'display_link']
43
43
44
44
45 class ChannelstreamException(Exception):
45 class ChannelstreamException(Exception):
46 pass
46 pass
47
47
48
48
49 class ChannelstreamConnectionException(ChannelstreamException):
49 class ChannelstreamConnectionException(ChannelstreamException):
50 pass
50 pass
51
51
52
52
53 class ChannelstreamPermissionException(ChannelstreamException):
53 class ChannelstreamPermissionException(ChannelstreamException):
54 pass
54 pass
55
55
56
56
57 def get_channelstream_server_url(config, endpoint):
57 def get_channelstream_server_url(config, endpoint):
58 return 'http://{}{}'.format(config['server'], endpoint)
58 return 'http://{}{}'.format(config['server'], endpoint)
59
59
60
60
61 def channelstream_request(config, payload, endpoint, raise_exc=True):
61 def channelstream_request(config, payload, endpoint, raise_exc=True):
62 signer = itsdangerous.TimestampSigner(config['secret'])
62 signer = itsdangerous.TimestampSigner(config['secret'])
63 sig_for_server = signer.sign(endpoint)
63 sig_for_server = signer.sign(endpoint)
64 secret_headers = {'x-channelstream-secret': sig_for_server,
64 secret_headers = {'x-channelstream-secret': sig_for_server,
65 'x-channelstream-endpoint': endpoint,
65 'x-channelstream-endpoint': endpoint,
66 'Content-Type': 'application/json'}
66 'Content-Type': 'application/json'}
67 req_url = get_channelstream_server_url(config, endpoint)
67 req_url = get_channelstream_server_url(config, endpoint)
68
68
69 log.debug('Sending a channelstream request to endpoint: `%s`', req_url)
69 log.debug('Sending a channelstream request to endpoint: `%s`', req_url)
70 response = None
70 response = None
71 try:
71 try:
72 response = requests.post(req_url, data=json.dumps(payload),
72 response = requests.post(req_url, data=json.dumps(payload),
73 headers=secret_headers).json()
73 headers=secret_headers).json()
74 except requests.ConnectionError:
74 except requests.ConnectionError:
75 log.exception('ConnectionError occurred for endpoint %s', req_url)
75 log.exception('ConnectionError occurred for endpoint %s', req_url)
76 if raise_exc:
76 if raise_exc:
77 raise ChannelstreamConnectionException(req_url)
77 raise ChannelstreamConnectionException(req_url)
78 except Exception:
78 except Exception:
79 log.exception('Exception related to Channelstream happened')
79 log.exception('Exception related to Channelstream happened')
80 if raise_exc:
80 if raise_exc:
81 raise ChannelstreamConnectionException()
81 raise ChannelstreamConnectionException()
82 log.debug('Got channelstream response: %s', response)
82 log.debug('Got channelstream response: %s', response)
83 return response
83 return response
84
84
85
85
86 def get_user_data(user_id):
86 def get_user_data(user_id):
87 user = User.get(user_id)
87 user = User.get(user_id)
88 return {
88 return {
89 'id': user.user_id,
89 'id': user.user_id,
90 'username': user.username,
90 'username': user.username,
91 'first_name': user.first_name,
91 'first_name': user.first_name,
92 'last_name': user.last_name,
92 'last_name': user.last_name,
93 'icon_link': h.gravatar_url(user.email, 60),
93 'icon_link': h.gravatar_url(user.email, 60),
94 'display_name': h.person(user, 'username_or_name_or_email'),
94 'display_name': h.person(user, 'username_or_name_or_email'),
95 'display_link': h.link_to_user(user),
95 'display_link': h.link_to_user(user),
96 'notifications': user.user_data.get('notification_status', True)
96 'notifications': user.user_data.get('notification_status', True)
97 }
97 }
98
98
99
99
100 def broadcast_validator(channel_name):
100 def broadcast_validator(channel_name):
101 """ checks if user can access the broadcast channel """
101 """ checks if user can access the broadcast channel """
102 if channel_name == 'broadcast':
102 if channel_name == 'broadcast':
103 return True
103 return True
104
104
105
105
106 def repo_validator(channel_name):
106 def repo_validator(channel_name):
107 """ checks if user can access the broadcast channel """
107 """ checks if user can access the broadcast channel """
108 channel_prefix = '/repo$'
108 channel_prefix = '/repo$'
109 if channel_name.startswith(channel_prefix):
109 if channel_name.startswith(channel_prefix):
110 elements = channel_name[len(channel_prefix):].split('$')
110 elements = channel_name[len(channel_prefix):].split('$')
111 repo_name = elements[0]
111 repo_name = elements[0]
112 can_access = HasRepoPermissionAny(
112 can_access = HasRepoPermissionAny(
113 'repository.read',
113 'repository.read',
114 'repository.write',
114 'repository.write',
115 'repository.admin')(repo_name)
115 'repository.admin')(repo_name)
116 log.debug(
116 log.debug(
117 'permission check for %s channel resulted in %s',
117 'permission check for %s channel resulted in %s',
118 repo_name, can_access)
118 repo_name, can_access)
119 if can_access:
119 if can_access:
120 return True
120 return True
121 return False
121 return False
122
122
123
123
124 def check_channel_permissions(channels, plugin_validators, should_raise=True):
124 def check_channel_permissions(channels, plugin_validators, should_raise=True):
125 valid_channels = []
125 valid_channels = []
126
126
127 validators = [broadcast_validator, repo_validator]
127 validators = [broadcast_validator, repo_validator]
128 if plugin_validators:
128 if plugin_validators:
129 validators.extend(plugin_validators)
129 validators.extend(plugin_validators)
130 for channel_name in channels:
130 for channel_name in channels:
131 is_valid = False
131 is_valid = False
132 for validator in validators:
132 for validator in validators:
133 if validator(channel_name):
133 if validator(channel_name):
134 is_valid = True
134 is_valid = True
135 break
135 break
136 if is_valid:
136 if is_valid:
137 valid_channels.append(channel_name)
137 valid_channels.append(channel_name)
138 else:
138 else:
139 if should_raise:
139 if should_raise:
140 raise ChannelstreamPermissionException()
140 raise ChannelstreamPermissionException()
141 return valid_channels
141 return valid_channels
142
142
143
143
144 def get_channels_info(self, channels):
144 def get_channels_info(self, channels):
145 payload = {'channels': channels}
145 payload = {'channels': channels}
146 # gather persistence info
146 # gather persistence info
147 return channelstream_request(self._config(), payload, '/info')
147 return channelstream_request(self._config(), payload, '/info')
148
148
149
149
150 def parse_channels_info(info_result, include_channel_info=None):
150 def parse_channels_info(info_result, include_channel_info=None):
151 """
151 """
152 Returns data that contains only secure information that can be
152 Returns data that contains only secure information that can be
153 presented to clients
153 presented to clients
154 """
154 """
155 include_channel_info = include_channel_info or []
155 include_channel_info = include_channel_info or []
156
156
157 user_state_dict = {}
157 user_state_dict = {}
158 for userinfo in info_result['users']:
158 for userinfo in info_result['users']:
159 user_state_dict[userinfo['user']] = {
159 user_state_dict[userinfo['user']] = {
160 k: v for k, v in userinfo['state'].items()
160 k: v for k, v in userinfo['state'].items()
161 if k in USER_STATE_PUBLIC_KEYS
161 if k in USER_STATE_PUBLIC_KEYS
162 }
162 }
163
163
164 channels_info = {}
164 channels_info = {}
165
165
166 for c_name, c_info in info_result['channels'].items():
166 for c_name, c_info in info_result['channels'].items():
167 if c_name not in include_channel_info:
167 if c_name not in include_channel_info:
168 continue
168 continue
169 connected_list = []
169 connected_list = []
170 for username in c_info['users']:
170 for username in c_info['users']:
171 connected_list.append({
171 connected_list.append({
172 'user': username,
172 'user': username,
173 'state': user_state_dict[username]
173 'state': user_state_dict[username]
174 })
174 })
175 channels_info[c_name] = {'users': connected_list,
175 channels_info[c_name] = {'users': connected_list,
176 'history': c_info['history']}
176 'history': c_info['history']}
177
177
178 return channels_info
178 return channels_info
179
179
180
180
181 def log_filepath(history_location, channel_name):
181 def log_filepath(history_location, channel_name):
182 hasher = hashlib.sha256()
182 hasher = hashlib.sha256()
183 hasher.update(channel_name.encode('utf8'))
183 hasher.update(channel_name.encode('utf8'))
184 filename = '{}.log'.format(hasher.hexdigest())
184 filename = '{}.log'.format(hasher.hexdigest())
185 filepath = os.path.join(history_location, filename)
185 filepath = os.path.join(history_location, filename)
186 return filepath
186 return filepath
187
187
188
188
189 def read_history(history_location, channel_name):
189 def read_history(history_location, channel_name):
190 filepath = log_filepath(history_location, channel_name)
190 filepath = log_filepath(history_location, channel_name)
191 if not os.path.exists(filepath):
191 if not os.path.exists(filepath):
192 return []
192 return []
193 history_lines_limit = -100
193 history_lines_limit = -100
194 history = []
194 history = []
195 with open(filepath, 'rb') as f:
195 with open(filepath, 'rb') as f:
196 for line in f.readlines()[history_lines_limit:]:
196 for line in f.readlines()[history_lines_limit:]:
197 try:
197 try:
198 history.append(json.loads(line))
198 history.append(json.loads(line))
199 except Exception:
199 except Exception:
200 log.exception('Failed to load history')
200 log.exception('Failed to load history')
201 return history
201 return history
202
202
203
203
204 def update_history_from_logs(config, channels, payload):
204 def update_history_from_logs(config, channels, payload):
205 history_location = config.get('history.location')
205 history_location = config.get('history.location')
206 for channel in channels:
206 for channel in channels:
207 history = read_history(history_location, channel)
207 history = read_history(history_location, channel)
208 payload['channels_info'][channel]['history'] = history
208 payload['channels_info'][channel]['history'] = history
209
209
210
210
211 def write_history(config, message):
211 def write_history(config, message):
212 """ writes a message to a base64encoded filename """
212 """ writes a message to a base64encoded filename """
213 history_location = config.get('history.location')
213 history_location = config.get('history.location')
214 if not os.path.exists(history_location):
214 if not os.path.exists(history_location):
215 return
215 return
216 try:
216 try:
217 LOCK.acquire_write_lock()
217 LOCK.acquire_write_lock()
218 filepath = log_filepath(history_location, message['channel'])
218 filepath = log_filepath(history_location, message['channel'])
219 json_message = json.dumps(message)
219 json_message = json.dumps(message)
220 with open(filepath, 'ab') as f:
220 with open(filepath, 'ab') as f:
221 f.write(json_message)
221 f.write(json_message)
222 f.write('\n')
222 f.write('\n')
223 finally:
223 finally:
224 LOCK.release_write_lock()
224 LOCK.release_write_lock()
225
225
226
226
227 def get_connection_validators(registry):
227 def get_connection_validators(registry):
228 validators = []
228 validators = []
229 for k, config in registry.rhodecode_plugins.items():
229 for k, config in registry.rhodecode_plugins.items():
230 validator = config.get('channelstream', {}).get('connect_validator')
230 validator = config.get('channelstream', {}).get('connect_validator')
231 if validator:
231 if validator:
232 validators.append(validator)
232 validators.append(validator)
233 return validators
233 return validators
234
234
235
235
236 def get_channelstream_config(registry=None):
236 def get_channelstream_config(registry=None):
237 if not registry:
237 if not registry:
238 registry = get_current_registry()
238 registry = get_current_registry()
239
239
240 rhodecode_plugins = getattr(registry, 'rhodecode_plugins', {})
240 rhodecode_plugins = getattr(registry, 'rhodecode_plugins', {})
241 channelstream_config = rhodecode_plugins.get('channelstream', {})
241 channelstream_config = rhodecode_plugins.get('channelstream', {})
242 return channelstream_config
242 return channelstream_config
243
243
244
244
245 def post_message(channel, message, username, registry=None):
245 def post_message(channel, message, username, registry=None):
246 channelstream_config = get_channelstream_config(registry)
246 channelstream_config = get_channelstream_config(registry)
247 if not channelstream_config.get('enabled'):
247 if not channelstream_config.get('enabled'):
248 return
248 return
249
249
250 message_obj = message
250 message_obj = message
251 if isinstance(message, str):
251 if isinstance(message, str):
252 message_obj = {
252 message_obj = {
253 'message': message,
253 'message': message,
254 'level': 'success',
254 'level': 'success',
255 'topic': '/notifications'
255 'topic': '/notifications'
256 }
256 }
257
257
258 log.debug('Channelstream: sending notification to channel %s', channel)
258 log.debug('Channelstream: sending notification to channel %s', channel)
259 payload = {
259 payload = {
260 'type': 'message',
260 'type': 'message',
261 'timestamp': datetime.datetime.utcnow(),
261 'timestamp': datetime.datetime.utcnow(),
262 'user': 'system',
262 'user': 'system',
263 'exclude_users': [username],
263 'exclude_users': [username],
264 'channel': channel,
264 'channel': channel,
265 'message': message_obj
265 'message': message_obj
266 }
266 }
267
267
268 try:
268 try:
269 return channelstream_request(
269 return channelstream_request(
270 channelstream_config, [payload], '/message',
270 channelstream_config, [payload], '/message',
271 raise_exc=False)
271 raise_exc=False)
272 except ChannelstreamException:
272 except ChannelstreamException:
273 log.exception('Failed to send channelstream data')
273 log.exception('Failed to send channelstream data')
274 raise
274 raise
275
275
276
276
277 def _reload_link(label):
277 def _reload_link(label):
278 return (
278 return (
279 '<a onclick="window.location.reload()">'
279 '<a onclick="window.location.reload()">'
280 '<strong>{}</strong>'
280 '<strong>{}</strong>'
281 '</a>'.format(label)
281 '</a>'.format(label)
282 )
282 )
283
283
284
284
285 def pr_channel(pull_request):
285 def pr_channel(pull_request):
286 repo_name = pull_request.target_repo.repo_name
286 repo_name = pull_request.target_repo.repo_name
287 pull_request_id = pull_request.pull_request_id
287 pull_request_id = pull_request.pull_request_id
288 channel = '/repo${}$/pr/{}'.format(repo_name, pull_request_id)
288 channel = '/repo${}$/pr/{}'.format(repo_name, pull_request_id)
289 log.debug('Getting pull-request channelstream broadcast channel: %s', channel)
289 log.debug('Getting pull-request channelstream broadcast channel: %s', channel)
290 return channel
290 return channel
291
291
292
292
293 def comment_channel(repo_name, commit_obj=None, pull_request_obj=None):
293 def comment_channel(repo_name, commit_obj=None, pull_request_obj=None):
294 channel = None
294 channel = None
295 if commit_obj:
295 if commit_obj:
296 channel = u'/repo${}$/commit/{}'.format(
296 channel = '/repo${}$/commit/{}'.format(
297 repo_name, commit_obj.raw_id
297 repo_name, commit_obj.raw_id
298 )
298 )
299 elif pull_request_obj:
299 elif pull_request_obj:
300 channel = u'/repo${}$/pr/{}'.format(
300 channel = '/repo${}$/pr/{}'.format(
301 repo_name, pull_request_obj.pull_request_id
301 repo_name, pull_request_obj.pull_request_id
302 )
302 )
303 log.debug('Getting comment channelstream broadcast channel: %s', channel)
303 log.debug('Getting comment channelstream broadcast channel: %s', channel)
304
304
305 return channel
305 return channel
306
306
307
307
308 def pr_update_channelstream_push(request, pr_broadcast_channel, user, msg, **kwargs):
308 def pr_update_channelstream_push(request, pr_broadcast_channel, user, msg, **kwargs):
309 """
309 """
310 Channel push on pull request update
310 Channel push on pull request update
311 """
311 """
312 if not pr_broadcast_channel:
312 if not pr_broadcast_channel:
313 return
313 return
314
314
315 _ = request.translate
315 _ = request.translate
316
316
317 message = '{} {}'.format(
317 message = '{} {}'.format(
318 msg,
318 msg,
319 _reload_link(_(' Reload page to load changes')))
319 _reload_link(_(' Reload page to load changes')))
320
320
321 message_obj = {
321 message_obj = {
322 'message': message,
322 'message': message,
323 'level': 'success',
323 'level': 'success',
324 'topic': '/notifications'
324 'topic': '/notifications'
325 }
325 }
326
326
327 post_message(
327 post_message(
328 pr_broadcast_channel, message_obj, user.username,
328 pr_broadcast_channel, message_obj, user.username,
329 registry=request.registry)
329 registry=request.registry)
330
330
331
331
332 def comment_channelstream_push(request, comment_broadcast_channel, user, msg, **kwargs):
332 def comment_channelstream_push(request, comment_broadcast_channel, user, msg, **kwargs):
333 """
333 """
334 Channelstream push on comment action, on commit, or pull-request
334 Channelstream push on comment action, on commit, or pull-request
335 """
335 """
336 if not comment_broadcast_channel:
336 if not comment_broadcast_channel:
337 return
337 return
338
338
339 _ = request.translate
339 _ = request.translate
340
340
341 comment_data = kwargs.pop('comment_data', {})
341 comment_data = kwargs.pop('comment_data', {})
342 user_data = kwargs.pop('user_data', {})
342 user_data = kwargs.pop('user_data', {})
343 comment_id = comment_data.keys()[0] if comment_data else ''
343 comment_id = comment_data.keys()[0] if comment_data else ''
344
344
345 message = '<strong>{}</strong> {} #{}'.format(
345 message = '<strong>{}</strong> {} #{}'.format(
346 user.username,
346 user.username,
347 msg,
347 msg,
348 comment_id,
348 comment_id,
349 )
349 )
350
350
351 message_obj = {
351 message_obj = {
352 'message': message,
352 'message': message,
353 'level': 'success',
353 'level': 'success',
354 'topic': '/notifications'
354 'topic': '/notifications'
355 }
355 }
356
356
357 post_message(
357 post_message(
358 comment_broadcast_channel, message_obj, user.username,
358 comment_broadcast_channel, message_obj, user.username,
359 registry=request.registry)
359 registry=request.registry)
360
360
361 message_obj = {
361 message_obj = {
362 'message': None,
362 'message': None,
363 'user': user.username,
363 'user': user.username,
364 'comment_id': comment_id,
364 'comment_id': comment_id,
365 'comment_data': comment_data,
365 'comment_data': comment_data,
366 'user_data': user_data,
366 'user_data': user_data,
367 'topic': '/comment'
367 'topic': '/comment'
368 }
368 }
369 post_message(
369 post_message(
370 comment_broadcast_channel, message_obj, user.username,
370 comment_broadcast_channel, message_obj, user.username,
371 registry=request.registry)
371 registry=request.registry)
@@ -1,797 +1,797 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2011-2020 RhodeCode GmbH
3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import logging
21 import logging
22 import difflib
22 import difflib
23 from itertools import groupby
23 from itertools import groupby
24
24
25 from pygments import lex
25 from pygments import lex
26 from pygments.formatters.html import _get_ttype_class as pygment_token_class
26 from pygments.formatters.html import _get_ttype_class as pygment_token_class
27 from pygments.lexers.special import TextLexer, Token
27 from pygments.lexers.special import TextLexer, Token
28 from pygments.lexers import get_lexer_by_name
28 from pygments.lexers import get_lexer_by_name
29
29
30 from rhodecode.lib.helpers import (
30 from rhodecode.lib.helpers import (
31 get_lexer_for_filenode, html_escape, get_custom_lexer)
31 get_lexer_for_filenode, html_escape, get_custom_lexer)
32 from rhodecode.lib.utils2 import AttributeDict, StrictAttributeDict, safe_unicode
32 from rhodecode.lib.utils2 import AttributeDict, StrictAttributeDict, safe_unicode
33 from rhodecode.lib.vcs.nodes import FileNode
33 from rhodecode.lib.vcs.nodes import FileNode
34 from rhodecode.lib.vcs.exceptions import VCSError, NodeDoesNotExistError
34 from rhodecode.lib.vcs.exceptions import VCSError, NodeDoesNotExistError
35 from rhodecode.lib.diff_match_patch import diff_match_patch
35 from rhodecode.lib.diff_match_patch import diff_match_patch
36 from rhodecode.lib.diffs import LimitedDiffContainer, DEL_FILENODE, BIN_FILENODE
36 from rhodecode.lib.diffs import LimitedDiffContainer, DEL_FILENODE, BIN_FILENODE
37
37
38
38
39 plain_text_lexer = get_lexer_by_name(
39 plain_text_lexer = get_lexer_by_name(
40 'text', stripall=False, stripnl=False, ensurenl=False)
40 'text', stripall=False, stripnl=False, ensurenl=False)
41
41
42
42
43 log = logging.getLogger(__name__)
43 log = logging.getLogger(__name__)
44
44
45
45
46 def filenode_as_lines_tokens(filenode, lexer=None):
46 def filenode_as_lines_tokens(filenode, lexer=None):
47 org_lexer = lexer
47 org_lexer = lexer
48 lexer = lexer or get_lexer_for_filenode(filenode)
48 lexer = lexer or get_lexer_for_filenode(filenode)
49 log.debug('Generating file node pygment tokens for %s, %s, org_lexer:%s',
49 log.debug('Generating file node pygment tokens for %s, %s, org_lexer:%s',
50 lexer, filenode, org_lexer)
50 lexer, filenode, org_lexer)
51 content = filenode.content
51 content = filenode.content
52 tokens = tokenize_string(content, lexer)
52 tokens = tokenize_string(content, lexer)
53 lines = split_token_stream(tokens, content)
53 lines = split_token_stream(tokens, content)
54 rv = list(lines)
54 rv = list(lines)
55 return rv
55 return rv
56
56
57
57
58 def tokenize_string(content, lexer):
58 def tokenize_string(content, lexer):
59 """
59 """
60 Use pygments to tokenize some content based on a lexer
60 Use pygments to tokenize some content based on a lexer
61 ensuring all original new lines and whitespace is preserved
61 ensuring all original new lines and whitespace is preserved
62 """
62 """
63
63
64 lexer.stripall = False
64 lexer.stripall = False
65 lexer.stripnl = False
65 lexer.stripnl = False
66 lexer.ensurenl = False
66 lexer.ensurenl = False
67
67
68 if isinstance(lexer, TextLexer):
68 if isinstance(lexer, TextLexer):
69 lexed = [(Token.Text, content)]
69 lexed = [(Token.Text, content)]
70 else:
70 else:
71 lexed = lex(content, lexer)
71 lexed = lex(content, lexer)
72
72
73 for token_type, token_text in lexed:
73 for token_type, token_text in lexed:
74 yield pygment_token_class(token_type), token_text
74 yield pygment_token_class(token_type), token_text
75
75
76
76
77 def split_token_stream(tokens, content):
77 def split_token_stream(tokens, content):
78 """
78 """
79 Take a list of (TokenType, text) tuples and split them by a string
79 Take a list of (TokenType, text) tuples and split them by a string
80
80
81 split_token_stream([(TEXT, 'some\ntext'), (TEXT, 'more\n')])
81 split_token_stream([(TEXT, 'some\ntext'), (TEXT, 'more\n')])
82 [(TEXT, 'some'), (TEXT, 'text'),
82 [(TEXT, 'some'), (TEXT, 'text'),
83 (TEXT, 'more'), (TEXT, 'text')]
83 (TEXT, 'more'), (TEXT, 'text')]
84 """
84 """
85
85
86 token_buffer = []
86 token_buffer = []
87 for token_class, token_text in tokens:
87 for token_class, token_text in tokens:
88 parts = token_text.split('\n')
88 parts = token_text.split('\n')
89 for part in parts[:-1]:
89 for part in parts[:-1]:
90 token_buffer.append((token_class, part))
90 token_buffer.append((token_class, part))
91 yield token_buffer
91 yield token_buffer
92 token_buffer = []
92 token_buffer = []
93
93
94 token_buffer.append((token_class, parts[-1]))
94 token_buffer.append((token_class, parts[-1]))
95
95
96 if token_buffer:
96 if token_buffer:
97 yield token_buffer
97 yield token_buffer
98 elif content:
98 elif content:
99 # this is a special case, we have the content, but tokenization didn't produce
99 # this is a special case, we have the content, but tokenization didn't produce
100 # any results. THis can happen if know file extensions like .css have some bogus
100 # any results. THis can happen if know file extensions like .css have some bogus
101 # unicode content without any newline characters
101 # unicode content without any newline characters
102 yield [(pygment_token_class(Token.Text), content)]
102 yield [(pygment_token_class(Token.Text), content)]
103
103
104
104
105 def filenode_as_annotated_lines_tokens(filenode):
105 def filenode_as_annotated_lines_tokens(filenode):
106 """
106 """
107 Take a file node and return a list of annotations => lines, if no annotation
107 Take a file node and return a list of annotations => lines, if no annotation
108 is found, it will be None.
108 is found, it will be None.
109
109
110 eg:
110 eg:
111
111
112 [
112 [
113 (annotation1, [
113 (annotation1, [
114 (1, line1_tokens_list),
114 (1, line1_tokens_list),
115 (2, line2_tokens_list),
115 (2, line2_tokens_list),
116 ]),
116 ]),
117 (annotation2, [
117 (annotation2, [
118 (3, line1_tokens_list),
118 (3, line1_tokens_list),
119 ]),
119 ]),
120 (None, [
120 (None, [
121 (4, line1_tokens_list),
121 (4, line1_tokens_list),
122 ]),
122 ]),
123 (annotation1, [
123 (annotation1, [
124 (5, line1_tokens_list),
124 (5, line1_tokens_list),
125 (6, line2_tokens_list),
125 (6, line2_tokens_list),
126 ])
126 ])
127 ]
127 ]
128 """
128 """
129
129
130 commit_cache = {} # cache commit_getter lookups
130 commit_cache = {} # cache commit_getter lookups
131
131
132 def _get_annotation(commit_id, commit_getter):
132 def _get_annotation(commit_id, commit_getter):
133 if commit_id not in commit_cache:
133 if commit_id not in commit_cache:
134 commit_cache[commit_id] = commit_getter()
134 commit_cache[commit_id] = commit_getter()
135 return commit_cache[commit_id]
135 return commit_cache[commit_id]
136
136
137 annotation_lookup = {
137 annotation_lookup = {
138 line_no: _get_annotation(commit_id, commit_getter)
138 line_no: _get_annotation(commit_id, commit_getter)
139 for line_no, commit_id, commit_getter, line_content
139 for line_no, commit_id, commit_getter, line_content
140 in filenode.annotate
140 in filenode.annotate
141 }
141 }
142
142
143 annotations_lines = ((annotation_lookup.get(line_no), line_no, tokens)
143 annotations_lines = ((annotation_lookup.get(line_no), line_no, tokens)
144 for line_no, tokens
144 for line_no, tokens
145 in enumerate(filenode_as_lines_tokens(filenode), 1))
145 in enumerate(filenode_as_lines_tokens(filenode), 1))
146
146
147 grouped_annotations_lines = groupby(annotations_lines, lambda x: x[0])
147 grouped_annotations_lines = groupby(annotations_lines, lambda x: x[0])
148
148
149 for annotation, group in grouped_annotations_lines:
149 for annotation, group in grouped_annotations_lines:
150 yield (
150 yield (
151 annotation, [(line_no, tokens)
151 annotation, [(line_no, tokens)
152 for (_, line_no, tokens) in group]
152 for (_, line_no, tokens) in group]
153 )
153 )
154
154
155
155
156 def render_tokenstream(tokenstream):
156 def render_tokenstream(tokenstream):
157 result = []
157 result = []
158 for token_class, token_ops_texts in rollup_tokenstream(tokenstream):
158 for token_class, token_ops_texts in rollup_tokenstream(tokenstream):
159
159
160 if token_class:
160 if token_class:
161 result.append(u'<span class="%s">' % token_class)
161 result.append('<span class="%s">' % token_class)
162 else:
162 else:
163 result.append(u'<span>')
163 result.append('<span>')
164
164
165 for op_tag, token_text in token_ops_texts:
165 for op_tag, token_text in token_ops_texts:
166
166
167 if op_tag:
167 if op_tag:
168 result.append(u'<%s>' % op_tag)
168 result.append('<%s>' % op_tag)
169
169
170 # NOTE(marcink): in some cases of mixed encodings, we might run into
170 # NOTE(marcink): in some cases of mixed encodings, we might run into
171 # troubles in the html_escape, in this case we say unicode force on token_text
171 # troubles in the html_escape, in this case we say unicode force on token_text
172 # that would ensure "correct" data even with the cost of rendered
172 # that would ensure "correct" data even with the cost of rendered
173 try:
173 try:
174 escaped_text = html_escape(token_text)
174 escaped_text = html_escape(token_text)
175 except TypeError:
175 except TypeError:
176 escaped_text = html_escape(safe_unicode(token_text))
176 escaped_text = html_escape(safe_unicode(token_text))
177
177
178 # TODO: dan: investigate showing hidden characters like space/nl/tab
178 # TODO: dan: investigate showing hidden characters like space/nl/tab
179 # escaped_text = escaped_text.replace(' ', '<sp> </sp>')
179 # escaped_text = escaped_text.replace(' ', '<sp> </sp>')
180 # escaped_text = escaped_text.replace('\n', '<nl>\n</nl>')
180 # escaped_text = escaped_text.replace('\n', '<nl>\n</nl>')
181 # escaped_text = escaped_text.replace('\t', '<tab>\t</tab>')
181 # escaped_text = escaped_text.replace('\t', '<tab>\t</tab>')
182
182
183 result.append(escaped_text)
183 result.append(escaped_text)
184
184
185 if op_tag:
185 if op_tag:
186 result.append(u'</%s>' % op_tag)
186 result.append('</%s>' % op_tag)
187
187
188 result.append(u'</span>')
188 result.append('</span>')
189
189
190 html = ''.join(result)
190 html = ''.join(result)
191 return html
191 return html
192
192
193
193
194 def rollup_tokenstream(tokenstream):
194 def rollup_tokenstream(tokenstream):
195 """
195 """
196 Group a token stream of the format:
196 Group a token stream of the format:
197
197
198 ('class', 'op', 'text')
198 ('class', 'op', 'text')
199 or
199 or
200 ('class', 'text')
200 ('class', 'text')
201
201
202 into
202 into
203
203
204 [('class1',
204 [('class1',
205 [('op1', 'text'),
205 [('op1', 'text'),
206 ('op2', 'text')]),
206 ('op2', 'text')]),
207 ('class2',
207 ('class2',
208 [('op3', 'text')])]
208 [('op3', 'text')])]
209
209
210 This is used to get the minimal tags necessary when
210 This is used to get the minimal tags necessary when
211 rendering to html eg for a token stream ie.
211 rendering to html eg for a token stream ie.
212
212
213 <span class="A"><ins>he</ins>llo</span>
213 <span class="A"><ins>he</ins>llo</span>
214 vs
214 vs
215 <span class="A"><ins>he</ins></span><span class="A">llo</span>
215 <span class="A"><ins>he</ins></span><span class="A">llo</span>
216
216
217 If a 2 tuple is passed in, the output op will be an empty string.
217 If a 2 tuple is passed in, the output op will be an empty string.
218
218
219 eg:
219 eg:
220
220
221 >>> rollup_tokenstream([('classA', '', 'h'),
221 >>> rollup_tokenstream([('classA', '', 'h'),
222 ('classA', 'del', 'ell'),
222 ('classA', 'del', 'ell'),
223 ('classA', '', 'o'),
223 ('classA', '', 'o'),
224 ('classB', '', ' '),
224 ('classB', '', ' '),
225 ('classA', '', 'the'),
225 ('classA', '', 'the'),
226 ('classA', '', 're'),
226 ('classA', '', 're'),
227 ])
227 ])
228
228
229 [('classA', [('', 'h'), ('del', 'ell'), ('', 'o')],
229 [('classA', [('', 'h'), ('del', 'ell'), ('', 'o')],
230 ('classB', [('', ' ')],
230 ('classB', [('', ' ')],
231 ('classA', [('', 'there')]]
231 ('classA', [('', 'there')]]
232
232
233 """
233 """
234 if tokenstream and len(tokenstream[0]) == 2:
234 if tokenstream and len(tokenstream[0]) == 2:
235 tokenstream = ((t[0], '', t[1]) for t in tokenstream)
235 tokenstream = ((t[0], '', t[1]) for t in tokenstream)
236
236
237 result = []
237 result = []
238 for token_class, op_list in groupby(tokenstream, lambda t: t[0]):
238 for token_class, op_list in groupby(tokenstream, lambda t: t[0]):
239 ops = []
239 ops = []
240 for token_op, token_text_list in groupby(op_list, lambda o: o[1]):
240 for token_op, token_text_list in groupby(op_list, lambda o: o[1]):
241 text_buffer = []
241 text_buffer = []
242 for t_class, t_op, t_text in token_text_list:
242 for t_class, t_op, t_text in token_text_list:
243 text_buffer.append(t_text)
243 text_buffer.append(t_text)
244 ops.append((token_op, ''.join(text_buffer)))
244 ops.append((token_op, ''.join(text_buffer)))
245 result.append((token_class, ops))
245 result.append((token_class, ops))
246 return result
246 return result
247
247
248
248
249 def tokens_diff(old_tokens, new_tokens, use_diff_match_patch=True):
249 def tokens_diff(old_tokens, new_tokens, use_diff_match_patch=True):
250 """
250 """
251 Converts a list of (token_class, token_text) tuples to a list of
251 Converts a list of (token_class, token_text) tuples to a list of
252 (token_class, token_op, token_text) tuples where token_op is one of
252 (token_class, token_op, token_text) tuples where token_op is one of
253 ('ins', 'del', '')
253 ('ins', 'del', '')
254
254
255 :param old_tokens: list of (token_class, token_text) tuples of old line
255 :param old_tokens: list of (token_class, token_text) tuples of old line
256 :param new_tokens: list of (token_class, token_text) tuples of new line
256 :param new_tokens: list of (token_class, token_text) tuples of new line
257 :param use_diff_match_patch: boolean, will use google's diff match patch
257 :param use_diff_match_patch: boolean, will use google's diff match patch
258 library which has options to 'smooth' out the character by character
258 library which has options to 'smooth' out the character by character
259 differences making nicer ins/del blocks
259 differences making nicer ins/del blocks
260 """
260 """
261
261
262 old_tokens_result = []
262 old_tokens_result = []
263 new_tokens_result = []
263 new_tokens_result = []
264
264
265 similarity = difflib.SequenceMatcher(None,
265 similarity = difflib.SequenceMatcher(None,
266 ''.join(token_text for token_class, token_text in old_tokens),
266 ''.join(token_text for token_class, token_text in old_tokens),
267 ''.join(token_text for token_class, token_text in new_tokens)
267 ''.join(token_text for token_class, token_text in new_tokens)
268 ).ratio()
268 ).ratio()
269
269
270 if similarity < 0.6: # return, the blocks are too different
270 if similarity < 0.6: # return, the blocks are too different
271 for token_class, token_text in old_tokens:
271 for token_class, token_text in old_tokens:
272 old_tokens_result.append((token_class, '', token_text))
272 old_tokens_result.append((token_class, '', token_text))
273 for token_class, token_text in new_tokens:
273 for token_class, token_text in new_tokens:
274 new_tokens_result.append((token_class, '', token_text))
274 new_tokens_result.append((token_class, '', token_text))
275 return old_tokens_result, new_tokens_result, similarity
275 return old_tokens_result, new_tokens_result, similarity
276
276
277 token_sequence_matcher = difflib.SequenceMatcher(None,
277 token_sequence_matcher = difflib.SequenceMatcher(None,
278 [x[1] for x in old_tokens],
278 [x[1] for x in old_tokens],
279 [x[1] for x in new_tokens])
279 [x[1] for x in new_tokens])
280
280
281 for tag, o1, o2, n1, n2 in token_sequence_matcher.get_opcodes():
281 for tag, o1, o2, n1, n2 in token_sequence_matcher.get_opcodes():
282 # check the differences by token block types first to give a more
282 # check the differences by token block types first to give a more
283 # nicer "block" level replacement vs character diffs
283 # nicer "block" level replacement vs character diffs
284
284
285 if tag == 'equal':
285 if tag == 'equal':
286 for token_class, token_text in old_tokens[o1:o2]:
286 for token_class, token_text in old_tokens[o1:o2]:
287 old_tokens_result.append((token_class, '', token_text))
287 old_tokens_result.append((token_class, '', token_text))
288 for token_class, token_text in new_tokens[n1:n2]:
288 for token_class, token_text in new_tokens[n1:n2]:
289 new_tokens_result.append((token_class, '', token_text))
289 new_tokens_result.append((token_class, '', token_text))
290 elif tag == 'delete':
290 elif tag == 'delete':
291 for token_class, token_text in old_tokens[o1:o2]:
291 for token_class, token_text in old_tokens[o1:o2]:
292 old_tokens_result.append((token_class, 'del', token_text))
292 old_tokens_result.append((token_class, 'del', token_text))
293 elif tag == 'insert':
293 elif tag == 'insert':
294 for token_class, token_text in new_tokens[n1:n2]:
294 for token_class, token_text in new_tokens[n1:n2]:
295 new_tokens_result.append((token_class, 'ins', token_text))
295 new_tokens_result.append((token_class, 'ins', token_text))
296 elif tag == 'replace':
296 elif tag == 'replace':
297 # if same type token blocks must be replaced, do a diff on the
297 # if same type token blocks must be replaced, do a diff on the
298 # characters in the token blocks to show individual changes
298 # characters in the token blocks to show individual changes
299
299
300 old_char_tokens = []
300 old_char_tokens = []
301 new_char_tokens = []
301 new_char_tokens = []
302 for token_class, token_text in old_tokens[o1:o2]:
302 for token_class, token_text in old_tokens[o1:o2]:
303 for char in token_text:
303 for char in token_text:
304 old_char_tokens.append((token_class, char))
304 old_char_tokens.append((token_class, char))
305
305
306 for token_class, token_text in new_tokens[n1:n2]:
306 for token_class, token_text in new_tokens[n1:n2]:
307 for char in token_text:
307 for char in token_text:
308 new_char_tokens.append((token_class, char))
308 new_char_tokens.append((token_class, char))
309
309
310 old_string = ''.join([token_text for
310 old_string = ''.join([token_text for
311 token_class, token_text in old_char_tokens])
311 token_class, token_text in old_char_tokens])
312 new_string = ''.join([token_text for
312 new_string = ''.join([token_text for
313 token_class, token_text in new_char_tokens])
313 token_class, token_text in new_char_tokens])
314
314
315 char_sequence = difflib.SequenceMatcher(
315 char_sequence = difflib.SequenceMatcher(
316 None, old_string, new_string)
316 None, old_string, new_string)
317 copcodes = char_sequence.get_opcodes()
317 copcodes = char_sequence.get_opcodes()
318 obuffer, nbuffer = [], []
318 obuffer, nbuffer = [], []
319
319
320 if use_diff_match_patch:
320 if use_diff_match_patch:
321 dmp = diff_match_patch()
321 dmp = diff_match_patch()
322 dmp.Diff_EditCost = 11 # TODO: dan: extract this to a setting
322 dmp.Diff_EditCost = 11 # TODO: dan: extract this to a setting
323 reps = dmp.diff_main(old_string, new_string)
323 reps = dmp.diff_main(old_string, new_string)
324 dmp.diff_cleanupEfficiency(reps)
324 dmp.diff_cleanupEfficiency(reps)
325
325
326 a, b = 0, 0
326 a, b = 0, 0
327 for op, rep in reps:
327 for op, rep in reps:
328 l = len(rep)
328 l = len(rep)
329 if op == 0:
329 if op == 0:
330 for i, c in enumerate(rep):
330 for i, c in enumerate(rep):
331 obuffer.append((old_char_tokens[a+i][0], '', c))
331 obuffer.append((old_char_tokens[a+i][0], '', c))
332 nbuffer.append((new_char_tokens[b+i][0], '', c))
332 nbuffer.append((new_char_tokens[b+i][0], '', c))
333 a += l
333 a += l
334 b += l
334 b += l
335 elif op == -1:
335 elif op == -1:
336 for i, c in enumerate(rep):
336 for i, c in enumerate(rep):
337 obuffer.append((old_char_tokens[a+i][0], 'del', c))
337 obuffer.append((old_char_tokens[a+i][0], 'del', c))
338 a += l
338 a += l
339 elif op == 1:
339 elif op == 1:
340 for i, c in enumerate(rep):
340 for i, c in enumerate(rep):
341 nbuffer.append((new_char_tokens[b+i][0], 'ins', c))
341 nbuffer.append((new_char_tokens[b+i][0], 'ins', c))
342 b += l
342 b += l
343 else:
343 else:
344 for ctag, co1, co2, cn1, cn2 in copcodes:
344 for ctag, co1, co2, cn1, cn2 in copcodes:
345 if ctag == 'equal':
345 if ctag == 'equal':
346 for token_class, token_text in old_char_tokens[co1:co2]:
346 for token_class, token_text in old_char_tokens[co1:co2]:
347 obuffer.append((token_class, '', token_text))
347 obuffer.append((token_class, '', token_text))
348 for token_class, token_text in new_char_tokens[cn1:cn2]:
348 for token_class, token_text in new_char_tokens[cn1:cn2]:
349 nbuffer.append((token_class, '', token_text))
349 nbuffer.append((token_class, '', token_text))
350 elif ctag == 'delete':
350 elif ctag == 'delete':
351 for token_class, token_text in old_char_tokens[co1:co2]:
351 for token_class, token_text in old_char_tokens[co1:co2]:
352 obuffer.append((token_class, 'del', token_text))
352 obuffer.append((token_class, 'del', token_text))
353 elif ctag == 'insert':
353 elif ctag == 'insert':
354 for token_class, token_text in new_char_tokens[cn1:cn2]:
354 for token_class, token_text in new_char_tokens[cn1:cn2]:
355 nbuffer.append((token_class, 'ins', token_text))
355 nbuffer.append((token_class, 'ins', token_text))
356 elif ctag == 'replace':
356 elif ctag == 'replace':
357 for token_class, token_text in old_char_tokens[co1:co2]:
357 for token_class, token_text in old_char_tokens[co1:co2]:
358 obuffer.append((token_class, 'del', token_text))
358 obuffer.append((token_class, 'del', token_text))
359 for token_class, token_text in new_char_tokens[cn1:cn2]:
359 for token_class, token_text in new_char_tokens[cn1:cn2]:
360 nbuffer.append((token_class, 'ins', token_text))
360 nbuffer.append((token_class, 'ins', token_text))
361
361
362 old_tokens_result.extend(obuffer)
362 old_tokens_result.extend(obuffer)
363 new_tokens_result.extend(nbuffer)
363 new_tokens_result.extend(nbuffer)
364
364
365 return old_tokens_result, new_tokens_result, similarity
365 return old_tokens_result, new_tokens_result, similarity
366
366
367
367
368 def diffset_node_getter(commit):
368 def diffset_node_getter(commit):
369 def get_node(fname):
369 def get_node(fname):
370 try:
370 try:
371 return commit.get_node(fname)
371 return commit.get_node(fname)
372 except NodeDoesNotExistError:
372 except NodeDoesNotExistError:
373 return None
373 return None
374
374
375 return get_node
375 return get_node
376
376
377
377
378 class DiffSet(object):
378 class DiffSet(object):
379 """
379 """
380 An object for parsing the diff result from diffs.DiffProcessor and
380 An object for parsing the diff result from diffs.DiffProcessor and
381 adding highlighting, side by side/unified renderings and line diffs
381 adding highlighting, side by side/unified renderings and line diffs
382 """
382 """
383
383
384 HL_REAL = 'REAL' # highlights using original file, slow
384 HL_REAL = 'REAL' # highlights using original file, slow
385 HL_FAST = 'FAST' # highlights using just the line, fast but not correct
385 HL_FAST = 'FAST' # highlights using just the line, fast but not correct
386 # in the case of multiline code
386 # in the case of multiline code
387 HL_NONE = 'NONE' # no highlighting, fastest
387 HL_NONE = 'NONE' # no highlighting, fastest
388
388
389 def __init__(self, highlight_mode=HL_REAL, repo_name=None,
389 def __init__(self, highlight_mode=HL_REAL, repo_name=None,
390 source_repo_name=None,
390 source_repo_name=None,
391 source_node_getter=lambda filename: None,
391 source_node_getter=lambda filename: None,
392 target_repo_name=None,
392 target_repo_name=None,
393 target_node_getter=lambda filename: None,
393 target_node_getter=lambda filename: None,
394 source_nodes=None, target_nodes=None,
394 source_nodes=None, target_nodes=None,
395 # files over this size will use fast highlighting
395 # files over this size will use fast highlighting
396 max_file_size_limit=150 * 1024,
396 max_file_size_limit=150 * 1024,
397 ):
397 ):
398
398
399 self.highlight_mode = highlight_mode
399 self.highlight_mode = highlight_mode
400 self.highlighted_filenodes = {
400 self.highlighted_filenodes = {
401 'before': {},
401 'before': {},
402 'after': {}
402 'after': {}
403 }
403 }
404 self.source_node_getter = source_node_getter
404 self.source_node_getter = source_node_getter
405 self.target_node_getter = target_node_getter
405 self.target_node_getter = target_node_getter
406 self.source_nodes = source_nodes or {}
406 self.source_nodes = source_nodes or {}
407 self.target_nodes = target_nodes or {}
407 self.target_nodes = target_nodes or {}
408 self.repo_name = repo_name
408 self.repo_name = repo_name
409 self.target_repo_name = target_repo_name or repo_name
409 self.target_repo_name = target_repo_name or repo_name
410 self.source_repo_name = source_repo_name or repo_name
410 self.source_repo_name = source_repo_name or repo_name
411 self.max_file_size_limit = max_file_size_limit
411 self.max_file_size_limit = max_file_size_limit
412
412
413 def render_patchset(self, patchset, source_ref=None, target_ref=None):
413 def render_patchset(self, patchset, source_ref=None, target_ref=None):
414 diffset = AttributeDict(dict(
414 diffset = AttributeDict(dict(
415 lines_added=0,
415 lines_added=0,
416 lines_deleted=0,
416 lines_deleted=0,
417 changed_files=0,
417 changed_files=0,
418 files=[],
418 files=[],
419 file_stats={},
419 file_stats={},
420 limited_diff=isinstance(patchset, LimitedDiffContainer),
420 limited_diff=isinstance(patchset, LimitedDiffContainer),
421 repo_name=self.repo_name,
421 repo_name=self.repo_name,
422 target_repo_name=self.target_repo_name,
422 target_repo_name=self.target_repo_name,
423 source_repo_name=self.source_repo_name,
423 source_repo_name=self.source_repo_name,
424 source_ref=source_ref,
424 source_ref=source_ref,
425 target_ref=target_ref,
425 target_ref=target_ref,
426 ))
426 ))
427 for patch in patchset:
427 for patch in patchset:
428 diffset.file_stats[patch['filename']] = patch['stats']
428 diffset.file_stats[patch['filename']] = patch['stats']
429 filediff = self.render_patch(patch)
429 filediff = self.render_patch(patch)
430 filediff.diffset = StrictAttributeDict(dict(
430 filediff.diffset = StrictAttributeDict(dict(
431 source_ref=diffset.source_ref,
431 source_ref=diffset.source_ref,
432 target_ref=diffset.target_ref,
432 target_ref=diffset.target_ref,
433 repo_name=diffset.repo_name,
433 repo_name=diffset.repo_name,
434 source_repo_name=diffset.source_repo_name,
434 source_repo_name=diffset.source_repo_name,
435 target_repo_name=diffset.target_repo_name,
435 target_repo_name=diffset.target_repo_name,
436 ))
436 ))
437 diffset.files.append(filediff)
437 diffset.files.append(filediff)
438 diffset.changed_files += 1
438 diffset.changed_files += 1
439 if not patch['stats']['binary']:
439 if not patch['stats']['binary']:
440 diffset.lines_added += patch['stats']['added']
440 diffset.lines_added += patch['stats']['added']
441 diffset.lines_deleted += patch['stats']['deleted']
441 diffset.lines_deleted += patch['stats']['deleted']
442
442
443 return diffset
443 return diffset
444
444
445 _lexer_cache = {}
445 _lexer_cache = {}
446
446
447 def _get_lexer_for_filename(self, filename, filenode=None):
447 def _get_lexer_for_filename(self, filename, filenode=None):
448 # cached because we might need to call it twice for source/target
448 # cached because we might need to call it twice for source/target
449 if filename not in self._lexer_cache:
449 if filename not in self._lexer_cache:
450 if filenode:
450 if filenode:
451 lexer = filenode.lexer
451 lexer = filenode.lexer
452 extension = filenode.extension
452 extension = filenode.extension
453 else:
453 else:
454 lexer = FileNode.get_lexer(filename=filename)
454 lexer = FileNode.get_lexer(filename=filename)
455 extension = filename.split('.')[-1]
455 extension = filename.split('.')[-1]
456
456
457 lexer = get_custom_lexer(extension) or lexer
457 lexer = get_custom_lexer(extension) or lexer
458 self._lexer_cache[filename] = lexer
458 self._lexer_cache[filename] = lexer
459 return self._lexer_cache[filename]
459 return self._lexer_cache[filename]
460
460
461 def render_patch(self, patch):
461 def render_patch(self, patch):
462 log.debug('rendering diff for %r', patch['filename'])
462 log.debug('rendering diff for %r', patch['filename'])
463
463
464 source_filename = patch['original_filename']
464 source_filename = patch['original_filename']
465 target_filename = patch['filename']
465 target_filename = patch['filename']
466
466
467 source_lexer = plain_text_lexer
467 source_lexer = plain_text_lexer
468 target_lexer = plain_text_lexer
468 target_lexer = plain_text_lexer
469
469
470 if not patch['stats']['binary']:
470 if not patch['stats']['binary']:
471 node_hl_mode = self.HL_NONE if patch['chunks'] == [] else None
471 node_hl_mode = self.HL_NONE if patch['chunks'] == [] else None
472 hl_mode = node_hl_mode or self.highlight_mode
472 hl_mode = node_hl_mode or self.highlight_mode
473
473
474 if hl_mode == self.HL_REAL:
474 if hl_mode == self.HL_REAL:
475 if (source_filename and patch['operation'] in ('D', 'M')
475 if (source_filename and patch['operation'] in ('D', 'M')
476 and source_filename not in self.source_nodes):
476 and source_filename not in self.source_nodes):
477 self.source_nodes[source_filename] = (
477 self.source_nodes[source_filename] = (
478 self.source_node_getter(source_filename))
478 self.source_node_getter(source_filename))
479
479
480 if (target_filename and patch['operation'] in ('A', 'M')
480 if (target_filename and patch['operation'] in ('A', 'M')
481 and target_filename not in self.target_nodes):
481 and target_filename not in self.target_nodes):
482 self.target_nodes[target_filename] = (
482 self.target_nodes[target_filename] = (
483 self.target_node_getter(target_filename))
483 self.target_node_getter(target_filename))
484
484
485 elif hl_mode == self.HL_FAST:
485 elif hl_mode == self.HL_FAST:
486 source_lexer = self._get_lexer_for_filename(source_filename)
486 source_lexer = self._get_lexer_for_filename(source_filename)
487 target_lexer = self._get_lexer_for_filename(target_filename)
487 target_lexer = self._get_lexer_for_filename(target_filename)
488
488
489 source_file = self.source_nodes.get(source_filename, source_filename)
489 source_file = self.source_nodes.get(source_filename, source_filename)
490 target_file = self.target_nodes.get(target_filename, target_filename)
490 target_file = self.target_nodes.get(target_filename, target_filename)
491 raw_id_uid = ''
491 raw_id_uid = ''
492 if self.source_nodes.get(source_filename):
492 if self.source_nodes.get(source_filename):
493 raw_id_uid = self.source_nodes[source_filename].commit.raw_id
493 raw_id_uid = self.source_nodes[source_filename].commit.raw_id
494
494
495 if not raw_id_uid and self.target_nodes.get(target_filename):
495 if not raw_id_uid and self.target_nodes.get(target_filename):
496 # in case this is a new file we only have it in target
496 # in case this is a new file we only have it in target
497 raw_id_uid = self.target_nodes[target_filename].commit.raw_id
497 raw_id_uid = self.target_nodes[target_filename].commit.raw_id
498
498
499 source_filenode, target_filenode = None, None
499 source_filenode, target_filenode = None, None
500
500
501 # TODO: dan: FileNode.lexer works on the content of the file - which
501 # TODO: dan: FileNode.lexer works on the content of the file - which
502 # can be slow - issue #4289 explains a lexer clean up - which once
502 # can be slow - issue #4289 explains a lexer clean up - which once
503 # done can allow caching a lexer for a filenode to avoid the file lookup
503 # done can allow caching a lexer for a filenode to avoid the file lookup
504 if isinstance(source_file, FileNode):
504 if isinstance(source_file, FileNode):
505 source_filenode = source_file
505 source_filenode = source_file
506 #source_lexer = source_file.lexer
506 #source_lexer = source_file.lexer
507 source_lexer = self._get_lexer_for_filename(source_filename)
507 source_lexer = self._get_lexer_for_filename(source_filename)
508 source_file.lexer = source_lexer
508 source_file.lexer = source_lexer
509
509
510 if isinstance(target_file, FileNode):
510 if isinstance(target_file, FileNode):
511 target_filenode = target_file
511 target_filenode = target_file
512 #target_lexer = target_file.lexer
512 #target_lexer = target_file.lexer
513 target_lexer = self._get_lexer_for_filename(target_filename)
513 target_lexer = self._get_lexer_for_filename(target_filename)
514 target_file.lexer = target_lexer
514 target_file.lexer = target_lexer
515
515
516 source_file_path, target_file_path = None, None
516 source_file_path, target_file_path = None, None
517
517
518 if source_filename != '/dev/null':
518 if source_filename != '/dev/null':
519 source_file_path = source_filename
519 source_file_path = source_filename
520 if target_filename != '/dev/null':
520 if target_filename != '/dev/null':
521 target_file_path = target_filename
521 target_file_path = target_filename
522
522
523 source_file_type = source_lexer.name
523 source_file_type = source_lexer.name
524 target_file_type = target_lexer.name
524 target_file_type = target_lexer.name
525
525
526 filediff = AttributeDict({
526 filediff = AttributeDict({
527 'source_file_path': source_file_path,
527 'source_file_path': source_file_path,
528 'target_file_path': target_file_path,
528 'target_file_path': target_file_path,
529 'source_filenode': source_filenode,
529 'source_filenode': source_filenode,
530 'target_filenode': target_filenode,
530 'target_filenode': target_filenode,
531 'source_file_type': target_file_type,
531 'source_file_type': target_file_type,
532 'target_file_type': source_file_type,
532 'target_file_type': source_file_type,
533 'patch': {'filename': patch['filename'], 'stats': patch['stats']},
533 'patch': {'filename': patch['filename'], 'stats': patch['stats']},
534 'operation': patch['operation'],
534 'operation': patch['operation'],
535 'source_mode': patch['stats']['old_mode'],
535 'source_mode': patch['stats']['old_mode'],
536 'target_mode': patch['stats']['new_mode'],
536 'target_mode': patch['stats']['new_mode'],
537 'limited_diff': patch['is_limited_diff'],
537 'limited_diff': patch['is_limited_diff'],
538 'hunks': [],
538 'hunks': [],
539 'hunk_ops': None,
539 'hunk_ops': None,
540 'diffset': self,
540 'diffset': self,
541 'raw_id': raw_id_uid,
541 'raw_id': raw_id_uid,
542 })
542 })
543
543
544 file_chunks = patch['chunks'][1:]
544 file_chunks = patch['chunks'][1:]
545 for i, hunk in enumerate(file_chunks, 1):
545 for i, hunk in enumerate(file_chunks, 1):
546 hunkbit = self.parse_hunk(hunk, source_file, target_file)
546 hunkbit = self.parse_hunk(hunk, source_file, target_file)
547 hunkbit.source_file_path = source_file_path
547 hunkbit.source_file_path = source_file_path
548 hunkbit.target_file_path = target_file_path
548 hunkbit.target_file_path = target_file_path
549 hunkbit.index = i
549 hunkbit.index = i
550 filediff.hunks.append(hunkbit)
550 filediff.hunks.append(hunkbit)
551
551
552 # Simulate hunk on OPS type line which doesn't really contain any diff
552 # Simulate hunk on OPS type line which doesn't really contain any diff
553 # this allows commenting on those
553 # this allows commenting on those
554 if not file_chunks:
554 if not file_chunks:
555 actions = []
555 actions = []
556 for op_id, op_text in filediff.patch['stats']['ops'].items():
556 for op_id, op_text in filediff.patch['stats']['ops'].items():
557 if op_id == DEL_FILENODE:
557 if op_id == DEL_FILENODE:
558 actions.append(u'file was removed')
558 actions.append('file was removed')
559 elif op_id == BIN_FILENODE:
559 elif op_id == BIN_FILENODE:
560 actions.append(u'binary diff hidden')
560 actions.append('binary diff hidden')
561 else:
561 else:
562 actions.append(safe_unicode(op_text))
562 actions.append(safe_unicode(op_text))
563 action_line = u'NO CONTENT: ' + \
563 action_line = 'NO CONTENT: ' + \
564 u', '.join(actions) or u'UNDEFINED_ACTION'
564 ', '.join(actions) or 'UNDEFINED_ACTION'
565
565
566 hunk_ops = {'source_length': 0, 'source_start': 0,
566 hunk_ops = {'source_length': 0, 'source_start': 0,
567 'lines': [
567 'lines': [
568 {'new_lineno': 0, 'old_lineno': 1,
568 {'new_lineno': 0, 'old_lineno': 1,
569 'action': 'unmod-no-hl', 'line': action_line}
569 'action': 'unmod-no-hl', 'line': action_line}
570 ],
570 ],
571 'section_header': u'', 'target_start': 1, 'target_length': 1}
571 'section_header': '', 'target_start': 1, 'target_length': 1}
572
572
573 hunkbit = self.parse_hunk(hunk_ops, source_file, target_file)
573 hunkbit = self.parse_hunk(hunk_ops, source_file, target_file)
574 hunkbit.source_file_path = source_file_path
574 hunkbit.source_file_path = source_file_path
575 hunkbit.target_file_path = target_file_path
575 hunkbit.target_file_path = target_file_path
576 filediff.hunk_ops = hunkbit
576 filediff.hunk_ops = hunkbit
577 return filediff
577 return filediff
578
578
579 def parse_hunk(self, hunk, source_file, target_file):
579 def parse_hunk(self, hunk, source_file, target_file):
580 result = AttributeDict(dict(
580 result = AttributeDict(dict(
581 source_start=hunk['source_start'],
581 source_start=hunk['source_start'],
582 source_length=hunk['source_length'],
582 source_length=hunk['source_length'],
583 target_start=hunk['target_start'],
583 target_start=hunk['target_start'],
584 target_length=hunk['target_length'],
584 target_length=hunk['target_length'],
585 section_header=hunk['section_header'],
585 section_header=hunk['section_header'],
586 lines=[],
586 lines=[],
587 ))
587 ))
588 before, after = [], []
588 before, after = [], []
589
589
590 for line in hunk['lines']:
590 for line in hunk['lines']:
591 if line['action'] in ['unmod', 'unmod-no-hl']:
591 if line['action'] in ['unmod', 'unmod-no-hl']:
592 no_hl = line['action'] == 'unmod-no-hl'
592 no_hl = line['action'] == 'unmod-no-hl'
593 result.lines.extend(
593 result.lines.extend(
594 self.parse_lines(before, after, source_file, target_file, no_hl=no_hl))
594 self.parse_lines(before, after, source_file, target_file, no_hl=no_hl))
595 after.append(line)
595 after.append(line)
596 before.append(line)
596 before.append(line)
597 elif line['action'] == 'add':
597 elif line['action'] == 'add':
598 after.append(line)
598 after.append(line)
599 elif line['action'] == 'del':
599 elif line['action'] == 'del':
600 before.append(line)
600 before.append(line)
601 elif line['action'] == 'old-no-nl':
601 elif line['action'] == 'old-no-nl':
602 before.append(line)
602 before.append(line)
603 elif line['action'] == 'new-no-nl':
603 elif line['action'] == 'new-no-nl':
604 after.append(line)
604 after.append(line)
605
605
606 all_actions = [x['action'] for x in after] + [x['action'] for x in before]
606 all_actions = [x['action'] for x in after] + [x['action'] for x in before]
607 no_hl = {x for x in all_actions} == {'unmod-no-hl'}
607 no_hl = {x for x in all_actions} == {'unmod-no-hl'}
608 result.lines.extend(
608 result.lines.extend(
609 self.parse_lines(before, after, source_file, target_file, no_hl=no_hl))
609 self.parse_lines(before, after, source_file, target_file, no_hl=no_hl))
610 # NOTE(marcink): we must keep list() call here so we can cache the result...
610 # NOTE(marcink): we must keep list() call here so we can cache the result...
611 result.unified = list(self.as_unified(result.lines))
611 result.unified = list(self.as_unified(result.lines))
612 result.sideside = result.lines
612 result.sideside = result.lines
613
613
614 return result
614 return result
615
615
616 def parse_lines(self, before_lines, after_lines, source_file, target_file,
616 def parse_lines(self, before_lines, after_lines, source_file, target_file,
617 no_hl=False):
617 no_hl=False):
618 # TODO: dan: investigate doing the diff comparison and fast highlighting
618 # TODO: dan: investigate doing the diff comparison and fast highlighting
619 # on the entire before and after buffered block lines rather than by
619 # on the entire before and after buffered block lines rather than by
620 # line, this means we can get better 'fast' highlighting if the context
620 # line, this means we can get better 'fast' highlighting if the context
621 # allows it - eg.
621 # allows it - eg.
622 # line 4: """
622 # line 4: """
623 # line 5: this gets highlighted as a string
623 # line 5: this gets highlighted as a string
624 # line 6: """
624 # line 6: """
625
625
626 lines = []
626 lines = []
627
627
628 before_newline = AttributeDict()
628 before_newline = AttributeDict()
629 after_newline = AttributeDict()
629 after_newline = AttributeDict()
630 if before_lines and before_lines[-1]['action'] == 'old-no-nl':
630 if before_lines and before_lines[-1]['action'] == 'old-no-nl':
631 before_newline_line = before_lines.pop(-1)
631 before_newline_line = before_lines.pop(-1)
632 before_newline.content = '\n {}'.format(
632 before_newline.content = '\n {}'.format(
633 render_tokenstream(
633 render_tokenstream(
634 [(x[0], '', x[1])
634 [(x[0], '', x[1])
635 for x in [('nonl', before_newline_line['line'])]]))
635 for x in [('nonl', before_newline_line['line'])]]))
636
636
637 if after_lines and after_lines[-1]['action'] == 'new-no-nl':
637 if after_lines and after_lines[-1]['action'] == 'new-no-nl':
638 after_newline_line = after_lines.pop(-1)
638 after_newline_line = after_lines.pop(-1)
639 after_newline.content = '\n {}'.format(
639 after_newline.content = '\n {}'.format(
640 render_tokenstream(
640 render_tokenstream(
641 [(x[0], '', x[1])
641 [(x[0], '', x[1])
642 for x in [('nonl', after_newline_line['line'])]]))
642 for x in [('nonl', after_newline_line['line'])]]))
643
643
644 while before_lines or after_lines:
644 while before_lines or after_lines:
645 before, after = None, None
645 before, after = None, None
646 before_tokens, after_tokens = None, None
646 before_tokens, after_tokens = None, None
647
647
648 if before_lines:
648 if before_lines:
649 before = before_lines.pop(0)
649 before = before_lines.pop(0)
650 if after_lines:
650 if after_lines:
651 after = after_lines.pop(0)
651 after = after_lines.pop(0)
652
652
653 original = AttributeDict()
653 original = AttributeDict()
654 modified = AttributeDict()
654 modified = AttributeDict()
655
655
656 if before:
656 if before:
657 if before['action'] == 'old-no-nl':
657 if before['action'] == 'old-no-nl':
658 before_tokens = [('nonl', before['line'])]
658 before_tokens = [('nonl', before['line'])]
659 else:
659 else:
660 before_tokens = self.get_line_tokens(
660 before_tokens = self.get_line_tokens(
661 line_text=before['line'], line_number=before['old_lineno'],
661 line_text=before['line'], line_number=before['old_lineno'],
662 input_file=source_file, no_hl=no_hl, source='before')
662 input_file=source_file, no_hl=no_hl, source='before')
663 original.lineno = before['old_lineno']
663 original.lineno = before['old_lineno']
664 original.content = before['line']
664 original.content = before['line']
665 original.action = self.action_to_op(before['action'])
665 original.action = self.action_to_op(before['action'])
666
666
667 original.get_comment_args = (
667 original.get_comment_args = (
668 source_file, 'o', before['old_lineno'])
668 source_file, 'o', before['old_lineno'])
669
669
670 if after:
670 if after:
671 if after['action'] == 'new-no-nl':
671 if after['action'] == 'new-no-nl':
672 after_tokens = [('nonl', after['line'])]
672 after_tokens = [('nonl', after['line'])]
673 else:
673 else:
674 after_tokens = self.get_line_tokens(
674 after_tokens = self.get_line_tokens(
675 line_text=after['line'], line_number=after['new_lineno'],
675 line_text=after['line'], line_number=after['new_lineno'],
676 input_file=target_file, no_hl=no_hl, source='after')
676 input_file=target_file, no_hl=no_hl, source='after')
677 modified.lineno = after['new_lineno']
677 modified.lineno = after['new_lineno']
678 modified.content = after['line']
678 modified.content = after['line']
679 modified.action = self.action_to_op(after['action'])
679 modified.action = self.action_to_op(after['action'])
680
680
681 modified.get_comment_args = (target_file, 'n', after['new_lineno'])
681 modified.get_comment_args = (target_file, 'n', after['new_lineno'])
682
682
683 # diff the lines
683 # diff the lines
684 if before_tokens and after_tokens:
684 if before_tokens and after_tokens:
685 o_tokens, m_tokens, similarity = tokens_diff(
685 o_tokens, m_tokens, similarity = tokens_diff(
686 before_tokens, after_tokens)
686 before_tokens, after_tokens)
687 original.content = render_tokenstream(o_tokens)
687 original.content = render_tokenstream(o_tokens)
688 modified.content = render_tokenstream(m_tokens)
688 modified.content = render_tokenstream(m_tokens)
689 elif before_tokens:
689 elif before_tokens:
690 original.content = render_tokenstream(
690 original.content = render_tokenstream(
691 [(x[0], '', x[1]) for x in before_tokens])
691 [(x[0], '', x[1]) for x in before_tokens])
692 elif after_tokens:
692 elif after_tokens:
693 modified.content = render_tokenstream(
693 modified.content = render_tokenstream(
694 [(x[0], '', x[1]) for x in after_tokens])
694 [(x[0], '', x[1]) for x in after_tokens])
695
695
696 if not before_lines and before_newline:
696 if not before_lines and before_newline:
697 original.content += before_newline.content
697 original.content += before_newline.content
698 before_newline = None
698 before_newline = None
699 if not after_lines and after_newline:
699 if not after_lines and after_newline:
700 modified.content += after_newline.content
700 modified.content += after_newline.content
701 after_newline = None
701 after_newline = None
702
702
703 lines.append(AttributeDict({
703 lines.append(AttributeDict({
704 'original': original,
704 'original': original,
705 'modified': modified,
705 'modified': modified,
706 }))
706 }))
707
707
708 return lines
708 return lines
709
709
710 def get_line_tokens(self, line_text, line_number, input_file=None, no_hl=False, source=''):
710 def get_line_tokens(self, line_text, line_number, input_file=None, no_hl=False, source=''):
711 filenode = None
711 filenode = None
712 filename = None
712 filename = None
713
713
714 if isinstance(input_file, str):
714 if isinstance(input_file, str):
715 filename = input_file
715 filename = input_file
716 elif isinstance(input_file, FileNode):
716 elif isinstance(input_file, FileNode):
717 filenode = input_file
717 filenode = input_file
718 filename = input_file.unicode_path
718 filename = input_file.unicode_path
719
719
720 hl_mode = self.HL_NONE if no_hl else self.highlight_mode
720 hl_mode = self.HL_NONE if no_hl else self.highlight_mode
721 if hl_mode == self.HL_REAL and filenode:
721 if hl_mode == self.HL_REAL and filenode:
722 lexer = self._get_lexer_for_filename(filename)
722 lexer = self._get_lexer_for_filename(filename)
723 file_size_allowed = input_file.size < self.max_file_size_limit
723 file_size_allowed = input_file.size < self.max_file_size_limit
724 if line_number and file_size_allowed:
724 if line_number and file_size_allowed:
725 return self.get_tokenized_filenode_line(input_file, line_number, lexer, source)
725 return self.get_tokenized_filenode_line(input_file, line_number, lexer, source)
726
726
727 if hl_mode in (self.HL_REAL, self.HL_FAST) and filename:
727 if hl_mode in (self.HL_REAL, self.HL_FAST) and filename:
728 lexer = self._get_lexer_for_filename(filename)
728 lexer = self._get_lexer_for_filename(filename)
729 return list(tokenize_string(line_text, lexer))
729 return list(tokenize_string(line_text, lexer))
730
730
731 return list(tokenize_string(line_text, plain_text_lexer))
731 return list(tokenize_string(line_text, plain_text_lexer))
732
732
733 def get_tokenized_filenode_line(self, filenode, line_number, lexer=None, source=''):
733 def get_tokenized_filenode_line(self, filenode, line_number, lexer=None, source=''):
734
734
735 def tokenize(_filenode):
735 def tokenize(_filenode):
736 self.highlighted_filenodes[source][filenode] = filenode_as_lines_tokens(filenode, lexer)
736 self.highlighted_filenodes[source][filenode] = filenode_as_lines_tokens(filenode, lexer)
737
737
738 if filenode not in self.highlighted_filenodes[source]:
738 if filenode not in self.highlighted_filenodes[source]:
739 tokenize(filenode)
739 tokenize(filenode)
740
740
741 try:
741 try:
742 return self.highlighted_filenodes[source][filenode][line_number - 1]
742 return self.highlighted_filenodes[source][filenode][line_number - 1]
743 except Exception:
743 except Exception:
744 log.exception('diff rendering error')
744 log.exception('diff rendering error')
745 return [('', u'L{}: rhodecode diff rendering error'.format(line_number))]
745 return [('', 'L{}: rhodecode diff rendering error'.format(line_number))]
746
746
747 def action_to_op(self, action):
747 def action_to_op(self, action):
748 return {
748 return {
749 'add': '+',
749 'add': '+',
750 'del': '-',
750 'del': '-',
751 'unmod': ' ',
751 'unmod': ' ',
752 'unmod-no-hl': ' ',
752 'unmod-no-hl': ' ',
753 'old-no-nl': ' ',
753 'old-no-nl': ' ',
754 'new-no-nl': ' ',
754 'new-no-nl': ' ',
755 }.get(action, action)
755 }.get(action, action)
756
756
757 def as_unified(self, lines):
757 def as_unified(self, lines):
758 """
758 """
759 Return a generator that yields the lines of a diff in unified order
759 Return a generator that yields the lines of a diff in unified order
760 """
760 """
761 def generator():
761 def generator():
762 buf = []
762 buf = []
763 for line in lines:
763 for line in lines:
764
764
765 if buf and not line.original or line.original.action == ' ':
765 if buf and not line.original or line.original.action == ' ':
766 for b in buf:
766 for b in buf:
767 yield b
767 yield b
768 buf = []
768 buf = []
769
769
770 if line.original:
770 if line.original:
771 if line.original.action == ' ':
771 if line.original.action == ' ':
772 yield (line.original.lineno, line.modified.lineno,
772 yield (line.original.lineno, line.modified.lineno,
773 line.original.action, line.original.content,
773 line.original.action, line.original.content,
774 line.original.get_comment_args)
774 line.original.get_comment_args)
775 continue
775 continue
776
776
777 if line.original.action == '-':
777 if line.original.action == '-':
778 yield (line.original.lineno, None,
778 yield (line.original.lineno, None,
779 line.original.action, line.original.content,
779 line.original.action, line.original.content,
780 line.original.get_comment_args)
780 line.original.get_comment_args)
781
781
782 if line.modified.action == '+':
782 if line.modified.action == '+':
783 buf.append((
783 buf.append((
784 None, line.modified.lineno,
784 None, line.modified.lineno,
785 line.modified.action, line.modified.content,
785 line.modified.action, line.modified.content,
786 line.modified.get_comment_args))
786 line.modified.get_comment_args))
787 continue
787 continue
788
788
789 if line.modified:
789 if line.modified:
790 yield (None, line.modified.lineno,
790 yield (None, line.modified.lineno,
791 line.modified.action, line.modified.content,
791 line.modified.action, line.modified.content,
792 line.modified.get_comment_args)
792 line.modified.get_comment_args)
793
793
794 for b in buf:
794 for b in buf:
795 yield b
795 yield b
796
796
797 return generator()
797 return generator()
@@ -1,680 +1,680 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 """
21 """
22 Database creation, and setup module for RhodeCode Enterprise. Used for creation
22 Database creation, and setup module for RhodeCode Enterprise. Used for creation
23 of database as well as for migration operations
23 of database as well as for migration operations
24 """
24 """
25
25
26 import os
26 import os
27 import sys
27 import sys
28 import time
28 import time
29 import uuid
29 import uuid
30 import logging
30 import logging
31 import getpass
31 import getpass
32 from os.path import dirname as dn, join as jn
32 from os.path import dirname as dn, join as jn
33
33
34 from sqlalchemy.engine import create_engine
34 from sqlalchemy.engine import create_engine
35
35
36 from rhodecode import __dbversion__
36 from rhodecode import __dbversion__
37 from rhodecode.model import init_model
37 from rhodecode.model import init_model
38 from rhodecode.model.user import UserModel
38 from rhodecode.model.user import UserModel
39 from rhodecode.model.db import (
39 from rhodecode.model.db import (
40 User, Permission, RhodeCodeUi, RhodeCodeSetting, UserToPerm,
40 User, Permission, RhodeCodeUi, RhodeCodeSetting, UserToPerm,
41 DbMigrateVersion, RepoGroup, UserRepoGroupToPerm, CacheKey, Repository)
41 DbMigrateVersion, RepoGroup, UserRepoGroupToPerm, CacheKey, Repository)
42 from rhodecode.model.meta import Session, Base
42 from rhodecode.model.meta import Session, Base
43 from rhodecode.model.permission import PermissionModel
43 from rhodecode.model.permission import PermissionModel
44 from rhodecode.model.repo import RepoModel
44 from rhodecode.model.repo import RepoModel
45 from rhodecode.model.repo_group import RepoGroupModel
45 from rhodecode.model.repo_group import RepoGroupModel
46 from rhodecode.model.settings import SettingsModel
46 from rhodecode.model.settings import SettingsModel
47
47
48
48
49 log = logging.getLogger(__name__)
49 log = logging.getLogger(__name__)
50
50
51
51
52 def notify(msg):
52 def notify(msg):
53 """
53 """
54 Notification for migrations messages
54 Notification for migrations messages
55 """
55 """
56 ml = len(msg) + (4 * 2)
56 ml = len(msg) + (4 * 2)
57 print(('\n%s\n*** %s ***\n%s' % ('*' * ml, msg, '*' * ml)).upper())
57 print(('\n%s\n*** %s ***\n%s' % ('*' * ml, msg, '*' * ml)).upper())
58
58
59
59
60 class DbManage(object):
60 class DbManage(object):
61
61
62 def __init__(self, log_sql, dbconf, root, tests=False,
62 def __init__(self, log_sql, dbconf, root, tests=False,
63 SESSION=None, cli_args=None):
63 SESSION=None, cli_args=None):
64 self.dbname = dbconf.split('/')[-1]
64 self.dbname = dbconf.split('/')[-1]
65 self.tests = tests
65 self.tests = tests
66 self.root = root
66 self.root = root
67 self.dburi = dbconf
67 self.dburi = dbconf
68 self.log_sql = log_sql
68 self.log_sql = log_sql
69 self.cli_args = cli_args or {}
69 self.cli_args = cli_args or {}
70 self.init_db(SESSION=SESSION)
70 self.init_db(SESSION=SESSION)
71 self.ask_ok = self.get_ask_ok_func(self.cli_args.get('force_ask'))
71 self.ask_ok = self.get_ask_ok_func(self.cli_args.get('force_ask'))
72
72
73 def db_exists(self):
73 def db_exists(self):
74 if not self.sa:
74 if not self.sa:
75 self.init_db()
75 self.init_db()
76 try:
76 try:
77 self.sa.query(RhodeCodeUi)\
77 self.sa.query(RhodeCodeUi)\
78 .filter(RhodeCodeUi.ui_key == '/')\
78 .filter(RhodeCodeUi.ui_key == '/')\
79 .scalar()
79 .scalar()
80 return True
80 return True
81 except Exception:
81 except Exception:
82 return False
82 return False
83 finally:
83 finally:
84 self.sa.rollback()
84 self.sa.rollback()
85
85
86 def get_ask_ok_func(self, param):
86 def get_ask_ok_func(self, param):
87 if param not in [None]:
87 if param not in [None]:
88 # return a function lambda that has a default set to param
88 # return a function lambda that has a default set to param
89 return lambda *args, **kwargs: param
89 return lambda *args, **kwargs: param
90 else:
90 else:
91 from rhodecode.lib.utils import ask_ok
91 from rhodecode.lib.utils import ask_ok
92 return ask_ok
92 return ask_ok
93
93
94 def init_db(self, SESSION=None):
94 def init_db(self, SESSION=None):
95 if SESSION:
95 if SESSION:
96 self.sa = SESSION
96 self.sa = SESSION
97 else:
97 else:
98 # init new sessions
98 # init new sessions
99 engine = create_engine(self.dburi, echo=self.log_sql)
99 engine = create_engine(self.dburi, echo=self.log_sql)
100 init_model(engine)
100 init_model(engine)
101 self.sa = Session()
101 self.sa = Session()
102
102
103 def create_tables(self, override=False):
103 def create_tables(self, override=False):
104 """
104 """
105 Create a auth database
105 Create a auth database
106 """
106 """
107
107
108 log.info("Existing database with the same name is going to be destroyed.")
108 log.info("Existing database with the same name is going to be destroyed.")
109 log.info("Setup command will run DROP ALL command on that database.")
109 log.info("Setup command will run DROP ALL command on that database.")
110 if self.tests:
110 if self.tests:
111 destroy = True
111 destroy = True
112 else:
112 else:
113 destroy = self.ask_ok('Are you sure that you want to destroy the old database? [y/n]')
113 destroy = self.ask_ok('Are you sure that you want to destroy the old database? [y/n]')
114 if not destroy:
114 if not destroy:
115 log.info('db tables bootstrap: Nothing done.')
115 log.info('db tables bootstrap: Nothing done.')
116 sys.exit(0)
116 sys.exit(0)
117 if destroy:
117 if destroy:
118 Base.metadata.drop_all()
118 Base.metadata.drop_all()
119
119
120 checkfirst = not override
120 checkfirst = not override
121 Base.metadata.create_all(checkfirst=checkfirst)
121 Base.metadata.create_all(checkfirst=checkfirst)
122 log.info('Created tables for %s', self.dbname)
122 log.info('Created tables for %s', self.dbname)
123
123
124 def set_db_version(self):
124 def set_db_version(self):
125 ver = DbMigrateVersion()
125 ver = DbMigrateVersion()
126 ver.version = __dbversion__
126 ver.version = __dbversion__
127 ver.repository_id = 'rhodecode_db_migrations'
127 ver.repository_id = 'rhodecode_db_migrations'
128 ver.repository_path = 'versions'
128 ver.repository_path = 'versions'
129 self.sa.add(ver)
129 self.sa.add(ver)
130 log.info('db version set to: %s', __dbversion__)
130 log.info('db version set to: %s', __dbversion__)
131
131
132 def run_post_migration_tasks(self):
132 def run_post_migration_tasks(self):
133 """
133 """
134 Run various tasks before actually doing migrations
134 Run various tasks before actually doing migrations
135 """
135 """
136 # delete cache keys on each upgrade
136 # delete cache keys on each upgrade
137 total = CacheKey.query().count()
137 total = CacheKey.query().count()
138 log.info("Deleting (%s) cache keys now...", total)
138 log.info("Deleting (%s) cache keys now...", total)
139 CacheKey.delete_all_cache()
139 CacheKey.delete_all_cache()
140
140
141 def upgrade(self, version=None):
141 def upgrade(self, version=None):
142 """
142 """
143 Upgrades given database schema to given revision following
143 Upgrades given database schema to given revision following
144 all needed steps, to perform the upgrade
144 all needed steps, to perform the upgrade
145
145
146 """
146 """
147
147
148 from rhodecode.lib.dbmigrate.migrate.versioning import api
148 from rhodecode.lib.dbmigrate.migrate.versioning import api
149 from rhodecode.lib.dbmigrate.migrate.exceptions import \
149 from rhodecode.lib.dbmigrate.migrate.exceptions import \
150 DatabaseNotControlledError
150 DatabaseNotControlledError
151
151
152 if 'sqlite' in self.dburi:
152 if 'sqlite' in self.dburi:
153 print(
153 print(
154 '********************** WARNING **********************\n'
154 '********************** WARNING **********************\n'
155 'Make sure your version of sqlite is at least 3.7.X. \n'
155 'Make sure your version of sqlite is at least 3.7.X. \n'
156 'Earlier versions are known to fail on some migrations\n'
156 'Earlier versions are known to fail on some migrations\n'
157 '*****************************************************\n')
157 '*****************************************************\n')
158
158
159 upgrade = self.ask_ok(
159 upgrade = self.ask_ok(
160 'You are about to perform a database upgrade. Make '
160 'You are about to perform a database upgrade. Make '
161 'sure you have backed up your database. '
161 'sure you have backed up your database. '
162 'Continue ? [y/n]')
162 'Continue ? [y/n]')
163 if not upgrade:
163 if not upgrade:
164 log.info('No upgrade performed')
164 log.info('No upgrade performed')
165 sys.exit(0)
165 sys.exit(0)
166
166
167 repository_path = jn(dn(dn(dn(os.path.realpath(__file__)))),
167 repository_path = jn(dn(dn(dn(os.path.realpath(__file__)))),
168 'rhodecode/lib/dbmigrate')
168 'rhodecode/lib/dbmigrate')
169 db_uri = self.dburi
169 db_uri = self.dburi
170
170
171 if version:
171 if version:
172 DbMigrateVersion.set_version(version)
172 DbMigrateVersion.set_version(version)
173
173
174 try:
174 try:
175 curr_version = api.db_version(db_uri, repository_path)
175 curr_version = api.db_version(db_uri, repository_path)
176 msg = ('Found current database db_uri under version '
176 msg = ('Found current database db_uri under version '
177 'control with version {}'.format(curr_version))
177 'control with version {}'.format(curr_version))
178
178
179 except (RuntimeError, DatabaseNotControlledError):
179 except (RuntimeError, DatabaseNotControlledError):
180 curr_version = 1
180 curr_version = 1
181 msg = ('Current database is not under version control. Setting '
181 msg = ('Current database is not under version control. Setting '
182 'as version %s' % curr_version)
182 'as version %s' % curr_version)
183 api.version_control(db_uri, repository_path, curr_version)
183 api.version_control(db_uri, repository_path, curr_version)
184
184
185 notify(msg)
185 notify(msg)
186
186
187
187
188 if curr_version == __dbversion__:
188 if curr_version == __dbversion__:
189 log.info('This database is already at the newest version')
189 log.info('This database is already at the newest version')
190 sys.exit(0)
190 sys.exit(0)
191
191
192 upgrade_steps = range(curr_version + 1, __dbversion__ + 1)
192 upgrade_steps = range(curr_version + 1, __dbversion__ + 1)
193 notify('attempting to upgrade database from '
193 notify('attempting to upgrade database from '
194 'version %s to version %s' % (curr_version, __dbversion__))
194 'version %s to version %s' % (curr_version, __dbversion__))
195
195
196 # CALL THE PROPER ORDER OF STEPS TO PERFORM FULL UPGRADE
196 # CALL THE PROPER ORDER OF STEPS TO PERFORM FULL UPGRADE
197 _step = None
197 _step = None
198 for step in upgrade_steps:
198 for step in upgrade_steps:
199 notify('performing upgrade step %s' % step)
199 notify('performing upgrade step %s' % step)
200 time.sleep(0.5)
200 time.sleep(0.5)
201
201
202 api.upgrade(db_uri, repository_path, step)
202 api.upgrade(db_uri, repository_path, step)
203 self.sa.rollback()
203 self.sa.rollback()
204 notify('schema upgrade for step %s completed' % (step,))
204 notify('schema upgrade for step %s completed' % (step,))
205
205
206 _step = step
206 _step = step
207
207
208 self.run_post_migration_tasks()
208 self.run_post_migration_tasks()
209 notify('upgrade to version %s successful' % _step)
209 notify('upgrade to version %s successful' % _step)
210
210
211 def fix_repo_paths(self):
211 def fix_repo_paths(self):
212 """
212 """
213 Fixes an old RhodeCode version path into new one without a '*'
213 Fixes an old RhodeCode version path into new one without a '*'
214 """
214 """
215
215
216 paths = self.sa.query(RhodeCodeUi)\
216 paths = self.sa.query(RhodeCodeUi)\
217 .filter(RhodeCodeUi.ui_key == '/')\
217 .filter(RhodeCodeUi.ui_key == '/')\
218 .scalar()
218 .scalar()
219
219
220 paths.ui_value = paths.ui_value.replace('*', '')
220 paths.ui_value = paths.ui_value.replace('*', '')
221
221
222 try:
222 try:
223 self.sa.add(paths)
223 self.sa.add(paths)
224 self.sa.commit()
224 self.sa.commit()
225 except Exception:
225 except Exception:
226 self.sa.rollback()
226 self.sa.rollback()
227 raise
227 raise
228
228
229 def fix_default_user(self):
229 def fix_default_user(self):
230 """
230 """
231 Fixes an old default user with some 'nicer' default values,
231 Fixes an old default user with some 'nicer' default values,
232 used mostly for anonymous access
232 used mostly for anonymous access
233 """
233 """
234 def_user = self.sa.query(User)\
234 def_user = self.sa.query(User)\
235 .filter(User.username == User.DEFAULT_USER)\
235 .filter(User.username == User.DEFAULT_USER)\
236 .one()
236 .one()
237
237
238 def_user.name = 'Anonymous'
238 def_user.name = 'Anonymous'
239 def_user.lastname = 'User'
239 def_user.lastname = 'User'
240 def_user.email = User.DEFAULT_USER_EMAIL
240 def_user.email = User.DEFAULT_USER_EMAIL
241
241
242 try:
242 try:
243 self.sa.add(def_user)
243 self.sa.add(def_user)
244 self.sa.commit()
244 self.sa.commit()
245 except Exception:
245 except Exception:
246 self.sa.rollback()
246 self.sa.rollback()
247 raise
247 raise
248
248
249 def fix_settings(self):
249 def fix_settings(self):
250 """
250 """
251 Fixes rhodecode settings and adds ga_code key for google analytics
251 Fixes rhodecode settings and adds ga_code key for google analytics
252 """
252 """
253
253
254 hgsettings3 = RhodeCodeSetting('ga_code', '')
254 hgsettings3 = RhodeCodeSetting('ga_code', '')
255
255
256 try:
256 try:
257 self.sa.add(hgsettings3)
257 self.sa.add(hgsettings3)
258 self.sa.commit()
258 self.sa.commit()
259 except Exception:
259 except Exception:
260 self.sa.rollback()
260 self.sa.rollback()
261 raise
261 raise
262
262
263 def create_admin_and_prompt(self):
263 def create_admin_and_prompt(self):
264
264
265 # defaults
265 # defaults
266 defaults = self.cli_args
266 defaults = self.cli_args
267 username = defaults.get('username')
267 username = defaults.get('username')
268 password = defaults.get('password')
268 password = defaults.get('password')
269 email = defaults.get('email')
269 email = defaults.get('email')
270
270
271 if username is None:
271 if username is None:
272 username = input('Specify admin username:')
272 username = eval(input('Specify admin username:'))
273 if password is None:
273 if password is None:
274 password = self._get_admin_password()
274 password = self._get_admin_password()
275 if not password:
275 if not password:
276 # second try
276 # second try
277 password = self._get_admin_password()
277 password = self._get_admin_password()
278 if not password:
278 if not password:
279 sys.exit()
279 sys.exit()
280 if email is None:
280 if email is None:
281 email = input('Specify admin email:')
281 email = eval(input('Specify admin email:'))
282 api_key = self.cli_args.get('api_key')
282 api_key = self.cli_args.get('api_key')
283 self.create_user(username, password, email, True,
283 self.create_user(username, password, email, True,
284 strict_creation_check=False,
284 strict_creation_check=False,
285 api_key=api_key)
285 api_key=api_key)
286
286
287 def _get_admin_password(self):
287 def _get_admin_password(self):
288 password = getpass.getpass('Specify admin password '
288 password = getpass.getpass('Specify admin password '
289 '(min 6 chars):')
289 '(min 6 chars):')
290 confirm = getpass.getpass('Confirm password:')
290 confirm = getpass.getpass('Confirm password:')
291
291
292 if password != confirm:
292 if password != confirm:
293 log.error('passwords mismatch')
293 log.error('passwords mismatch')
294 return False
294 return False
295 if len(password) < 6:
295 if len(password) < 6:
296 log.error('password is too short - use at least 6 characters')
296 log.error('password is too short - use at least 6 characters')
297 return False
297 return False
298
298
299 return password
299 return password
300
300
301 def create_test_admin_and_users(self):
301 def create_test_admin_and_users(self):
302 log.info('creating admin and regular test users')
302 log.info('creating admin and regular test users')
303 from rhodecode.tests import TEST_USER_ADMIN_LOGIN, \
303 from rhodecode.tests import TEST_USER_ADMIN_LOGIN, \
304 TEST_USER_ADMIN_PASS, TEST_USER_ADMIN_EMAIL, \
304 TEST_USER_ADMIN_PASS, TEST_USER_ADMIN_EMAIL, \
305 TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS, \
305 TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS, \
306 TEST_USER_REGULAR_EMAIL, TEST_USER_REGULAR2_LOGIN, \
306 TEST_USER_REGULAR_EMAIL, TEST_USER_REGULAR2_LOGIN, \
307 TEST_USER_REGULAR2_PASS, TEST_USER_REGULAR2_EMAIL
307 TEST_USER_REGULAR2_PASS, TEST_USER_REGULAR2_EMAIL
308
308
309 self.create_user(TEST_USER_ADMIN_LOGIN, TEST_USER_ADMIN_PASS,
309 self.create_user(TEST_USER_ADMIN_LOGIN, TEST_USER_ADMIN_PASS,
310 TEST_USER_ADMIN_EMAIL, True, api_key=True)
310 TEST_USER_ADMIN_EMAIL, True, api_key=True)
311
311
312 self.create_user(TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS,
312 self.create_user(TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS,
313 TEST_USER_REGULAR_EMAIL, False, api_key=True)
313 TEST_USER_REGULAR_EMAIL, False, api_key=True)
314
314
315 self.create_user(TEST_USER_REGULAR2_LOGIN, TEST_USER_REGULAR2_PASS,
315 self.create_user(TEST_USER_REGULAR2_LOGIN, TEST_USER_REGULAR2_PASS,
316 TEST_USER_REGULAR2_EMAIL, False, api_key=True)
316 TEST_USER_REGULAR2_EMAIL, False, api_key=True)
317
317
318 def create_ui_settings(self, repo_store_path):
318 def create_ui_settings(self, repo_store_path):
319 """
319 """
320 Creates ui settings, fills out hooks
320 Creates ui settings, fills out hooks
321 and disables dotencode
321 and disables dotencode
322 """
322 """
323 settings_model = SettingsModel(sa=self.sa)
323 settings_model = SettingsModel(sa=self.sa)
324 from rhodecode.lib.vcs.backends.hg import largefiles_store
324 from rhodecode.lib.vcs.backends.hg import largefiles_store
325 from rhodecode.lib.vcs.backends.git import lfs_store
325 from rhodecode.lib.vcs.backends.git import lfs_store
326
326
327 # Build HOOKS
327 # Build HOOKS
328 hooks = [
328 hooks = [
329 (RhodeCodeUi.HOOK_REPO_SIZE, 'python:vcsserver.hooks.repo_size'),
329 (RhodeCodeUi.HOOK_REPO_SIZE, 'python:vcsserver.hooks.repo_size'),
330
330
331 # HG
331 # HG
332 (RhodeCodeUi.HOOK_PRE_PULL, 'python:vcsserver.hooks.pre_pull'),
332 (RhodeCodeUi.HOOK_PRE_PULL, 'python:vcsserver.hooks.pre_pull'),
333 (RhodeCodeUi.HOOK_PULL, 'python:vcsserver.hooks.log_pull_action'),
333 (RhodeCodeUi.HOOK_PULL, 'python:vcsserver.hooks.log_pull_action'),
334 (RhodeCodeUi.HOOK_PRE_PUSH, 'python:vcsserver.hooks.pre_push'),
334 (RhodeCodeUi.HOOK_PRE_PUSH, 'python:vcsserver.hooks.pre_push'),
335 (RhodeCodeUi.HOOK_PRETX_PUSH, 'python:vcsserver.hooks.pre_push'),
335 (RhodeCodeUi.HOOK_PRETX_PUSH, 'python:vcsserver.hooks.pre_push'),
336 (RhodeCodeUi.HOOK_PUSH, 'python:vcsserver.hooks.log_push_action'),
336 (RhodeCodeUi.HOOK_PUSH, 'python:vcsserver.hooks.log_push_action'),
337 (RhodeCodeUi.HOOK_PUSH_KEY, 'python:vcsserver.hooks.key_push'),
337 (RhodeCodeUi.HOOK_PUSH_KEY, 'python:vcsserver.hooks.key_push'),
338
338
339 ]
339 ]
340
340
341 for key, value in hooks:
341 for key, value in hooks:
342 hook_obj = settings_model.get_ui_by_key(key)
342 hook_obj = settings_model.get_ui_by_key(key)
343 hooks2 = hook_obj if hook_obj else RhodeCodeUi()
343 hooks2 = hook_obj if hook_obj else RhodeCodeUi()
344 hooks2.ui_section = 'hooks'
344 hooks2.ui_section = 'hooks'
345 hooks2.ui_key = key
345 hooks2.ui_key = key
346 hooks2.ui_value = value
346 hooks2.ui_value = value
347 self.sa.add(hooks2)
347 self.sa.add(hooks2)
348
348
349 # enable largefiles
349 # enable largefiles
350 largefiles = RhodeCodeUi()
350 largefiles = RhodeCodeUi()
351 largefiles.ui_section = 'extensions'
351 largefiles.ui_section = 'extensions'
352 largefiles.ui_key = 'largefiles'
352 largefiles.ui_key = 'largefiles'
353 largefiles.ui_value = ''
353 largefiles.ui_value = ''
354 self.sa.add(largefiles)
354 self.sa.add(largefiles)
355
355
356 # set default largefiles cache dir, defaults to
356 # set default largefiles cache dir, defaults to
357 # /repo_store_location/.cache/largefiles
357 # /repo_store_location/.cache/largefiles
358 largefiles = RhodeCodeUi()
358 largefiles = RhodeCodeUi()
359 largefiles.ui_section = 'largefiles'
359 largefiles.ui_section = 'largefiles'
360 largefiles.ui_key = 'usercache'
360 largefiles.ui_key = 'usercache'
361 largefiles.ui_value = largefiles_store(repo_store_path)
361 largefiles.ui_value = largefiles_store(repo_store_path)
362
362
363 self.sa.add(largefiles)
363 self.sa.add(largefiles)
364
364
365 # set default lfs cache dir, defaults to
365 # set default lfs cache dir, defaults to
366 # /repo_store_location/.cache/lfs_store
366 # /repo_store_location/.cache/lfs_store
367 lfsstore = RhodeCodeUi()
367 lfsstore = RhodeCodeUi()
368 lfsstore.ui_section = 'vcs_git_lfs'
368 lfsstore.ui_section = 'vcs_git_lfs'
369 lfsstore.ui_key = 'store_location'
369 lfsstore.ui_key = 'store_location'
370 lfsstore.ui_value = lfs_store(repo_store_path)
370 lfsstore.ui_value = lfs_store(repo_store_path)
371
371
372 self.sa.add(lfsstore)
372 self.sa.add(lfsstore)
373
373
374 # enable hgsubversion disabled by default
374 # enable hgsubversion disabled by default
375 hgsubversion = RhodeCodeUi()
375 hgsubversion = RhodeCodeUi()
376 hgsubversion.ui_section = 'extensions'
376 hgsubversion.ui_section = 'extensions'
377 hgsubversion.ui_key = 'hgsubversion'
377 hgsubversion.ui_key = 'hgsubversion'
378 hgsubversion.ui_value = ''
378 hgsubversion.ui_value = ''
379 hgsubversion.ui_active = False
379 hgsubversion.ui_active = False
380 self.sa.add(hgsubversion)
380 self.sa.add(hgsubversion)
381
381
382 # enable hgevolve disabled by default
382 # enable hgevolve disabled by default
383 hgevolve = RhodeCodeUi()
383 hgevolve = RhodeCodeUi()
384 hgevolve.ui_section = 'extensions'
384 hgevolve.ui_section = 'extensions'
385 hgevolve.ui_key = 'evolve'
385 hgevolve.ui_key = 'evolve'
386 hgevolve.ui_value = ''
386 hgevolve.ui_value = ''
387 hgevolve.ui_active = False
387 hgevolve.ui_active = False
388 self.sa.add(hgevolve)
388 self.sa.add(hgevolve)
389
389
390 hgevolve = RhodeCodeUi()
390 hgevolve = RhodeCodeUi()
391 hgevolve.ui_section = 'experimental'
391 hgevolve.ui_section = 'experimental'
392 hgevolve.ui_key = 'evolution'
392 hgevolve.ui_key = 'evolution'
393 hgevolve.ui_value = ''
393 hgevolve.ui_value = ''
394 hgevolve.ui_active = False
394 hgevolve.ui_active = False
395 self.sa.add(hgevolve)
395 self.sa.add(hgevolve)
396
396
397 hgevolve = RhodeCodeUi()
397 hgevolve = RhodeCodeUi()
398 hgevolve.ui_section = 'experimental'
398 hgevolve.ui_section = 'experimental'
399 hgevolve.ui_key = 'evolution.exchange'
399 hgevolve.ui_key = 'evolution.exchange'
400 hgevolve.ui_value = ''
400 hgevolve.ui_value = ''
401 hgevolve.ui_active = False
401 hgevolve.ui_active = False
402 self.sa.add(hgevolve)
402 self.sa.add(hgevolve)
403
403
404 hgevolve = RhodeCodeUi()
404 hgevolve = RhodeCodeUi()
405 hgevolve.ui_section = 'extensions'
405 hgevolve.ui_section = 'extensions'
406 hgevolve.ui_key = 'topic'
406 hgevolve.ui_key = 'topic'
407 hgevolve.ui_value = ''
407 hgevolve.ui_value = ''
408 hgevolve.ui_active = False
408 hgevolve.ui_active = False
409 self.sa.add(hgevolve)
409 self.sa.add(hgevolve)
410
410
411 # enable hggit disabled by default
411 # enable hggit disabled by default
412 hggit = RhodeCodeUi()
412 hggit = RhodeCodeUi()
413 hggit.ui_section = 'extensions'
413 hggit.ui_section = 'extensions'
414 hggit.ui_key = 'hggit'
414 hggit.ui_key = 'hggit'
415 hggit.ui_value = ''
415 hggit.ui_value = ''
416 hggit.ui_active = False
416 hggit.ui_active = False
417 self.sa.add(hggit)
417 self.sa.add(hggit)
418
418
419 # set svn branch defaults
419 # set svn branch defaults
420 branches = ["/branches/*", "/trunk"]
420 branches = ["/branches/*", "/trunk"]
421 tags = ["/tags/*"]
421 tags = ["/tags/*"]
422
422
423 for branch in branches:
423 for branch in branches:
424 settings_model.create_ui_section_value(
424 settings_model.create_ui_section_value(
425 RhodeCodeUi.SVN_BRANCH_ID, branch)
425 RhodeCodeUi.SVN_BRANCH_ID, branch)
426
426
427 for tag in tags:
427 for tag in tags:
428 settings_model.create_ui_section_value(RhodeCodeUi.SVN_TAG_ID, tag)
428 settings_model.create_ui_section_value(RhodeCodeUi.SVN_TAG_ID, tag)
429
429
430 def create_auth_plugin_options(self, skip_existing=False):
430 def create_auth_plugin_options(self, skip_existing=False):
431 """
431 """
432 Create default auth plugin settings, and make it active
432 Create default auth plugin settings, and make it active
433
433
434 :param skip_existing:
434 :param skip_existing:
435 """
435 """
436 defaults = [
436 defaults = [
437 ('auth_plugins',
437 ('auth_plugins',
438 'egg:rhodecode-enterprise-ce#token,egg:rhodecode-enterprise-ce#rhodecode',
438 'egg:rhodecode-enterprise-ce#token,egg:rhodecode-enterprise-ce#rhodecode',
439 'list'),
439 'list'),
440
440
441 ('auth_authtoken_enabled',
441 ('auth_authtoken_enabled',
442 'True',
442 'True',
443 'bool'),
443 'bool'),
444
444
445 ('auth_rhodecode_enabled',
445 ('auth_rhodecode_enabled',
446 'True',
446 'True',
447 'bool'),
447 'bool'),
448 ]
448 ]
449 for k, v, t in defaults:
449 for k, v, t in defaults:
450 if (skip_existing and
450 if (skip_existing and
451 SettingsModel().get_setting_by_name(k) is not None):
451 SettingsModel().get_setting_by_name(k) is not None):
452 log.debug('Skipping option %s', k)
452 log.debug('Skipping option %s', k)
453 continue
453 continue
454 setting = RhodeCodeSetting(k, v, t)
454 setting = RhodeCodeSetting(k, v, t)
455 self.sa.add(setting)
455 self.sa.add(setting)
456
456
457 def create_default_options(self, skip_existing=False):
457 def create_default_options(self, skip_existing=False):
458 """Creates default settings"""
458 """Creates default settings"""
459
459
460 for k, v, t in [
460 for k, v, t in [
461 ('default_repo_enable_locking', False, 'bool'),
461 ('default_repo_enable_locking', False, 'bool'),
462 ('default_repo_enable_downloads', False, 'bool'),
462 ('default_repo_enable_downloads', False, 'bool'),
463 ('default_repo_enable_statistics', False, 'bool'),
463 ('default_repo_enable_statistics', False, 'bool'),
464 ('default_repo_private', False, 'bool'),
464 ('default_repo_private', False, 'bool'),
465 ('default_repo_type', 'hg', 'unicode')]:
465 ('default_repo_type', 'hg', 'unicode')]:
466
466
467 if (skip_existing and
467 if (skip_existing and
468 SettingsModel().get_setting_by_name(k) is not None):
468 SettingsModel().get_setting_by_name(k) is not None):
469 log.debug('Skipping option %s', k)
469 log.debug('Skipping option %s', k)
470 continue
470 continue
471 setting = RhodeCodeSetting(k, v, t)
471 setting = RhodeCodeSetting(k, v, t)
472 self.sa.add(setting)
472 self.sa.add(setting)
473
473
474 def fixup_groups(self):
474 def fixup_groups(self):
475 def_usr = User.get_default_user()
475 def_usr = User.get_default_user()
476 for g in RepoGroup.query().all():
476 for g in RepoGroup.query().all():
477 g.group_name = g.get_new_name(g.name)
477 g.group_name = g.get_new_name(g.name)
478 self.sa.add(g)
478 self.sa.add(g)
479 # get default perm
479 # get default perm
480 default = UserRepoGroupToPerm.query()\
480 default = UserRepoGroupToPerm.query()\
481 .filter(UserRepoGroupToPerm.group == g)\
481 .filter(UserRepoGroupToPerm.group == g)\
482 .filter(UserRepoGroupToPerm.user == def_usr)\
482 .filter(UserRepoGroupToPerm.user == def_usr)\
483 .scalar()
483 .scalar()
484
484
485 if default is None:
485 if default is None:
486 log.debug('missing default permission for group %s adding', g)
486 log.debug('missing default permission for group %s adding', g)
487 perm_obj = RepoGroupModel()._create_default_perms(g)
487 perm_obj = RepoGroupModel()._create_default_perms(g)
488 self.sa.add(perm_obj)
488 self.sa.add(perm_obj)
489
489
490 def reset_permissions(self, username):
490 def reset_permissions(self, username):
491 """
491 """
492 Resets permissions to default state, useful when old systems had
492 Resets permissions to default state, useful when old systems had
493 bad permissions, we must clean them up
493 bad permissions, we must clean them up
494
494
495 :param username:
495 :param username:
496 """
496 """
497 default_user = User.get_by_username(username)
497 default_user = User.get_by_username(username)
498 if not default_user:
498 if not default_user:
499 return
499 return
500
500
501 u2p = UserToPerm.query()\
501 u2p = UserToPerm.query()\
502 .filter(UserToPerm.user == default_user).all()
502 .filter(UserToPerm.user == default_user).all()
503 fixed = False
503 fixed = False
504 if len(u2p) != len(Permission.DEFAULT_USER_PERMISSIONS):
504 if len(u2p) != len(Permission.DEFAULT_USER_PERMISSIONS):
505 for p in u2p:
505 for p in u2p:
506 Session().delete(p)
506 Session().delete(p)
507 fixed = True
507 fixed = True
508 self.populate_default_permissions()
508 self.populate_default_permissions()
509 return fixed
509 return fixed
510
510
511 def config_prompt(self, test_repo_path='', retries=3):
511 def config_prompt(self, test_repo_path='', retries=3):
512 defaults = self.cli_args
512 defaults = self.cli_args
513 _path = defaults.get('repos_location')
513 _path = defaults.get('repos_location')
514 if retries == 3:
514 if retries == 3:
515 log.info('Setting up repositories config')
515 log.info('Setting up repositories config')
516
516
517 if _path is not None:
517 if _path is not None:
518 path = _path
518 path = _path
519 elif not self.tests and not test_repo_path:
519 elif not self.tests and not test_repo_path:
520 path = input(
520 path = eval(input(
521 'Enter a valid absolute path to store repositories. '
521 'Enter a valid absolute path to store repositories. '
522 'All repositories in that path will be added automatically:'
522 'All repositories in that path will be added automatically:'
523 )
523 ))
524 else:
524 else:
525 path = test_repo_path
525 path = test_repo_path
526 path_ok = True
526 path_ok = True
527
527
528 # check proper dir
528 # check proper dir
529 if not os.path.isdir(path):
529 if not os.path.isdir(path):
530 path_ok = False
530 path_ok = False
531 log.error('Given path %s is not a valid directory', path)
531 log.error('Given path %s is not a valid directory', path)
532
532
533 elif not os.path.isabs(path):
533 elif not os.path.isabs(path):
534 path_ok = False
534 path_ok = False
535 log.error('Given path %s is not an absolute path', path)
535 log.error('Given path %s is not an absolute path', path)
536
536
537 # check if path is at least readable.
537 # check if path is at least readable.
538 if not os.access(path, os.R_OK):
538 if not os.access(path, os.R_OK):
539 path_ok = False
539 path_ok = False
540 log.error('Given path %s is not readable', path)
540 log.error('Given path %s is not readable', path)
541
541
542 # check write access, warn user about non writeable paths
542 # check write access, warn user about non writeable paths
543 elif not os.access(path, os.W_OK) and path_ok:
543 elif not os.access(path, os.W_OK) and path_ok:
544 log.warning('No write permission to given path %s', path)
544 log.warning('No write permission to given path %s', path)
545
545
546 q = ('Given path %s is not writeable, do you want to '
546 q = ('Given path %s is not writeable, do you want to '
547 'continue with read only mode ? [y/n]' % (path,))
547 'continue with read only mode ? [y/n]' % (path,))
548 if not self.ask_ok(q):
548 if not self.ask_ok(q):
549 log.error('Canceled by user')
549 log.error('Canceled by user')
550 sys.exit(-1)
550 sys.exit(-1)
551
551
552 if retries == 0:
552 if retries == 0:
553 sys.exit('max retries reached')
553 sys.exit('max retries reached')
554 if not path_ok:
554 if not path_ok:
555 retries -= 1
555 retries -= 1
556 return self.config_prompt(test_repo_path, retries)
556 return self.config_prompt(test_repo_path, retries)
557
557
558 real_path = os.path.normpath(os.path.realpath(path))
558 real_path = os.path.normpath(os.path.realpath(path))
559
559
560 if real_path != os.path.normpath(path):
560 if real_path != os.path.normpath(path):
561 q = ('Path looks like a symlink, RhodeCode Enterprise will store '
561 q = ('Path looks like a symlink, RhodeCode Enterprise will store '
562 'given path as %s ? [y/n]') % (real_path,)
562 'given path as %s ? [y/n]') % (real_path,)
563 if not self.ask_ok(q):
563 if not self.ask_ok(q):
564 log.error('Canceled by user')
564 log.error('Canceled by user')
565 sys.exit(-1)
565 sys.exit(-1)
566
566
567 return real_path
567 return real_path
568
568
569 def create_settings(self, path):
569 def create_settings(self, path):
570
570
571 self.create_ui_settings(path)
571 self.create_ui_settings(path)
572
572
573 ui_config = [
573 ui_config = [
574 ('web', 'push_ssl', 'False'),
574 ('web', 'push_ssl', 'False'),
575 ('web', 'allow_archive', 'gz zip bz2'),
575 ('web', 'allow_archive', 'gz zip bz2'),
576 ('web', 'allow_push', '*'),
576 ('web', 'allow_push', '*'),
577 ('web', 'baseurl', '/'),
577 ('web', 'baseurl', '/'),
578 ('paths', '/', path),
578 ('paths', '/', path),
579 ('phases', 'publish', 'True')
579 ('phases', 'publish', 'True')
580 ]
580 ]
581 for section, key, value in ui_config:
581 for section, key, value in ui_config:
582 ui_conf = RhodeCodeUi()
582 ui_conf = RhodeCodeUi()
583 setattr(ui_conf, 'ui_section', section)
583 setattr(ui_conf, 'ui_section', section)
584 setattr(ui_conf, 'ui_key', key)
584 setattr(ui_conf, 'ui_key', key)
585 setattr(ui_conf, 'ui_value', value)
585 setattr(ui_conf, 'ui_value', value)
586 self.sa.add(ui_conf)
586 self.sa.add(ui_conf)
587
587
588 # rhodecode app settings
588 # rhodecode app settings
589 settings = [
589 settings = [
590 ('realm', 'RhodeCode', 'unicode'),
590 ('realm', 'RhodeCode', 'unicode'),
591 ('title', '', 'unicode'),
591 ('title', '', 'unicode'),
592 ('pre_code', '', 'unicode'),
592 ('pre_code', '', 'unicode'),
593 ('post_code', '', 'unicode'),
593 ('post_code', '', 'unicode'),
594
594
595 # Visual
595 # Visual
596 ('show_public_icon', True, 'bool'),
596 ('show_public_icon', True, 'bool'),
597 ('show_private_icon', True, 'bool'),
597 ('show_private_icon', True, 'bool'),
598 ('stylify_metatags', True, 'bool'),
598 ('stylify_metatags', True, 'bool'),
599 ('dashboard_items', 100, 'int'),
599 ('dashboard_items', 100, 'int'),
600 ('admin_grid_items', 25, 'int'),
600 ('admin_grid_items', 25, 'int'),
601
601
602 ('markup_renderer', 'markdown', 'unicode'),
602 ('markup_renderer', 'markdown', 'unicode'),
603
603
604 ('repository_fields', True, 'bool'),
604 ('repository_fields', True, 'bool'),
605 ('show_version', True, 'bool'),
605 ('show_version', True, 'bool'),
606 ('show_revision_number', True, 'bool'),
606 ('show_revision_number', True, 'bool'),
607 ('show_sha_length', 12, 'int'),
607 ('show_sha_length', 12, 'int'),
608
608
609 ('use_gravatar', False, 'bool'),
609 ('use_gravatar', False, 'bool'),
610 ('gravatar_url', User.DEFAULT_GRAVATAR_URL, 'unicode'),
610 ('gravatar_url', User.DEFAULT_GRAVATAR_URL, 'unicode'),
611
611
612 ('clone_uri_tmpl', Repository.DEFAULT_CLONE_URI, 'unicode'),
612 ('clone_uri_tmpl', Repository.DEFAULT_CLONE_URI, 'unicode'),
613 ('clone_uri_id_tmpl', Repository.DEFAULT_CLONE_URI_ID, 'unicode'),
613 ('clone_uri_id_tmpl', Repository.DEFAULT_CLONE_URI_ID, 'unicode'),
614 ('clone_uri_ssh_tmpl', Repository.DEFAULT_CLONE_URI_SSH, 'unicode'),
614 ('clone_uri_ssh_tmpl', Repository.DEFAULT_CLONE_URI_SSH, 'unicode'),
615 ('support_url', '', 'unicode'),
615 ('support_url', '', 'unicode'),
616 ('update_url', RhodeCodeSetting.DEFAULT_UPDATE_URL, 'unicode'),
616 ('update_url', RhodeCodeSetting.DEFAULT_UPDATE_URL, 'unicode'),
617
617
618 # VCS Settings
618 # VCS Settings
619 ('pr_merge_enabled', True, 'bool'),
619 ('pr_merge_enabled', True, 'bool'),
620 ('use_outdated_comments', True, 'bool'),
620 ('use_outdated_comments', True, 'bool'),
621 ('diff_cache', True, 'bool'),
621 ('diff_cache', True, 'bool'),
622 ]
622 ]
623
623
624 for key, val, type_ in settings:
624 for key, val, type_ in settings:
625 sett = RhodeCodeSetting(key, val, type_)
625 sett = RhodeCodeSetting(key, val, type_)
626 self.sa.add(sett)
626 self.sa.add(sett)
627
627
628 self.create_auth_plugin_options()
628 self.create_auth_plugin_options()
629 self.create_default_options()
629 self.create_default_options()
630
630
631 log.info('created ui config')
631 log.info('created ui config')
632
632
633 def create_user(self, username, password, email='', admin=False,
633 def create_user(self, username, password, email='', admin=False,
634 strict_creation_check=True, api_key=None):
634 strict_creation_check=True, api_key=None):
635 log.info('creating user `%s`', username)
635 log.info('creating user `%s`', username)
636 user = UserModel().create_or_update(
636 user = UserModel().create_or_update(
637 username, password, email, firstname=u'RhodeCode', lastname=u'Admin',
637 username, password, email, firstname='RhodeCode', lastname='Admin',
638 active=True, admin=admin, extern_type="rhodecode",
638 active=True, admin=admin, extern_type="rhodecode",
639 strict_creation_check=strict_creation_check)
639 strict_creation_check=strict_creation_check)
640
640
641 if api_key:
641 if api_key:
642 log.info('setting a new default auth token for user `%s`', username)
642 log.info('setting a new default auth token for user `%s`', username)
643 UserModel().add_auth_token(
643 UserModel().add_auth_token(
644 user=user, lifetime_minutes=-1,
644 user=user, lifetime_minutes=-1,
645 role=UserModel.auth_token_role.ROLE_ALL,
645 role=UserModel.auth_token_role.ROLE_ALL,
646 description=u'BUILTIN TOKEN')
646 description='BUILTIN TOKEN')
647
647
648 def create_default_user(self):
648 def create_default_user(self):
649 log.info('creating default user')
649 log.info('creating default user')
650 # create default user for handling default permissions.
650 # create default user for handling default permissions.
651 user = UserModel().create_or_update(username=User.DEFAULT_USER,
651 user = UserModel().create_or_update(username=User.DEFAULT_USER,
652 password=str(uuid.uuid1())[:20],
652 password=str(uuid.uuid1())[:20],
653 email=User.DEFAULT_USER_EMAIL,
653 email=User.DEFAULT_USER_EMAIL,
654 firstname=u'Anonymous',
654 firstname='Anonymous',
655 lastname=u'User',
655 lastname='User',
656 strict_creation_check=False)
656 strict_creation_check=False)
657 # based on configuration options activate/de-activate this user which
657 # based on configuration options activate/de-activate this user which
658 # controlls anonymous access
658 # controls anonymous access
659 if self.cli_args.get('public_access') is False:
659 if self.cli_args.get('public_access') is False:
660 log.info('Public access disabled')
660 log.info('Public access disabled')
661 user.active = False
661 user.active = False
662 Session().add(user)
662 Session().add(user)
663 Session().commit()
663 Session().commit()
664
664
665 def create_permissions(self):
665 def create_permissions(self):
666 """
666 """
667 Creates all permissions defined in the system
667 Creates all permissions defined in the system
668 """
668 """
669 # module.(access|create|change|delete)_[name]
669 # module.(access|create|change|delete)_[name]
670 # module.(none|read|write|admin)
670 # module.(none|read|write|admin)
671 log.info('creating permissions')
671 log.info('creating permissions')
672 PermissionModel(self.sa).create_permissions()
672 PermissionModel(self.sa).create_permissions()
673
673
674 def populate_default_permissions(self):
674 def populate_default_permissions(self):
675 """
675 """
676 Populate default permissions. It will create only the default
676 Populate default permissions. It will create only the default
677 permissions that are missing, and not alter already defined ones
677 permissions that are missing, and not alter already defined ones
678 """
678 """
679 log.info('creating default user permissions')
679 log.info('creating default user permissions')
680 PermissionModel(self.sa).create_default_user_permissions(user=User.DEFAULT_USER)
680 PermissionModel(self.sa).create_default_user_permissions(user=User.DEFAULT_USER)
@@ -1,2031 +1,2031 b''
1 """Diff Match and Patch
1 """Diff Match and Patch
2
2
3 Copyright 2006 Google Inc.
3 Copyright 2006 Google Inc.
4 http://code.google.com/p/google-diff-match-patch/
4 http://code.google.com/p/google-diff-match-patch/
5
5
6 Licensed under the Apache License, Version 2.0 (the "License");
6 Licensed under the Apache License, Version 2.0 (the "License");
7 you may not use this file except in compliance with the License.
7 you may not use this file except in compliance with the License.
8 You may obtain a copy of the License at
8 You may obtain a copy of the License at
9
9
10 http://www.apache.org/licenses/LICENSE-2.0
10 http://www.apache.org/licenses/LICENSE-2.0
11
11
12 Unless required by applicable law or agreed to in writing, software
12 Unless required by applicable law or agreed to in writing, software
13 distributed under the License is distributed on an "AS IS" BASIS,
13 distributed under the License is distributed on an "AS IS" BASIS,
14 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 See the License for the specific language governing permissions and
15 See the License for the specific language governing permissions and
16 limitations under the License.
16 limitations under the License.
17 """
17 """
18
18
19 """Functions for diff, match and patch.
19 """Functions for diff, match and patch.
20
20
21 Computes the difference between two texts to create a patch.
21 Computes the difference between two texts to create a patch.
22 Applies the patch onto another text, allowing for errors.
22 Applies the patch onto another text, allowing for errors.
23 """
23 """
24
24
25 __author__ = "fraser@google.com (Neil Fraser)"
25 __author__ = "fraser@google.com (Neil Fraser)"
26
26
27 import math
27 import math
28 import re
28 import re
29 import sys
29 import sys
30 import time
30 import time
31 import urllib.request, urllib.parse, urllib.error
31 import urllib.request, urllib.parse, urllib.error
32
32
33
33
34 class diff_match_patch:
34 class diff_match_patch:
35 """Class containing the diff, match and patch methods.
35 """Class containing the diff, match and patch methods.
36
36
37 Also contains the behaviour settings.
37 Also contains the behaviour settings.
38 """
38 """
39
39
40 def __init__(self):
40 def __init__(self):
41 """Inits a diff_match_patch object with default settings.
41 """Inits a diff_match_patch object with default settings.
42 Redefine these in your program to override the defaults.
42 Redefine these in your program to override the defaults.
43 """
43 """
44
44
45 # Number of seconds to map a diff before giving up (0 for infinity).
45 # Number of seconds to map a diff before giving up (0 for infinity).
46 self.Diff_Timeout = 1.0
46 self.Diff_Timeout = 1.0
47 # Cost of an empty edit operation in terms of edit characters.
47 # Cost of an empty edit operation in terms of edit characters.
48 self.Diff_EditCost = 4
48 self.Diff_EditCost = 4
49 # At what point is no match declared (0.0 = perfection, 1.0 = very loose).
49 # At what point is no match declared (0.0 = perfection, 1.0 = very loose).
50 self.Match_Threshold = 0.5
50 self.Match_Threshold = 0.5
51 # How far to search for a match (0 = exact location, 1000+ = broad match).
51 # How far to search for a match (0 = exact location, 1000+ = broad match).
52 # A match this many characters away from the expected location will add
52 # A match this many characters away from the expected location will add
53 # 1.0 to the score (0.0 is a perfect match).
53 # 1.0 to the score (0.0 is a perfect match).
54 self.Match_Distance = 1000
54 self.Match_Distance = 1000
55 # When deleting a large block of text (over ~64 characters), how close do
55 # When deleting a large block of text (over ~64 characters), how close do
56 # the contents have to be to match the expected contents. (0.0 = perfection,
56 # the contents have to be to match the expected contents. (0.0 = perfection,
57 # 1.0 = very loose). Note that Match_Threshold controls how closely the
57 # 1.0 = very loose). Note that Match_Threshold controls how closely the
58 # end points of a delete need to match.
58 # end points of a delete need to match.
59 self.Patch_DeleteThreshold = 0.5
59 self.Patch_DeleteThreshold = 0.5
60 # Chunk size for context length.
60 # Chunk size for context length.
61 self.Patch_Margin = 4
61 self.Patch_Margin = 4
62
62
63 # The number of bits in an int.
63 # The number of bits in an int.
64 # Python has no maximum, thus to disable patch splitting set to 0.
64 # Python has no maximum, thus to disable patch splitting set to 0.
65 # However to avoid long patches in certain pathological cases, use 32.
65 # However to avoid long patches in certain pathological cases, use 32.
66 # Multiple short patches (using native ints) are much faster than long ones.
66 # Multiple short patches (using native ints) are much faster than long ones.
67 self.Match_MaxBits = 32
67 self.Match_MaxBits = 32
68
68
69 # DIFF FUNCTIONS
69 # DIFF FUNCTIONS
70
70
71 # The data structure representing a diff is an array of tuples:
71 # The data structure representing a diff is an array of tuples:
72 # [(DIFF_DELETE, "Hello"), (DIFF_INSERT, "Goodbye"), (DIFF_EQUAL, " world.")]
72 # [(DIFF_DELETE, "Hello"), (DIFF_INSERT, "Goodbye"), (DIFF_EQUAL, " world.")]
73 # which means: delete "Hello", add "Goodbye" and keep " world."
73 # which means: delete "Hello", add "Goodbye" and keep " world."
74 DIFF_DELETE = -1
74 DIFF_DELETE = -1
75 DIFF_INSERT = 1
75 DIFF_INSERT = 1
76 DIFF_EQUAL = 0
76 DIFF_EQUAL = 0
77
77
78 def diff_main(self, text1, text2, checklines=True, deadline=None):
78 def diff_main(self, text1, text2, checklines=True, deadline=None):
79 """Find the differences between two texts. Simplifies the problem by
79 """Find the differences between two texts. Simplifies the problem by
80 stripping any common prefix or suffix off the texts before diffing.
80 stripping any common prefix or suffix off the texts before diffing.
81
81
82 Args:
82 Args:
83 text1: Old string to be diffed.
83 text1: Old string to be diffed.
84 text2: New string to be diffed.
84 text2: New string to be diffed.
85 checklines: Optional speedup flag. If present and false, then don't run
85 checklines: Optional speedup flag. If present and false, then don't run
86 a line-level diff first to identify the changed areas.
86 a line-level diff first to identify the changed areas.
87 Defaults to true, which does a faster, slightly less optimal diff.
87 Defaults to true, which does a faster, slightly less optimal diff.
88 deadline: Optional time when the diff should be complete by. Used
88 deadline: Optional time when the diff should be complete by. Used
89 internally for recursive calls. Users should set DiffTimeout instead.
89 internally for recursive calls. Users should set DiffTimeout instead.
90
90
91 Returns:
91 Returns:
92 Array of changes.
92 Array of changes.
93 """
93 """
94 # Set a deadline by which time the diff must be complete.
94 # Set a deadline by which time the diff must be complete.
95 if deadline is None:
95 if deadline is None:
96 # Unlike in most languages, Python counts time in seconds.
96 # Unlike in most languages, Python counts time in seconds.
97 if self.Diff_Timeout <= 0:
97 if self.Diff_Timeout <= 0:
98 deadline = sys.maxsize
98 deadline = sys.maxsize
99 else:
99 else:
100 deadline = time.time() + self.Diff_Timeout
100 deadline = time.time() + self.Diff_Timeout
101
101
102 # Check for null inputs.
102 # Check for null inputs.
103 if text1 is None or text2 is None:
103 if text1 is None or text2 is None:
104 raise ValueError("Null inputs. (diff_main)")
104 raise ValueError("Null inputs. (diff_main)")
105
105
106 # Check for equality (speedup).
106 # Check for equality (speedup).
107 if text1 == text2:
107 if text1 == text2:
108 if text1:
108 if text1:
109 return [(self.DIFF_EQUAL, text1)]
109 return [(self.DIFF_EQUAL, text1)]
110 return []
110 return []
111
111
112 # Trim off common prefix (speedup).
112 # Trim off common prefix (speedup).
113 commonlength = self.diff_commonPrefix(text1, text2)
113 commonlength = self.diff_commonPrefix(text1, text2)
114 commonprefix = text1[:commonlength]
114 commonprefix = text1[:commonlength]
115 text1 = text1[commonlength:]
115 text1 = text1[commonlength:]
116 text2 = text2[commonlength:]
116 text2 = text2[commonlength:]
117
117
118 # Trim off common suffix (speedup).
118 # Trim off common suffix (speedup).
119 commonlength = self.diff_commonSuffix(text1, text2)
119 commonlength = self.diff_commonSuffix(text1, text2)
120 if commonlength == 0:
120 if commonlength == 0:
121 commonsuffix = ""
121 commonsuffix = ""
122 else:
122 else:
123 commonsuffix = text1[-commonlength:]
123 commonsuffix = text1[-commonlength:]
124 text1 = text1[:-commonlength]
124 text1 = text1[:-commonlength]
125 text2 = text2[:-commonlength]
125 text2 = text2[:-commonlength]
126
126
127 # Compute the diff on the middle block.
127 # Compute the diff on the middle block.
128 diffs = self.diff_compute(text1, text2, checklines, deadline)
128 diffs = self.diff_compute(text1, text2, checklines, deadline)
129
129
130 # Restore the prefix and suffix.
130 # Restore the prefix and suffix.
131 if commonprefix:
131 if commonprefix:
132 diffs[:0] = [(self.DIFF_EQUAL, commonprefix)]
132 diffs[:0] = [(self.DIFF_EQUAL, commonprefix)]
133 if commonsuffix:
133 if commonsuffix:
134 diffs.append((self.DIFF_EQUAL, commonsuffix))
134 diffs.append((self.DIFF_EQUAL, commonsuffix))
135 self.diff_cleanupMerge(diffs)
135 self.diff_cleanupMerge(diffs)
136 return diffs
136 return diffs
137
137
138 def diff_compute(self, text1, text2, checklines, deadline):
138 def diff_compute(self, text1, text2, checklines, deadline):
139 """Find the differences between two texts. Assumes that the texts do not
139 """Find the differences between two texts. Assumes that the texts do not
140 have any common prefix or suffix.
140 have any common prefix or suffix.
141
141
142 Args:
142 Args:
143 text1: Old string to be diffed.
143 text1: Old string to be diffed.
144 text2: New string to be diffed.
144 text2: New string to be diffed.
145 checklines: Speedup flag. If false, then don't run a line-level diff
145 checklines: Speedup flag. If false, then don't run a line-level diff
146 first to identify the changed areas.
146 first to identify the changed areas.
147 If true, then run a faster, slightly less optimal diff.
147 If true, then run a faster, slightly less optimal diff.
148 deadline: Time when the diff should be complete by.
148 deadline: Time when the diff should be complete by.
149
149
150 Returns:
150 Returns:
151 Array of changes.
151 Array of changes.
152 """
152 """
153 if not text1:
153 if not text1:
154 # Just add some text (speedup).
154 # Just add some text (speedup).
155 return [(self.DIFF_INSERT, text2)]
155 return [(self.DIFF_INSERT, text2)]
156
156
157 if not text2:
157 if not text2:
158 # Just delete some text (speedup).
158 # Just delete some text (speedup).
159 return [(self.DIFF_DELETE, text1)]
159 return [(self.DIFF_DELETE, text1)]
160
160
161 if len(text1) > len(text2):
161 if len(text1) > len(text2):
162 (longtext, shorttext) = (text1, text2)
162 (longtext, shorttext) = (text1, text2)
163 else:
163 else:
164 (shorttext, longtext) = (text1, text2)
164 (shorttext, longtext) = (text1, text2)
165 i = longtext.find(shorttext)
165 i = longtext.find(shorttext)
166 if i != -1:
166 if i != -1:
167 # Shorter text is inside the longer text (speedup).
167 # Shorter text is inside the longer text (speedup).
168 diffs = [
168 diffs = [
169 (self.DIFF_INSERT, longtext[:i]),
169 (self.DIFF_INSERT, longtext[:i]),
170 (self.DIFF_EQUAL, shorttext),
170 (self.DIFF_EQUAL, shorttext),
171 (self.DIFF_INSERT, longtext[i + len(shorttext) :]),
171 (self.DIFF_INSERT, longtext[i + len(shorttext) :]),
172 ]
172 ]
173 # Swap insertions for deletions if diff is reversed.
173 # Swap insertions for deletions if diff is reversed.
174 if len(text1) > len(text2):
174 if len(text1) > len(text2):
175 diffs[0] = (self.DIFF_DELETE, diffs[0][1])
175 diffs[0] = (self.DIFF_DELETE, diffs[0][1])
176 diffs[2] = (self.DIFF_DELETE, diffs[2][1])
176 diffs[2] = (self.DIFF_DELETE, diffs[2][1])
177 return diffs
177 return diffs
178
178
179 if len(shorttext) == 1:
179 if len(shorttext) == 1:
180 # Single character string.
180 # Single character string.
181 # After the previous speedup, the character can't be an equality.
181 # After the previous speedup, the character can't be an equality.
182 return [(self.DIFF_DELETE, text1), (self.DIFF_INSERT, text2)]
182 return [(self.DIFF_DELETE, text1), (self.DIFF_INSERT, text2)]
183
183
184 # Check to see if the problem can be split in two.
184 # Check to see if the problem can be split in two.
185 hm = self.diff_halfMatch(text1, text2)
185 hm = self.diff_halfMatch(text1, text2)
186 if hm:
186 if hm:
187 # A half-match was found, sort out the return data.
187 # A half-match was found, sort out the return data.
188 (text1_a, text1_b, text2_a, text2_b, mid_common) = hm
188 (text1_a, text1_b, text2_a, text2_b, mid_common) = hm
189 # Send both pairs off for separate processing.
189 # Send both pairs off for separate processing.
190 diffs_a = self.diff_main(text1_a, text2_a, checklines, deadline)
190 diffs_a = self.diff_main(text1_a, text2_a, checklines, deadline)
191 diffs_b = self.diff_main(text1_b, text2_b, checklines, deadline)
191 diffs_b = self.diff_main(text1_b, text2_b, checklines, deadline)
192 # Merge the results.
192 # Merge the results.
193 return diffs_a + [(self.DIFF_EQUAL, mid_common)] + diffs_b
193 return diffs_a + [(self.DIFF_EQUAL, mid_common)] + diffs_b
194
194
195 if checklines and len(text1) > 100 and len(text2) > 100:
195 if checklines and len(text1) > 100 and len(text2) > 100:
196 return self.diff_lineMode(text1, text2, deadline)
196 return self.diff_lineMode(text1, text2, deadline)
197
197
198 return self.diff_bisect(text1, text2, deadline)
198 return self.diff_bisect(text1, text2, deadline)
199
199
200 def diff_lineMode(self, text1, text2, deadline):
200 def diff_lineMode(self, text1, text2, deadline):
201 """Do a quick line-level diff on both strings, then rediff the parts for
201 """Do a quick line-level diff on both strings, then rediff the parts for
202 greater accuracy.
202 greater accuracy.
203 This speedup can produce non-minimal diffs.
203 This speedup can produce non-minimal diffs.
204
204
205 Args:
205 Args:
206 text1: Old string to be diffed.
206 text1: Old string to be diffed.
207 text2: New string to be diffed.
207 text2: New string to be diffed.
208 deadline: Time when the diff should be complete by.
208 deadline: Time when the diff should be complete by.
209
209
210 Returns:
210 Returns:
211 Array of changes.
211 Array of changes.
212 """
212 """
213
213
214 # Scan the text on a line-by-line basis first.
214 # Scan the text on a line-by-line basis first.
215 (text1, text2, linearray) = self.diff_linesToChars(text1, text2)
215 (text1, text2, linearray) = self.diff_linesToChars(text1, text2)
216
216
217 diffs = self.diff_main(text1, text2, False, deadline)
217 diffs = self.diff_main(text1, text2, False, deadline)
218
218
219 # Convert the diff back to original text.
219 # Convert the diff back to original text.
220 self.diff_charsToLines(diffs, linearray)
220 self.diff_charsToLines(diffs, linearray)
221 # Eliminate freak matches (e.g. blank lines)
221 # Eliminate freak matches (e.g. blank lines)
222 self.diff_cleanupSemantic(diffs)
222 self.diff_cleanupSemantic(diffs)
223
223
224 # Rediff any replacement blocks, this time character-by-character.
224 # Rediff any replacement blocks, this time character-by-character.
225 # Add a dummy entry at the end.
225 # Add a dummy entry at the end.
226 diffs.append((self.DIFF_EQUAL, ""))
226 diffs.append((self.DIFF_EQUAL, ""))
227 pointer = 0
227 pointer = 0
228 count_delete = 0
228 count_delete = 0
229 count_insert = 0
229 count_insert = 0
230 text_delete = ""
230 text_delete = ""
231 text_insert = ""
231 text_insert = ""
232 while pointer < len(diffs):
232 while pointer < len(diffs):
233 if diffs[pointer][0] == self.DIFF_INSERT:
233 if diffs[pointer][0] == self.DIFF_INSERT:
234 count_insert += 1
234 count_insert += 1
235 text_insert += diffs[pointer][1]
235 text_insert += diffs[pointer][1]
236 elif diffs[pointer][0] == self.DIFF_DELETE:
236 elif diffs[pointer][0] == self.DIFF_DELETE:
237 count_delete += 1
237 count_delete += 1
238 text_delete += diffs[pointer][1]
238 text_delete += diffs[pointer][1]
239 elif diffs[pointer][0] == self.DIFF_EQUAL:
239 elif diffs[pointer][0] == self.DIFF_EQUAL:
240 # Upon reaching an equality, check for prior redundancies.
240 # Upon reaching an equality, check for prior redundancies.
241 if count_delete >= 1 and count_insert >= 1:
241 if count_delete >= 1 and count_insert >= 1:
242 # Delete the offending records and add the merged ones.
242 # Delete the offending records and add the merged ones.
243 a = self.diff_main(text_delete, text_insert, False, deadline)
243 a = self.diff_main(text_delete, text_insert, False, deadline)
244 diffs[pointer - count_delete - count_insert : pointer] = a
244 diffs[pointer - count_delete - count_insert : pointer] = a
245 pointer = pointer - count_delete - count_insert + len(a)
245 pointer = pointer - count_delete - count_insert + len(a)
246 count_insert = 0
246 count_insert = 0
247 count_delete = 0
247 count_delete = 0
248 text_delete = ""
248 text_delete = ""
249 text_insert = ""
249 text_insert = ""
250
250
251 pointer += 1
251 pointer += 1
252
252
253 diffs.pop() # Remove the dummy entry at the end.
253 diffs.pop() # Remove the dummy entry at the end.
254
254
255 return diffs
255 return diffs
256
256
257 def diff_bisect(self, text1, text2, deadline):
257 def diff_bisect(self, text1, text2, deadline):
258 """Find the 'middle snake' of a diff, split the problem in two
258 """Find the 'middle snake' of a diff, split the problem in two
259 and return the recursively constructed diff.
259 and return the recursively constructed diff.
260 See Myers 1986 paper: An O(ND) Difference Algorithm and Its Variations.
260 See Myers 1986 paper: An O(ND) Difference Algorithm and Its Variations.
261
261
262 Args:
262 Args:
263 text1: Old string to be diffed.
263 text1: Old string to be diffed.
264 text2: New string to be diffed.
264 text2: New string to be diffed.
265 deadline: Time at which to bail if not yet complete.
265 deadline: Time at which to bail if not yet complete.
266
266
267 Returns:
267 Returns:
268 Array of diff tuples.
268 Array of diff tuples.
269 """
269 """
270
270
271 # Cache the text lengths to prevent multiple calls.
271 # Cache the text lengths to prevent multiple calls.
272 text1_length = len(text1)
272 text1_length = len(text1)
273 text2_length = len(text2)
273 text2_length = len(text2)
274 max_d = (text1_length + text2_length + 1) // 2
274 max_d = (text1_length + text2_length + 1) // 2
275 v_offset = max_d
275 v_offset = max_d
276 v_length = 2 * max_d
276 v_length = 2 * max_d
277 v1 = [-1] * v_length
277 v1 = [-1] * v_length
278 v1[v_offset + 1] = 0
278 v1[v_offset + 1] = 0
279 v2 = v1[:]
279 v2 = v1[:]
280 delta = text1_length - text2_length
280 delta = text1_length - text2_length
281 # If the total number of characters is odd, then the front path will
281 # If the total number of characters is odd, then the front path will
282 # collide with the reverse path.
282 # collide with the reverse path.
283 front = delta % 2 != 0
283 front = delta % 2 != 0
284 # Offsets for start and end of k loop.
284 # Offsets for start and end of k loop.
285 # Prevents mapping of space beyond the grid.
285 # Prevents mapping of space beyond the grid.
286 k1start = 0
286 k1start = 0
287 k1end = 0
287 k1end = 0
288 k2start = 0
288 k2start = 0
289 k2end = 0
289 k2end = 0
290 for d in range(max_d):
290 for d in range(max_d):
291 # Bail out if deadline is reached.
291 # Bail out if deadline is reached.
292 if time.time() > deadline:
292 if time.time() > deadline:
293 break
293 break
294
294
295 # Walk the front path one step.
295 # Walk the front path one step.
296 for k1 in range(-d + k1start, d + 1 - k1end, 2):
296 for k1 in range(-d + k1start, d + 1 - k1end, 2):
297 k1_offset = v_offset + k1
297 k1_offset = v_offset + k1
298 if k1 == -d or (k1 != d and v1[k1_offset - 1] < v1[k1_offset + 1]):
298 if k1 == -d or (k1 != d and v1[k1_offset - 1] < v1[k1_offset + 1]):
299 x1 = v1[k1_offset + 1]
299 x1 = v1[k1_offset + 1]
300 else:
300 else:
301 x1 = v1[k1_offset - 1] + 1
301 x1 = v1[k1_offset - 1] + 1
302 y1 = x1 - k1
302 y1 = x1 - k1
303 while (
303 while (
304 x1 < text1_length and y1 < text2_length and text1[x1] == text2[y1]
304 x1 < text1_length and y1 < text2_length and text1[x1] == text2[y1]
305 ):
305 ):
306 x1 += 1
306 x1 += 1
307 y1 += 1
307 y1 += 1
308 v1[k1_offset] = x1
308 v1[k1_offset] = x1
309 if x1 > text1_length:
309 if x1 > text1_length:
310 # Ran off the right of the graph.
310 # Ran off the right of the graph.
311 k1end += 2
311 k1end += 2
312 elif y1 > text2_length:
312 elif y1 > text2_length:
313 # Ran off the bottom of the graph.
313 # Ran off the bottom of the graph.
314 k1start += 2
314 k1start += 2
315 elif front:
315 elif front:
316 k2_offset = v_offset + delta - k1
316 k2_offset = v_offset + delta - k1
317 if k2_offset >= 0 and k2_offset < v_length and v2[k2_offset] != -1:
317 if k2_offset >= 0 and k2_offset < v_length and v2[k2_offset] != -1:
318 # Mirror x2 onto top-left coordinate system.
318 # Mirror x2 onto top-left coordinate system.
319 x2 = text1_length - v2[k2_offset]
319 x2 = text1_length - v2[k2_offset]
320 if x1 >= x2:
320 if x1 >= x2:
321 # Overlap detected.
321 # Overlap detected.
322 return self.diff_bisectSplit(text1, text2, x1, y1, deadline)
322 return self.diff_bisectSplit(text1, text2, x1, y1, deadline)
323
323
324 # Walk the reverse path one step.
324 # Walk the reverse path one step.
325 for k2 in range(-d + k2start, d + 1 - k2end, 2):
325 for k2 in range(-d + k2start, d + 1 - k2end, 2):
326 k2_offset = v_offset + k2
326 k2_offset = v_offset + k2
327 if k2 == -d or (k2 != d and v2[k2_offset - 1] < v2[k2_offset + 1]):
327 if k2 == -d or (k2 != d and v2[k2_offset - 1] < v2[k2_offset + 1]):
328 x2 = v2[k2_offset + 1]
328 x2 = v2[k2_offset + 1]
329 else:
329 else:
330 x2 = v2[k2_offset - 1] + 1
330 x2 = v2[k2_offset - 1] + 1
331 y2 = x2 - k2
331 y2 = x2 - k2
332 while (
332 while (
333 x2 < text1_length
333 x2 < text1_length
334 and y2 < text2_length
334 and y2 < text2_length
335 and text1[-x2 - 1] == text2[-y2 - 1]
335 and text1[-x2 - 1] == text2[-y2 - 1]
336 ):
336 ):
337 x2 += 1
337 x2 += 1
338 y2 += 1
338 y2 += 1
339 v2[k2_offset] = x2
339 v2[k2_offset] = x2
340 if x2 > text1_length:
340 if x2 > text1_length:
341 # Ran off the left of the graph.
341 # Ran off the left of the graph.
342 k2end += 2
342 k2end += 2
343 elif y2 > text2_length:
343 elif y2 > text2_length:
344 # Ran off the top of the graph.
344 # Ran off the top of the graph.
345 k2start += 2
345 k2start += 2
346 elif not front:
346 elif not front:
347 k1_offset = v_offset + delta - k2
347 k1_offset = v_offset + delta - k2
348 if k1_offset >= 0 and k1_offset < v_length and v1[k1_offset] != -1:
348 if k1_offset >= 0 and k1_offset < v_length and v1[k1_offset] != -1:
349 x1 = v1[k1_offset]
349 x1 = v1[k1_offset]
350 y1 = v_offset + x1 - k1_offset
350 y1 = v_offset + x1 - k1_offset
351 # Mirror x2 onto top-left coordinate system.
351 # Mirror x2 onto top-left coordinate system.
352 x2 = text1_length - x2
352 x2 = text1_length - x2
353 if x1 >= x2:
353 if x1 >= x2:
354 # Overlap detected.
354 # Overlap detected.
355 return self.diff_bisectSplit(text1, text2, x1, y1, deadline)
355 return self.diff_bisectSplit(text1, text2, x1, y1, deadline)
356
356
357 # Diff took too long and hit the deadline or
357 # Diff took too long and hit the deadline or
358 # number of diffs equals number of characters, no commonality at all.
358 # number of diffs equals number of characters, no commonality at all.
359 return [(self.DIFF_DELETE, text1), (self.DIFF_INSERT, text2)]
359 return [(self.DIFF_DELETE, text1), (self.DIFF_INSERT, text2)]
360
360
361 def diff_bisectSplit(self, text1, text2, x, y, deadline):
361 def diff_bisectSplit(self, text1, text2, x, y, deadline):
362 """Given the location of the 'middle snake', split the diff in two parts
362 """Given the location of the 'middle snake', split the diff in two parts
363 and recurse.
363 and recurse.
364
364
365 Args:
365 Args:
366 text1: Old string to be diffed.
366 text1: Old string to be diffed.
367 text2: New string to be diffed.
367 text2: New string to be diffed.
368 x: Index of split point in text1.
368 x: Index of split point in text1.
369 y: Index of split point in text2.
369 y: Index of split point in text2.
370 deadline: Time at which to bail if not yet complete.
370 deadline: Time at which to bail if not yet complete.
371
371
372 Returns:
372 Returns:
373 Array of diff tuples.
373 Array of diff tuples.
374 """
374 """
375 text1a = text1[:x]
375 text1a = text1[:x]
376 text2a = text2[:y]
376 text2a = text2[:y]
377 text1b = text1[x:]
377 text1b = text1[x:]
378 text2b = text2[y:]
378 text2b = text2[y:]
379
379
380 # Compute both diffs serially.
380 # Compute both diffs serially.
381 diffs = self.diff_main(text1a, text2a, False, deadline)
381 diffs = self.diff_main(text1a, text2a, False, deadline)
382 diffsb = self.diff_main(text1b, text2b, False, deadline)
382 diffsb = self.diff_main(text1b, text2b, False, deadline)
383
383
384 return diffs + diffsb
384 return diffs + diffsb
385
385
386 def diff_linesToChars(self, text1, text2):
386 def diff_linesToChars(self, text1, text2):
387 """Split two texts into an array of strings. Reduce the texts to a string
387 """Split two texts into an array of strings. Reduce the texts to a string
388 of hashes where each Unicode character represents one line.
388 of hashes where each Unicode character represents one line.
389
389
390 Args:
390 Args:
391 text1: First string.
391 text1: First string.
392 text2: Second string.
392 text2: Second string.
393
393
394 Returns:
394 Returns:
395 Three element tuple, containing the encoded text1, the encoded text2 and
395 Three element tuple, containing the encoded text1, the encoded text2 and
396 the array of unique strings. The zeroth element of the array of unique
396 the array of unique strings. The zeroth element of the array of unique
397 strings is intentionally blank.
397 strings is intentionally blank.
398 """
398 """
399 lineArray = [] # e.g. lineArray[4] == "Hello\n"
399 lineArray = [] # e.g. lineArray[4] == "Hello\n"
400 lineHash = {} # e.g. lineHash["Hello\n"] == 4
400 lineHash = {} # e.g. lineHash["Hello\n"] == 4
401
401
402 # "\x00" is a valid character, but various debuggers don't like it.
402 # "\x00" is a valid character, but various debuggers don't like it.
403 # So we'll insert a junk entry to avoid generating a null character.
403 # So we'll insert a junk entry to avoid generating a null character.
404 lineArray.append("")
404 lineArray.append("")
405
405
406 def diff_linesToCharsMunge(text):
406 def diff_linesToCharsMunge(text):
407 """Split a text into an array of strings. Reduce the texts to a string
407 """Split a text into an array of strings. Reduce the texts to a string
408 of hashes where each Unicode character represents one line.
408 of hashes where each Unicode character represents one line.
409 Modifies linearray and linehash through being a closure.
409 Modifies linearray and linehash through being a closure.
410
410
411 Args:
411 Args:
412 text: String to encode.
412 text: String to encode.
413
413
414 Returns:
414 Returns:
415 Encoded string.
415 Encoded string.
416 """
416 """
417 chars = []
417 chars = []
418 # Walk the text, pulling out a substring for each line.
418 # Walk the text, pulling out a substring for each line.
419 # text.split('\n') would would temporarily double our memory footprint.
419 # text.split('\n') would would temporarily double our memory footprint.
420 # Modifying text would create many large strings to garbage collect.
420 # Modifying text would create many large strings to garbage collect.
421 lineStart = 0
421 lineStart = 0
422 lineEnd = -1
422 lineEnd = -1
423 while lineEnd < len(text) - 1:
423 while lineEnd < len(text) - 1:
424 lineEnd = text.find("\n", lineStart)
424 lineEnd = text.find("\n", lineStart)
425 if lineEnd == -1:
425 if lineEnd == -1:
426 lineEnd = len(text) - 1
426 lineEnd = len(text) - 1
427 line = text[lineStart : lineEnd + 1]
427 line = text[lineStart : lineEnd + 1]
428 lineStart = lineEnd + 1
428 lineStart = lineEnd + 1
429
429
430 if line in lineHash:
430 if line in lineHash:
431 chars.append(chr(lineHash[line]))
431 chars.append(chr(lineHash[line]))
432 else:
432 else:
433 lineArray.append(line)
433 lineArray.append(line)
434 lineHash[line] = len(lineArray) - 1
434 lineHash[line] = len(lineArray) - 1
435 chars.append(chr(len(lineArray) - 1))
435 chars.append(chr(len(lineArray) - 1))
436 return "".join(chars)
436 return "".join(chars)
437
437
438 chars1 = diff_linesToCharsMunge(text1)
438 chars1 = diff_linesToCharsMunge(text1)
439 chars2 = diff_linesToCharsMunge(text2)
439 chars2 = diff_linesToCharsMunge(text2)
440 return (chars1, chars2, lineArray)
440 return (chars1, chars2, lineArray)
441
441
442 def diff_charsToLines(self, diffs, lineArray):
442 def diff_charsToLines(self, diffs, lineArray):
443 """Rehydrate the text in a diff from a string of line hashes to real lines
443 """Rehydrate the text in a diff from a string of line hashes to real lines
444 of text.
444 of text.
445
445
446 Args:
446 Args:
447 diffs: Array of diff tuples.
447 diffs: Array of diff tuples.
448 lineArray: Array of unique strings.
448 lineArray: Array of unique strings.
449 """
449 """
450 for x in range(len(diffs)):
450 for x in range(len(diffs)):
451 text = []
451 text = []
452 for char in diffs[x][1]:
452 for char in diffs[x][1]:
453 text.append(lineArray[ord(char)])
453 text.append(lineArray[ord(char)])
454 diffs[x] = (diffs[x][0], "".join(text))
454 diffs[x] = (diffs[x][0], "".join(text))
455
455
456 def diff_commonPrefix(self, text1, text2):
456 def diff_commonPrefix(self, text1, text2):
457 """Determine the common prefix of two strings.
457 """Determine the common prefix of two strings.
458
458
459 Args:
459 Args:
460 text1: First string.
460 text1: First string.
461 text2: Second string.
461 text2: Second string.
462
462
463 Returns:
463 Returns:
464 The number of characters common to the start of each string.
464 The number of characters common to the start of each string.
465 """
465 """
466 # Quick check for common null cases.
466 # Quick check for common null cases.
467 if not text1 or not text2 or text1[0] != text2[0]:
467 if not text1 or not text2 or text1[0] != text2[0]:
468 return 0
468 return 0
469 # Binary search.
469 # Binary search.
470 # Performance analysis: http://neil.fraser.name/news/2007/10/09/
470 # Performance analysis: http://neil.fraser.name/news/2007/10/09/
471 pointermin = 0
471 pointermin = 0
472 pointermax = min(len(text1), len(text2))
472 pointermax = min(len(text1), len(text2))
473 pointermid = pointermax
473 pointermid = pointermax
474 pointerstart = 0
474 pointerstart = 0
475 while pointermin < pointermid:
475 while pointermin < pointermid:
476 if text1[pointerstart:pointermid] == text2[pointerstart:pointermid]:
476 if text1[pointerstart:pointermid] == text2[pointerstart:pointermid]:
477 pointermin = pointermid
477 pointermin = pointermid
478 pointerstart = pointermin
478 pointerstart = pointermin
479 else:
479 else:
480 pointermax = pointermid
480 pointermax = pointermid
481 pointermid = (pointermax - pointermin) // 2 + pointermin
481 pointermid = (pointermax - pointermin) // 2 + pointermin
482 return pointermid
482 return pointermid
483
483
484 def diff_commonSuffix(self, text1, text2):
484 def diff_commonSuffix(self, text1, text2):
485 """Determine the common suffix of two strings.
485 """Determine the common suffix of two strings.
486
486
487 Args:
487 Args:
488 text1: First string.
488 text1: First string.
489 text2: Second string.
489 text2: Second string.
490
490
491 Returns:
491 Returns:
492 The number of characters common to the end of each string.
492 The number of characters common to the end of each string.
493 """
493 """
494 # Quick check for common null cases.
494 # Quick check for common null cases.
495 if not text1 or not text2 or text1[-1] != text2[-1]:
495 if not text1 or not text2 or text1[-1] != text2[-1]:
496 return 0
496 return 0
497 # Binary search.
497 # Binary search.
498 # Performance analysis: http://neil.fraser.name/news/2007/10/09/
498 # Performance analysis: http://neil.fraser.name/news/2007/10/09/
499 pointermin = 0
499 pointermin = 0
500 pointermax = min(len(text1), len(text2))
500 pointermax = min(len(text1), len(text2))
501 pointermid = pointermax
501 pointermid = pointermax
502 pointerend = 0
502 pointerend = 0
503 while pointermin < pointermid:
503 while pointermin < pointermid:
504 if (
504 if (
505 text1[-pointermid : len(text1) - pointerend]
505 text1[-pointermid : len(text1) - pointerend]
506 == text2[-pointermid : len(text2) - pointerend]
506 == text2[-pointermid : len(text2) - pointerend]
507 ):
507 ):
508 pointermin = pointermid
508 pointermin = pointermid
509 pointerend = pointermin
509 pointerend = pointermin
510 else:
510 else:
511 pointermax = pointermid
511 pointermax = pointermid
512 pointermid = (pointermax - pointermin) // 2 + pointermin
512 pointermid = (pointermax - pointermin) // 2 + pointermin
513 return pointermid
513 return pointermid
514
514
515 def diff_commonOverlap(self, text1, text2):
515 def diff_commonOverlap(self, text1, text2):
516 """Determine if the suffix of one string is the prefix of another.
516 """Determine if the suffix of one string is the prefix of another.
517
517
518 Args:
518 Args:
519 text1 First string.
519 text1 First string.
520 text2 Second string.
520 text2 Second string.
521
521
522 Returns:
522 Returns:
523 The number of characters common to the end of the first
523 The number of characters common to the end of the first
524 string and the start of the second string.
524 string and the start of the second string.
525 """
525 """
526 # Cache the text lengths to prevent multiple calls.
526 # Cache the text lengths to prevent multiple calls.
527 text1_length = len(text1)
527 text1_length = len(text1)
528 text2_length = len(text2)
528 text2_length = len(text2)
529 # Eliminate the null case.
529 # Eliminate the null case.
530 if text1_length == 0 or text2_length == 0:
530 if text1_length == 0 or text2_length == 0:
531 return 0
531 return 0
532 # Truncate the longer string.
532 # Truncate the longer string.
533 if text1_length > text2_length:
533 if text1_length > text2_length:
534 text1 = text1[-text2_length:]
534 text1 = text1[-text2_length:]
535 elif text1_length < text2_length:
535 elif text1_length < text2_length:
536 text2 = text2[:text1_length]
536 text2 = text2[:text1_length]
537 text_length = min(text1_length, text2_length)
537 text_length = min(text1_length, text2_length)
538 # Quick check for the worst case.
538 # Quick check for the worst case.
539 if text1 == text2:
539 if text1 == text2:
540 return text_length
540 return text_length
541
541
542 # Start by looking for a single character match
542 # Start by looking for a single character match
543 # and increase length until no match is found.
543 # and increase length until no match is found.
544 # Performance analysis: http://neil.fraser.name/news/2010/11/04/
544 # Performance analysis: http://neil.fraser.name/news/2010/11/04/
545 best = 0
545 best = 0
546 length = 1
546 length = 1
547 while True:
547 while True:
548 pattern = text1[-length:]
548 pattern = text1[-length:]
549 found = text2.find(pattern)
549 found = text2.find(pattern)
550 if found == -1:
550 if found == -1:
551 return best
551 return best
552 length += found
552 length += found
553 if found == 0 or text1[-length:] == text2[:length]:
553 if found == 0 or text1[-length:] == text2[:length]:
554 best = length
554 best = length
555 length += 1
555 length += 1
556
556
557 def diff_halfMatch(self, text1, text2):
557 def diff_halfMatch(self, text1, text2):
558 """Do the two texts share a substring which is at least half the length of
558 """Do the two texts share a substring which is at least half the length of
559 the longer text?
559 the longer text?
560 This speedup can produce non-minimal diffs.
560 This speedup can produce non-minimal diffs.
561
561
562 Args:
562 Args:
563 text1: First string.
563 text1: First string.
564 text2: Second string.
564 text2: Second string.
565
565
566 Returns:
566 Returns:
567 Five element Array, containing the prefix of text1, the suffix of text1,
567 Five element Array, containing the prefix of text1, the suffix of text1,
568 the prefix of text2, the suffix of text2 and the common middle. Or None
568 the prefix of text2, the suffix of text2 and the common middle. Or None
569 if there was no match.
569 if there was no match.
570 """
570 """
571 if self.Diff_Timeout <= 0:
571 if self.Diff_Timeout <= 0:
572 # Don't risk returning a non-optimal diff if we have unlimited time.
572 # Don't risk returning a non-optimal diff if we have unlimited time.
573 return None
573 return None
574 if len(text1) > len(text2):
574 if len(text1) > len(text2):
575 (longtext, shorttext) = (text1, text2)
575 (longtext, shorttext) = (text1, text2)
576 else:
576 else:
577 (shorttext, longtext) = (text1, text2)
577 (shorttext, longtext) = (text1, text2)
578 if len(longtext) < 4 or len(shorttext) * 2 < len(longtext):
578 if len(longtext) < 4 or len(shorttext) * 2 < len(longtext):
579 return None # Pointless.
579 return None # Pointless.
580
580
581 def diff_halfMatchI(longtext, shorttext, i):
581 def diff_halfMatchI(longtext, shorttext, i):
582 """Does a substring of shorttext exist within longtext such that the
582 """Does a substring of shorttext exist within longtext such that the
583 substring is at least half the length of longtext?
583 substring is at least half the length of longtext?
584 Closure, but does not reference any external variables.
584 Closure, but does not reference any external variables.
585
585
586 Args:
586 Args:
587 longtext: Longer string.
587 longtext: Longer string.
588 shorttext: Shorter string.
588 shorttext: Shorter string.
589 i: Start index of quarter length substring within longtext.
589 i: Start index of quarter length substring within longtext.
590
590
591 Returns:
591 Returns:
592 Five element Array, containing the prefix of longtext, the suffix of
592 Five element Array, containing the prefix of longtext, the suffix of
593 longtext, the prefix of shorttext, the suffix of shorttext and the
593 longtext, the prefix of shorttext, the suffix of shorttext and the
594 common middle. Or None if there was no match.
594 common middle. Or None if there was no match.
595 """
595 """
596 seed = longtext[i : i + len(longtext) // 4]
596 seed = longtext[i : i + len(longtext) // 4]
597 best_common = ""
597 best_common = ""
598 j = shorttext.find(seed)
598 j = shorttext.find(seed)
599 while j != -1:
599 while j != -1:
600 prefixLength = self.diff_commonPrefix(longtext[i:], shorttext[j:])
600 prefixLength = self.diff_commonPrefix(longtext[i:], shorttext[j:])
601 suffixLength = self.diff_commonSuffix(longtext[:i], shorttext[:j])
601 suffixLength = self.diff_commonSuffix(longtext[:i], shorttext[:j])
602 if len(best_common) < suffixLength + prefixLength:
602 if len(best_common) < suffixLength + prefixLength:
603 best_common = (
603 best_common = (
604 shorttext[j - suffixLength : j]
604 shorttext[j - suffixLength : j]
605 + shorttext[j : j + prefixLength]
605 + shorttext[j : j + prefixLength]
606 )
606 )
607 best_longtext_a = longtext[: i - suffixLength]
607 best_longtext_a = longtext[: i - suffixLength]
608 best_longtext_b = longtext[i + prefixLength :]
608 best_longtext_b = longtext[i + prefixLength :]
609 best_shorttext_a = shorttext[: j - suffixLength]
609 best_shorttext_a = shorttext[: j - suffixLength]
610 best_shorttext_b = shorttext[j + prefixLength :]
610 best_shorttext_b = shorttext[j + prefixLength :]
611 j = shorttext.find(seed, j + 1)
611 j = shorttext.find(seed, j + 1)
612
612
613 if len(best_common) * 2 >= len(longtext):
613 if len(best_common) * 2 >= len(longtext):
614 return (
614 return (
615 best_longtext_a,
615 best_longtext_a,
616 best_longtext_b,
616 best_longtext_b,
617 best_shorttext_a,
617 best_shorttext_a,
618 best_shorttext_b,
618 best_shorttext_b,
619 best_common,
619 best_common,
620 )
620 )
621 else:
621 else:
622 return None
622 return None
623
623
624 # First check if the second quarter is the seed for a half-match.
624 # First check if the second quarter is the seed for a half-match.
625 hm1 = diff_halfMatchI(longtext, shorttext, (len(longtext) + 3) // 4)
625 hm1 = diff_halfMatchI(longtext, shorttext, (len(longtext) + 3) // 4)
626 # Check again based on the third quarter.
626 # Check again based on the third quarter.
627 hm2 = diff_halfMatchI(longtext, shorttext, (len(longtext) + 1) // 2)
627 hm2 = diff_halfMatchI(longtext, shorttext, (len(longtext) + 1) // 2)
628 if not hm1 and not hm2:
628 if not hm1 and not hm2:
629 return None
629 return None
630 elif not hm2:
630 elif not hm2:
631 hm = hm1
631 hm = hm1
632 elif not hm1:
632 elif not hm1:
633 hm = hm2
633 hm = hm2
634 else:
634 else:
635 # Both matched. Select the longest.
635 # Both matched. Select the longest.
636 if len(hm1[4]) > len(hm2[4]):
636 if len(hm1[4]) > len(hm2[4]):
637 hm = hm1
637 hm = hm1
638 else:
638 else:
639 hm = hm2
639 hm = hm2
640
640
641 # A half-match was found, sort out the return data.
641 # A half-match was found, sort out the return data.
642 if len(text1) > len(text2):
642 if len(text1) > len(text2):
643 (text1_a, text1_b, text2_a, text2_b, mid_common) = hm
643 (text1_a, text1_b, text2_a, text2_b, mid_common) = hm
644 else:
644 else:
645 (text2_a, text2_b, text1_a, text1_b, mid_common) = hm
645 (text2_a, text2_b, text1_a, text1_b, mid_common) = hm
646 return (text1_a, text1_b, text2_a, text2_b, mid_common)
646 return (text1_a, text1_b, text2_a, text2_b, mid_common)
647
647
648 def diff_cleanupSemantic(self, diffs):
648 def diff_cleanupSemantic(self, diffs):
649 """Reduce the number of edits by eliminating semantically trivial
649 """Reduce the number of edits by eliminating semantically trivial
650 equalities.
650 equalities.
651
651
652 Args:
652 Args:
653 diffs: Array of diff tuples.
653 diffs: Array of diff tuples.
654 """
654 """
655 changes = False
655 changes = False
656 equalities = [] # Stack of indices where equalities are found.
656 equalities = [] # Stack of indices where equalities are found.
657 lastequality = None # Always equal to diffs[equalities[-1]][1]
657 lastequality = None # Always equal to diffs[equalities[-1]][1]
658 pointer = 0 # Index of current position.
658 pointer = 0 # Index of current position.
659 # Number of chars that changed prior to the equality.
659 # Number of chars that changed prior to the equality.
660 length_insertions1, length_deletions1 = 0, 0
660 length_insertions1, length_deletions1 = 0, 0
661 # Number of chars that changed after the equality.
661 # Number of chars that changed after the equality.
662 length_insertions2, length_deletions2 = 0, 0
662 length_insertions2, length_deletions2 = 0, 0
663 while pointer < len(diffs):
663 while pointer < len(diffs):
664 if diffs[pointer][0] == self.DIFF_EQUAL: # Equality found.
664 if diffs[pointer][0] == self.DIFF_EQUAL: # Equality found.
665 equalities.append(pointer)
665 equalities.append(pointer)
666 length_insertions1, length_insertions2 = length_insertions2, 0
666 length_insertions1, length_insertions2 = length_insertions2, 0
667 length_deletions1, length_deletions2 = length_deletions2, 0
667 length_deletions1, length_deletions2 = length_deletions2, 0
668 lastequality = diffs[pointer][1]
668 lastequality = diffs[pointer][1]
669 else: # An insertion or deletion.
669 else: # An insertion or deletion.
670 if diffs[pointer][0] == self.DIFF_INSERT:
670 if diffs[pointer][0] == self.DIFF_INSERT:
671 length_insertions2 += len(diffs[pointer][1])
671 length_insertions2 += len(diffs[pointer][1])
672 else:
672 else:
673 length_deletions2 += len(diffs[pointer][1])
673 length_deletions2 += len(diffs[pointer][1])
674 # Eliminate an equality that is smaller or equal to the edits on both
674 # Eliminate an equality that is smaller or equal to the edits on both
675 # sides of it.
675 # sides of it.
676 if (
676 if (
677 lastequality
677 lastequality
678 and (
678 and (
679 len(lastequality) <= max(length_insertions1, length_deletions1)
679 len(lastequality) <= max(length_insertions1, length_deletions1)
680 )
680 )
681 and (
681 and (
682 len(lastequality) <= max(length_insertions2, length_deletions2)
682 len(lastequality) <= max(length_insertions2, length_deletions2)
683 )
683 )
684 ):
684 ):
685 # Duplicate record.
685 # Duplicate record.
686 diffs.insert(equalities[-1], (self.DIFF_DELETE, lastequality))
686 diffs.insert(equalities[-1], (self.DIFF_DELETE, lastequality))
687 # Change second copy to insert.
687 # Change second copy to insert.
688 diffs[equalities[-1] + 1] = (
688 diffs[equalities[-1] + 1] = (
689 self.DIFF_INSERT,
689 self.DIFF_INSERT,
690 diffs[equalities[-1] + 1][1],
690 diffs[equalities[-1] + 1][1],
691 )
691 )
692 # Throw away the equality we just deleted.
692 # Throw away the equality we just deleted.
693 equalities.pop()
693 equalities.pop()
694 # Throw away the previous equality (it needs to be reevaluated).
694 # Throw away the previous equality (it needs to be reevaluated).
695 if len(equalities):
695 if len(equalities):
696 equalities.pop()
696 equalities.pop()
697 if len(equalities):
697 if len(equalities):
698 pointer = equalities[-1]
698 pointer = equalities[-1]
699 else:
699 else:
700 pointer = -1
700 pointer = -1
701 # Reset the counters.
701 # Reset the counters.
702 length_insertions1, length_deletions1 = 0, 0
702 length_insertions1, length_deletions1 = 0, 0
703 length_insertions2, length_deletions2 = 0, 0
703 length_insertions2, length_deletions2 = 0, 0
704 lastequality = None
704 lastequality = None
705 changes = True
705 changes = True
706 pointer += 1
706 pointer += 1
707
707
708 # Normalize the diff.
708 # Normalize the diff.
709 if changes:
709 if changes:
710 self.diff_cleanupMerge(diffs)
710 self.diff_cleanupMerge(diffs)
711 self.diff_cleanupSemanticLossless(diffs)
711 self.diff_cleanupSemanticLossless(diffs)
712
712
713 # Find any overlaps between deletions and insertions.
713 # Find any overlaps between deletions and insertions.
714 # e.g: <del>abcxxx</del><ins>xxxdef</ins>
714 # e.g: <del>abcxxx</del><ins>xxxdef</ins>
715 # -> <del>abc</del>xxx<ins>def</ins>
715 # -> <del>abc</del>xxx<ins>def</ins>
716 # e.g: <del>xxxabc</del><ins>defxxx</ins>
716 # e.g: <del>xxxabc</del><ins>defxxx</ins>
717 # -> <ins>def</ins>xxx<del>abc</del>
717 # -> <ins>def</ins>xxx<del>abc</del>
718 # Only extract an overlap if it is as big as the edit ahead or behind it.
718 # Only extract an overlap if it is as big as the edit ahead or behind it.
719 pointer = 1
719 pointer = 1
720 while pointer < len(diffs):
720 while pointer < len(diffs):
721 if (
721 if (
722 diffs[pointer - 1][0] == self.DIFF_DELETE
722 diffs[pointer - 1][0] == self.DIFF_DELETE
723 and diffs[pointer][0] == self.DIFF_INSERT
723 and diffs[pointer][0] == self.DIFF_INSERT
724 ):
724 ):
725 deletion = diffs[pointer - 1][1]
725 deletion = diffs[pointer - 1][1]
726 insertion = diffs[pointer][1]
726 insertion = diffs[pointer][1]
727 overlap_length1 = self.diff_commonOverlap(deletion, insertion)
727 overlap_length1 = self.diff_commonOverlap(deletion, insertion)
728 overlap_length2 = self.diff_commonOverlap(insertion, deletion)
728 overlap_length2 = self.diff_commonOverlap(insertion, deletion)
729 if overlap_length1 >= overlap_length2:
729 if overlap_length1 >= overlap_length2:
730 if (
730 if (
731 overlap_length1 >= len(deletion) / 2.0
731 overlap_length1 >= len(deletion) / 2.0
732 or overlap_length1 >= len(insertion) / 2.0
732 or overlap_length1 >= len(insertion) / 2.0
733 ):
733 ):
734 # Overlap found. Insert an equality and trim the surrounding edits.
734 # Overlap found. Insert an equality and trim the surrounding edits.
735 diffs.insert(
735 diffs.insert(
736 pointer, (self.DIFF_EQUAL, insertion[:overlap_length1])
736 pointer, (self.DIFF_EQUAL, insertion[:overlap_length1])
737 )
737 )
738 diffs[pointer - 1] = (
738 diffs[pointer - 1] = (
739 self.DIFF_DELETE,
739 self.DIFF_DELETE,
740 deletion[: len(deletion) - overlap_length1],
740 deletion[: len(deletion) - overlap_length1],
741 )
741 )
742 diffs[pointer + 1] = (
742 diffs[pointer + 1] = (
743 self.DIFF_INSERT,
743 self.DIFF_INSERT,
744 insertion[overlap_length1:],
744 insertion[overlap_length1:],
745 )
745 )
746 pointer += 1
746 pointer += 1
747 else:
747 else:
748 if (
748 if (
749 overlap_length2 >= len(deletion) / 2.0
749 overlap_length2 >= len(deletion) / 2.0
750 or overlap_length2 >= len(insertion) / 2.0
750 or overlap_length2 >= len(insertion) / 2.0
751 ):
751 ):
752 # Reverse overlap found.
752 # Reverse overlap found.
753 # Insert an equality and swap and trim the surrounding edits.
753 # Insert an equality and swap and trim the surrounding edits.
754 diffs.insert(
754 diffs.insert(
755 pointer, (self.DIFF_EQUAL, deletion[:overlap_length2])
755 pointer, (self.DIFF_EQUAL, deletion[:overlap_length2])
756 )
756 )
757 diffs[pointer - 1] = (
757 diffs[pointer - 1] = (
758 self.DIFF_INSERT,
758 self.DIFF_INSERT,
759 insertion[: len(insertion) - overlap_length2],
759 insertion[: len(insertion) - overlap_length2],
760 )
760 )
761 diffs[pointer + 1] = (
761 diffs[pointer + 1] = (
762 self.DIFF_DELETE,
762 self.DIFF_DELETE,
763 deletion[overlap_length2:],
763 deletion[overlap_length2:],
764 )
764 )
765 pointer += 1
765 pointer += 1
766 pointer += 1
766 pointer += 1
767 pointer += 1
767 pointer += 1
768
768
769 def diff_cleanupSemanticLossless(self, diffs):
769 def diff_cleanupSemanticLossless(self, diffs):
770 """Look for single edits surrounded on both sides by equalities
770 """Look for single edits surrounded on both sides by equalities
771 which can be shifted sideways to align the edit to a word boundary.
771 which can be shifted sideways to align the edit to a word boundary.
772 e.g: The c<ins>at c</ins>ame. -> The <ins>cat </ins>came.
772 e.g: The c<ins>at c</ins>ame. -> The <ins>cat </ins>came.
773
773
774 Args:
774 Args:
775 diffs: Array of diff tuples.
775 diffs: Array of diff tuples.
776 """
776 """
777
777
778 def diff_cleanupSemanticScore(one, two):
778 def diff_cleanupSemanticScore(one, two):
779 """Given two strings, compute a score representing whether the
779 """Given two strings, compute a score representing whether the
780 internal boundary falls on logical boundaries.
780 internal boundary falls on logical boundaries.
781 Scores range from 6 (best) to 0 (worst).
781 Scores range from 6 (best) to 0 (worst).
782 Closure, but does not reference any external variables.
782 Closure, but does not reference any external variables.
783
783
784 Args:
784 Args:
785 one: First string.
785 one: First string.
786 two: Second string.
786 two: Second string.
787
787
788 Returns:
788 Returns:
789 The score.
789 The score.
790 """
790 """
791 if not one or not two:
791 if not one or not two:
792 # Edges are the best.
792 # Edges are the best.
793 return 6
793 return 6
794
794
795 # Each port of this function behaves slightly differently due to
795 # Each port of this function behaves slightly differently due to
796 # subtle differences in each language's definition of things like
796 # subtle differences in each language's definition of things like
797 # 'whitespace'. Since this function's purpose is largely cosmetic,
797 # 'whitespace'. Since this function's purpose is largely cosmetic,
798 # the choice has been made to use each language's native features
798 # the choice has been made to use each language's native features
799 # rather than force total conformity.
799 # rather than force total conformity.
800 char1 = one[-1]
800 char1 = one[-1]
801 char2 = two[0]
801 char2 = two[0]
802 nonAlphaNumeric1 = not char1.isalnum()
802 nonAlphaNumeric1 = not char1.isalnum()
803 nonAlphaNumeric2 = not char2.isalnum()
803 nonAlphaNumeric2 = not char2.isalnum()
804 whitespace1 = nonAlphaNumeric1 and char1.isspace()
804 whitespace1 = nonAlphaNumeric1 and char1.isspace()
805 whitespace2 = nonAlphaNumeric2 and char2.isspace()
805 whitespace2 = nonAlphaNumeric2 and char2.isspace()
806 lineBreak1 = whitespace1 and (char1 == "\r" or char1 == "\n")
806 lineBreak1 = whitespace1 and (char1 == "\r" or char1 == "\n")
807 lineBreak2 = whitespace2 and (char2 == "\r" or char2 == "\n")
807 lineBreak2 = whitespace2 and (char2 == "\r" or char2 == "\n")
808 blankLine1 = lineBreak1 and self.BLANKLINEEND.search(one)
808 blankLine1 = lineBreak1 and self.BLANKLINEEND.search(one)
809 blankLine2 = lineBreak2 and self.BLANKLINESTART.match(two)
809 blankLine2 = lineBreak2 and self.BLANKLINESTART.match(two)
810
810
811 if blankLine1 or blankLine2:
811 if blankLine1 or blankLine2:
812 # Five points for blank lines.
812 # Five points for blank lines.
813 return 5
813 return 5
814 elif lineBreak1 or lineBreak2:
814 elif lineBreak1 or lineBreak2:
815 # Four points for line breaks.
815 # Four points for line breaks.
816 return 4
816 return 4
817 elif nonAlphaNumeric1 and not whitespace1 and whitespace2:
817 elif nonAlphaNumeric1 and not whitespace1 and whitespace2:
818 # Three points for end of sentences.
818 # Three points for end of sentences.
819 return 3
819 return 3
820 elif whitespace1 or whitespace2:
820 elif whitespace1 or whitespace2:
821 # Two points for whitespace.
821 # Two points for whitespace.
822 return 2
822 return 2
823 elif nonAlphaNumeric1 or nonAlphaNumeric2:
823 elif nonAlphaNumeric1 or nonAlphaNumeric2:
824 # One point for non-alphanumeric.
824 # One point for non-alphanumeric.
825 return 1
825 return 1
826 return 0
826 return 0
827
827
828 pointer = 1
828 pointer = 1
829 # Intentionally ignore the first and last element (don't need checking).
829 # Intentionally ignore the first and last element (don't need checking).
830 while pointer < len(diffs) - 1:
830 while pointer < len(diffs) - 1:
831 if (
831 if (
832 diffs[pointer - 1][0] == self.DIFF_EQUAL
832 diffs[pointer - 1][0] == self.DIFF_EQUAL
833 and diffs[pointer + 1][0] == self.DIFF_EQUAL
833 and diffs[pointer + 1][0] == self.DIFF_EQUAL
834 ):
834 ):
835 # This is a single edit surrounded by equalities.
835 # This is a single edit surrounded by equalities.
836 equality1 = diffs[pointer - 1][1]
836 equality1 = diffs[pointer - 1][1]
837 edit = diffs[pointer][1]
837 edit = diffs[pointer][1]
838 equality2 = diffs[pointer + 1][1]
838 equality2 = diffs[pointer + 1][1]
839
839
840 # First, shift the edit as far left as possible.
840 # First, shift the edit as far left as possible.
841 commonOffset = self.diff_commonSuffix(equality1, edit)
841 commonOffset = self.diff_commonSuffix(equality1, edit)
842 if commonOffset:
842 if commonOffset:
843 commonString = edit[-commonOffset:]
843 commonString = edit[-commonOffset:]
844 equality1 = equality1[:-commonOffset]
844 equality1 = equality1[:-commonOffset]
845 edit = commonString + edit[:-commonOffset]
845 edit = commonString + edit[:-commonOffset]
846 equality2 = commonString + equality2
846 equality2 = commonString + equality2
847
847
848 # Second, step character by character right, looking for the best fit.
848 # Second, step character by character right, looking for the best fit.
849 bestEquality1 = equality1
849 bestEquality1 = equality1
850 bestEdit = edit
850 bestEdit = edit
851 bestEquality2 = equality2
851 bestEquality2 = equality2
852 bestScore = diff_cleanupSemanticScore(
852 bestScore = diff_cleanupSemanticScore(
853 equality1, edit
853 equality1, edit
854 ) + diff_cleanupSemanticScore(edit, equality2)
854 ) + diff_cleanupSemanticScore(edit, equality2)
855 while edit and equality2 and edit[0] == equality2[0]:
855 while edit and equality2 and edit[0] == equality2[0]:
856 equality1 += edit[0]
856 equality1 += edit[0]
857 edit = edit[1:] + equality2[0]
857 edit = edit[1:] + equality2[0]
858 equality2 = equality2[1:]
858 equality2 = equality2[1:]
859 score = diff_cleanupSemanticScore(
859 score = diff_cleanupSemanticScore(
860 equality1, edit
860 equality1, edit
861 ) + diff_cleanupSemanticScore(edit, equality2)
861 ) + diff_cleanupSemanticScore(edit, equality2)
862 # The >= encourages trailing rather than leading whitespace on edits.
862 # The >= encourages trailing rather than leading whitespace on edits.
863 if score >= bestScore:
863 if score >= bestScore:
864 bestScore = score
864 bestScore = score
865 bestEquality1 = equality1
865 bestEquality1 = equality1
866 bestEdit = edit
866 bestEdit = edit
867 bestEquality2 = equality2
867 bestEquality2 = equality2
868
868
869 if diffs[pointer - 1][1] != bestEquality1:
869 if diffs[pointer - 1][1] != bestEquality1:
870 # We have an improvement, save it back to the diff.
870 # We have an improvement, save it back to the diff.
871 if bestEquality1:
871 if bestEquality1:
872 diffs[pointer - 1] = (diffs[pointer - 1][0], bestEquality1)
872 diffs[pointer - 1] = (diffs[pointer - 1][0], bestEquality1)
873 else:
873 else:
874 del diffs[pointer - 1]
874 del diffs[pointer - 1]
875 pointer -= 1
875 pointer -= 1
876 diffs[pointer] = (diffs[pointer][0], bestEdit)
876 diffs[pointer] = (diffs[pointer][0], bestEdit)
877 if bestEquality2:
877 if bestEquality2:
878 diffs[pointer + 1] = (diffs[pointer + 1][0], bestEquality2)
878 diffs[pointer + 1] = (diffs[pointer + 1][0], bestEquality2)
879 else:
879 else:
880 del diffs[pointer + 1]
880 del diffs[pointer + 1]
881 pointer -= 1
881 pointer -= 1
882 pointer += 1
882 pointer += 1
883
883
884 # Define some regex patterns for matching boundaries.
884 # Define some regex patterns for matching boundaries.
885 BLANKLINEEND = re.compile(r"\n\r?\n$")
885 BLANKLINEEND = re.compile(r"\n\r?\n$")
886 BLANKLINESTART = re.compile(r"^\r?\n\r?\n")
886 BLANKLINESTART = re.compile(r"^\r?\n\r?\n")
887
887
888 def diff_cleanupEfficiency(self, diffs):
888 def diff_cleanupEfficiency(self, diffs):
889 """Reduce the number of edits by eliminating operationally trivial
889 """Reduce the number of edits by eliminating operationally trivial
890 equalities.
890 equalities.
891
891
892 Args:
892 Args:
893 diffs: Array of diff tuples.
893 diffs: Array of diff tuples.
894 """
894 """
895 changes = False
895 changes = False
896 equalities = [] # Stack of indices where equalities are found.
896 equalities = [] # Stack of indices where equalities are found.
897 lastequality = None # Always equal to diffs[equalities[-1]][1]
897 lastequality = None # Always equal to diffs[equalities[-1]][1]
898 pointer = 0 # Index of current position.
898 pointer = 0 # Index of current position.
899 pre_ins = False # Is there an insertion operation before the last equality.
899 pre_ins = False # Is there an insertion operation before the last equality.
900 pre_del = False # Is there a deletion operation before the last equality.
900 pre_del = False # Is there a deletion operation before the last equality.
901 post_ins = False # Is there an insertion operation after the last equality.
901 post_ins = False # Is there an insertion operation after the last equality.
902 post_del = False # Is there a deletion operation after the last equality.
902 post_del = False # Is there a deletion operation after the last equality.
903 while pointer < len(diffs):
903 while pointer < len(diffs):
904 if diffs[pointer][0] == self.DIFF_EQUAL: # Equality found.
904 if diffs[pointer][0] == self.DIFF_EQUAL: # Equality found.
905 if len(diffs[pointer][1]) < self.Diff_EditCost and (
905 if len(diffs[pointer][1]) < self.Diff_EditCost and (
906 post_ins or post_del
906 post_ins or post_del
907 ):
907 ):
908 # Candidate found.
908 # Candidate found.
909 equalities.append(pointer)
909 equalities.append(pointer)
910 pre_ins = post_ins
910 pre_ins = post_ins
911 pre_del = post_del
911 pre_del = post_del
912 lastequality = diffs[pointer][1]
912 lastequality = diffs[pointer][1]
913 else:
913 else:
914 # Not a candidate, and can never become one.
914 # Not a candidate, and can never become one.
915 equalities = []
915 equalities = []
916 lastequality = None
916 lastequality = None
917
917
918 post_ins = post_del = False
918 post_ins = post_del = False
919 else: # An insertion or deletion.
919 else: # An insertion or deletion.
920 if diffs[pointer][0] == self.DIFF_DELETE:
920 if diffs[pointer][0] == self.DIFF_DELETE:
921 post_del = True
921 post_del = True
922 else:
922 else:
923 post_ins = True
923 post_ins = True
924
924
925 # Five types to be split:
925 # Five types to be split:
926 # <ins>A</ins><del>B</del>XY<ins>C</ins><del>D</del>
926 # <ins>A</ins><del>B</del>XY<ins>C</ins><del>D</del>
927 # <ins>A</ins>X<ins>C</ins><del>D</del>
927 # <ins>A</ins>X<ins>C</ins><del>D</del>
928 # <ins>A</ins><del>B</del>X<ins>C</ins>
928 # <ins>A</ins><del>B</del>X<ins>C</ins>
929 # <ins>A</del>X<ins>C</ins><del>D</del>
929 # <ins>A</del>X<ins>C</ins><del>D</del>
930 # <ins>A</ins><del>B</del>X<del>C</del>
930 # <ins>A</ins><del>B</del>X<del>C</del>
931
931
932 if lastequality and (
932 if lastequality and (
933 (pre_ins and pre_del and post_ins and post_del)
933 (pre_ins and pre_del and post_ins and post_del)
934 or (
934 or (
935 (len(lastequality) < self.Diff_EditCost / 2)
935 (len(lastequality) < self.Diff_EditCost / 2)
936 and (pre_ins + pre_del + post_ins + post_del) == 3
936 and (pre_ins + pre_del + post_ins + post_del) == 3
937 )
937 )
938 ):
938 ):
939 # Duplicate record.
939 # Duplicate record.
940 diffs.insert(equalities[-1], (self.DIFF_DELETE, lastequality))
940 diffs.insert(equalities[-1], (self.DIFF_DELETE, lastequality))
941 # Change second copy to insert.
941 # Change second copy to insert.
942 diffs[equalities[-1] + 1] = (
942 diffs[equalities[-1] + 1] = (
943 self.DIFF_INSERT,
943 self.DIFF_INSERT,
944 diffs[equalities[-1] + 1][1],
944 diffs[equalities[-1] + 1][1],
945 )
945 )
946 equalities.pop() # Throw away the equality we just deleted.
946 equalities.pop() # Throw away the equality we just deleted.
947 lastequality = None
947 lastequality = None
948 if pre_ins and pre_del:
948 if pre_ins and pre_del:
949 # No changes made which could affect previous entry, keep going.
949 # No changes made which could affect previous entry, keep going.
950 post_ins = post_del = True
950 post_ins = post_del = True
951 equalities = []
951 equalities = []
952 else:
952 else:
953 if len(equalities):
953 if len(equalities):
954 equalities.pop() # Throw away the previous equality.
954 equalities.pop() # Throw away the previous equality.
955 if len(equalities):
955 if len(equalities):
956 pointer = equalities[-1]
956 pointer = equalities[-1]
957 else:
957 else:
958 pointer = -1
958 pointer = -1
959 post_ins = post_del = False
959 post_ins = post_del = False
960 changes = True
960 changes = True
961 pointer += 1
961 pointer += 1
962
962
963 if changes:
963 if changes:
964 self.diff_cleanupMerge(diffs)
964 self.diff_cleanupMerge(diffs)
965
965
966 def diff_cleanupMerge(self, diffs):
966 def diff_cleanupMerge(self, diffs):
967 """Reorder and merge like edit sections. Merge equalities.
967 """Reorder and merge like edit sections. Merge equalities.
968 Any edit section can move as long as it doesn't cross an equality.
968 Any edit section can move as long as it doesn't cross an equality.
969
969
970 Args:
970 Args:
971 diffs: Array of diff tuples.
971 diffs: Array of diff tuples.
972 """
972 """
973 diffs.append((self.DIFF_EQUAL, "")) # Add a dummy entry at the end.
973 diffs.append((self.DIFF_EQUAL, "")) # Add a dummy entry at the end.
974 pointer = 0
974 pointer = 0
975 count_delete = 0
975 count_delete = 0
976 count_insert = 0
976 count_insert = 0
977 text_delete = ""
977 text_delete = ""
978 text_insert = ""
978 text_insert = ""
979 while pointer < len(diffs):
979 while pointer < len(diffs):
980 if diffs[pointer][0] == self.DIFF_INSERT:
980 if diffs[pointer][0] == self.DIFF_INSERT:
981 count_insert += 1
981 count_insert += 1
982 text_insert += diffs[pointer][1]
982 text_insert += diffs[pointer][1]
983 pointer += 1
983 pointer += 1
984 elif diffs[pointer][0] == self.DIFF_DELETE:
984 elif diffs[pointer][0] == self.DIFF_DELETE:
985 count_delete += 1
985 count_delete += 1
986 text_delete += diffs[pointer][1]
986 text_delete += diffs[pointer][1]
987 pointer += 1
987 pointer += 1
988 elif diffs[pointer][0] == self.DIFF_EQUAL:
988 elif diffs[pointer][0] == self.DIFF_EQUAL:
989 # Upon reaching an equality, check for prior redundancies.
989 # Upon reaching an equality, check for prior redundancies.
990 if count_delete + count_insert > 1:
990 if count_delete + count_insert > 1:
991 if count_delete != 0 and count_insert != 0:
991 if count_delete != 0 and count_insert != 0:
992 # Factor out any common prefixies.
992 # Factor out any common prefixies.
993 commonlength = self.diff_commonPrefix(text_insert, text_delete)
993 commonlength = self.diff_commonPrefix(text_insert, text_delete)
994 if commonlength != 0:
994 if commonlength != 0:
995 x = pointer - count_delete - count_insert - 1
995 x = pointer - count_delete - count_insert - 1
996 if x >= 0 and diffs[x][0] == self.DIFF_EQUAL:
996 if x >= 0 and diffs[x][0] == self.DIFF_EQUAL:
997 diffs[x] = (
997 diffs[x] = (
998 diffs[x][0],
998 diffs[x][0],
999 diffs[x][1] + text_insert[:commonlength],
999 diffs[x][1] + text_insert[:commonlength],
1000 )
1000 )
1001 else:
1001 else:
1002 diffs.insert(
1002 diffs.insert(
1003 0, (self.DIFF_EQUAL, text_insert[:commonlength])
1003 0, (self.DIFF_EQUAL, text_insert[:commonlength])
1004 )
1004 )
1005 pointer += 1
1005 pointer += 1
1006 text_insert = text_insert[commonlength:]
1006 text_insert = text_insert[commonlength:]
1007 text_delete = text_delete[commonlength:]
1007 text_delete = text_delete[commonlength:]
1008 # Factor out any common suffixies.
1008 # Factor out any common suffixies.
1009 commonlength = self.diff_commonSuffix(text_insert, text_delete)
1009 commonlength = self.diff_commonSuffix(text_insert, text_delete)
1010 if commonlength != 0:
1010 if commonlength != 0:
1011 diffs[pointer] = (
1011 diffs[pointer] = (
1012 diffs[pointer][0],
1012 diffs[pointer][0],
1013 text_insert[-commonlength:] + diffs[pointer][1],
1013 text_insert[-commonlength:] + diffs[pointer][1],
1014 )
1014 )
1015 text_insert = text_insert[:-commonlength]
1015 text_insert = text_insert[:-commonlength]
1016 text_delete = text_delete[:-commonlength]
1016 text_delete = text_delete[:-commonlength]
1017 # Delete the offending records and add the merged ones.
1017 # Delete the offending records and add the merged ones.
1018 if count_delete == 0:
1018 if count_delete == 0:
1019 diffs[pointer - count_insert : pointer] = [
1019 diffs[pointer - count_insert : pointer] = [
1020 (self.DIFF_INSERT, text_insert)
1020 (self.DIFF_INSERT, text_insert)
1021 ]
1021 ]
1022 elif count_insert == 0:
1022 elif count_insert == 0:
1023 diffs[pointer - count_delete : pointer] = [
1023 diffs[pointer - count_delete : pointer] = [
1024 (self.DIFF_DELETE, text_delete)
1024 (self.DIFF_DELETE, text_delete)
1025 ]
1025 ]
1026 else:
1026 else:
1027 diffs[pointer - count_delete - count_insert : pointer] = [
1027 diffs[pointer - count_delete - count_insert : pointer] = [
1028 (self.DIFF_DELETE, text_delete),
1028 (self.DIFF_DELETE, text_delete),
1029 (self.DIFF_INSERT, text_insert),
1029 (self.DIFF_INSERT, text_insert),
1030 ]
1030 ]
1031 pointer = pointer - count_delete - count_insert + 1
1031 pointer = pointer - count_delete - count_insert + 1
1032 if count_delete != 0:
1032 if count_delete != 0:
1033 pointer += 1
1033 pointer += 1
1034 if count_insert != 0:
1034 if count_insert != 0:
1035 pointer += 1
1035 pointer += 1
1036 elif pointer != 0 and diffs[pointer - 1][0] == self.DIFF_EQUAL:
1036 elif pointer != 0 and diffs[pointer - 1][0] == self.DIFF_EQUAL:
1037 # Merge this equality with the previous one.
1037 # Merge this equality with the previous one.
1038 diffs[pointer - 1] = (
1038 diffs[pointer - 1] = (
1039 diffs[pointer - 1][0],
1039 diffs[pointer - 1][0],
1040 diffs[pointer - 1][1] + diffs[pointer][1],
1040 diffs[pointer - 1][1] + diffs[pointer][1],
1041 )
1041 )
1042 del diffs[pointer]
1042 del diffs[pointer]
1043 else:
1043 else:
1044 pointer += 1
1044 pointer += 1
1045
1045
1046 count_insert = 0
1046 count_insert = 0
1047 count_delete = 0
1047 count_delete = 0
1048 text_delete = ""
1048 text_delete = ""
1049 text_insert = ""
1049 text_insert = ""
1050
1050
1051 if diffs[-1][1] == "":
1051 if diffs[-1][1] == "":
1052 diffs.pop() # Remove the dummy entry at the end.
1052 diffs.pop() # Remove the dummy entry at the end.
1053
1053
1054 # Second pass: look for single edits surrounded on both sides by equalities
1054 # Second pass: look for single edits surrounded on both sides by equalities
1055 # which can be shifted sideways to eliminate an equality.
1055 # which can be shifted sideways to eliminate an equality.
1056 # e.g: A<ins>BA</ins>C -> <ins>AB</ins>AC
1056 # e.g: A<ins>BA</ins>C -> <ins>AB</ins>AC
1057 changes = False
1057 changes = False
1058 pointer = 1
1058 pointer = 1
1059 # Intentionally ignore the first and last element (don't need checking).
1059 # Intentionally ignore the first and last element (don't need checking).
1060 while pointer < len(diffs) - 1:
1060 while pointer < len(diffs) - 1:
1061 if (
1061 if (
1062 diffs[pointer - 1][0] == self.DIFF_EQUAL
1062 diffs[pointer - 1][0] == self.DIFF_EQUAL
1063 and diffs[pointer + 1][0] == self.DIFF_EQUAL
1063 and diffs[pointer + 1][0] == self.DIFF_EQUAL
1064 ):
1064 ):
1065 # This is a single edit surrounded by equalities.
1065 # This is a single edit surrounded by equalities.
1066 if diffs[pointer][1].endswith(diffs[pointer - 1][1]):
1066 if diffs[pointer][1].endswith(diffs[pointer - 1][1]):
1067 # Shift the edit over the previous equality.
1067 # Shift the edit over the previous equality.
1068 diffs[pointer] = (
1068 diffs[pointer] = (
1069 diffs[pointer][0],
1069 diffs[pointer][0],
1070 diffs[pointer - 1][1]
1070 diffs[pointer - 1][1]
1071 + diffs[pointer][1][: -len(diffs[pointer - 1][1])],
1071 + diffs[pointer][1][: -len(diffs[pointer - 1][1])],
1072 )
1072 )
1073 diffs[pointer + 1] = (
1073 diffs[pointer + 1] = (
1074 diffs[pointer + 1][0],
1074 diffs[pointer + 1][0],
1075 diffs[pointer - 1][1] + diffs[pointer + 1][1],
1075 diffs[pointer - 1][1] + diffs[pointer + 1][1],
1076 )
1076 )
1077 del diffs[pointer - 1]
1077 del diffs[pointer - 1]
1078 changes = True
1078 changes = True
1079 elif diffs[pointer][1].startswith(diffs[pointer + 1][1]):
1079 elif diffs[pointer][1].startswith(diffs[pointer + 1][1]):
1080 # Shift the edit over the next equality.
1080 # Shift the edit over the next equality.
1081 diffs[pointer - 1] = (
1081 diffs[pointer - 1] = (
1082 diffs[pointer - 1][0],
1082 diffs[pointer - 1][0],
1083 diffs[pointer - 1][1] + diffs[pointer + 1][1],
1083 diffs[pointer - 1][1] + diffs[pointer + 1][1],
1084 )
1084 )
1085 diffs[pointer] = (
1085 diffs[pointer] = (
1086 diffs[pointer][0],
1086 diffs[pointer][0],
1087 diffs[pointer][1][len(diffs[pointer + 1][1]) :]
1087 diffs[pointer][1][len(diffs[pointer + 1][1]) :]
1088 + diffs[pointer + 1][1],
1088 + diffs[pointer + 1][1],
1089 )
1089 )
1090 del diffs[pointer + 1]
1090 del diffs[pointer + 1]
1091 changes = True
1091 changes = True
1092 pointer += 1
1092 pointer += 1
1093
1093
1094 # If shifts were made, the diff needs reordering and another shift sweep.
1094 # If shifts were made, the diff needs reordering and another shift sweep.
1095 if changes:
1095 if changes:
1096 self.diff_cleanupMerge(diffs)
1096 self.diff_cleanupMerge(diffs)
1097
1097
1098 def diff_xIndex(self, diffs, loc):
1098 def diff_xIndex(self, diffs, loc):
1099 """loc is a location in text1, compute and return the equivalent location
1099 """loc is a location in text1, compute and return the equivalent location
1100 in text2. e.g. "The cat" vs "The big cat", 1->1, 5->8
1100 in text2. e.g. "The cat" vs "The big cat", 1->1, 5->8
1101
1101
1102 Args:
1102 Args:
1103 diffs: Array of diff tuples.
1103 diffs: Array of diff tuples.
1104 loc: Location within text1.
1104 loc: Location within text1.
1105
1105
1106 Returns:
1106 Returns:
1107 Location within text2.
1107 Location within text2.
1108 """
1108 """
1109 chars1 = 0
1109 chars1 = 0
1110 chars2 = 0
1110 chars2 = 0
1111 last_chars1 = 0
1111 last_chars1 = 0
1112 last_chars2 = 0
1112 last_chars2 = 0
1113 for x in range(len(diffs)):
1113 for x in range(len(diffs)):
1114 (op, text) = diffs[x]
1114 (op, text) = diffs[x]
1115 if op != self.DIFF_INSERT: # Equality or deletion.
1115 if op != self.DIFF_INSERT: # Equality or deletion.
1116 chars1 += len(text)
1116 chars1 += len(text)
1117 if op != self.DIFF_DELETE: # Equality or insertion.
1117 if op != self.DIFF_DELETE: # Equality or insertion.
1118 chars2 += len(text)
1118 chars2 += len(text)
1119 if chars1 > loc: # Overshot the location.
1119 if chars1 > loc: # Overshot the location.
1120 break
1120 break
1121 last_chars1 = chars1
1121 last_chars1 = chars1
1122 last_chars2 = chars2
1122 last_chars2 = chars2
1123
1123
1124 if len(diffs) != x and diffs[x][0] == self.DIFF_DELETE:
1124 if len(diffs) != x and diffs[x][0] == self.DIFF_DELETE:
1125 # The location was deleted.
1125 # The location was deleted.
1126 return last_chars2
1126 return last_chars2
1127 # Add the remaining len(character).
1127 # Add the remaining len(character).
1128 return last_chars2 + (loc - last_chars1)
1128 return last_chars2 + (loc - last_chars1)
1129
1129
1130 def diff_prettyHtml(self, diffs):
1130 def diff_prettyHtml(self, diffs):
1131 """Convert a diff array into a pretty HTML report.
1131 """Convert a diff array into a pretty HTML report.
1132
1132
1133 Args:
1133 Args:
1134 diffs: Array of diff tuples.
1134 diffs: Array of diff tuples.
1135
1135
1136 Returns:
1136 Returns:
1137 HTML representation.
1137 HTML representation.
1138 """
1138 """
1139 html = []
1139 html = []
1140 for op, data in diffs:
1140 for op, data in diffs:
1141 text = (
1141 text = (
1142 data.replace("&", "&amp;")
1142 data.replace("&", "&amp;")
1143 .replace("<", "&lt;")
1143 .replace("<", "&lt;")
1144 .replace(">", "&gt;")
1144 .replace(">", "&gt;")
1145 .replace("\n", "&para;<br>")
1145 .replace("\n", "&para;<br>")
1146 )
1146 )
1147 if op == self.DIFF_INSERT:
1147 if op == self.DIFF_INSERT:
1148 html.append('<ins style="background:#e6ffe6;">%s</ins>' % text)
1148 html.append('<ins style="background:#e6ffe6;">%s</ins>' % text)
1149 elif op == self.DIFF_DELETE:
1149 elif op == self.DIFF_DELETE:
1150 html.append('<del style="background:#ffe6e6;">%s</del>' % text)
1150 html.append('<del style="background:#ffe6e6;">%s</del>' % text)
1151 elif op == self.DIFF_EQUAL:
1151 elif op == self.DIFF_EQUAL:
1152 html.append("<span>%s</span>" % text)
1152 html.append("<span>%s</span>" % text)
1153 return "".join(html)
1153 return "".join(html)
1154
1154
1155 def diff_text1(self, diffs):
1155 def diff_text1(self, diffs):
1156 """Compute and return the source text (all equalities and deletions).
1156 """Compute and return the source text (all equalities and deletions).
1157
1157
1158 Args:
1158 Args:
1159 diffs: Array of diff tuples.
1159 diffs: Array of diff tuples.
1160
1160
1161 Returns:
1161 Returns:
1162 Source text.
1162 Source text.
1163 """
1163 """
1164 text = []
1164 text = []
1165 for op, data in diffs:
1165 for op, data in diffs:
1166 if op != self.DIFF_INSERT:
1166 if op != self.DIFF_INSERT:
1167 text.append(data)
1167 text.append(data)
1168 return "".join(text)
1168 return "".join(text)
1169
1169
1170 def diff_text2(self, diffs):
1170 def diff_text2(self, diffs):
1171 """Compute and return the destination text (all equalities and insertions).
1171 """Compute and return the destination text (all equalities and insertions).
1172
1172
1173 Args:
1173 Args:
1174 diffs: Array of diff tuples.
1174 diffs: Array of diff tuples.
1175
1175
1176 Returns:
1176 Returns:
1177 Destination text.
1177 Destination text.
1178 """
1178 """
1179 text = []
1179 text = []
1180 for op, data in diffs:
1180 for op, data in diffs:
1181 if op != self.DIFF_DELETE:
1181 if op != self.DIFF_DELETE:
1182 text.append(data)
1182 text.append(data)
1183 return "".join(text)
1183 return "".join(text)
1184
1184
1185 def diff_levenshtein(self, diffs):
1185 def diff_levenshtein(self, diffs):
1186 """Compute the Levenshtein distance; the number of inserted, deleted or
1186 """Compute the Levenshtein distance; the number of inserted, deleted or
1187 substituted characters.
1187 substituted characters.
1188
1188
1189 Args:
1189 Args:
1190 diffs: Array of diff tuples.
1190 diffs: Array of diff tuples.
1191
1191
1192 Returns:
1192 Returns:
1193 Number of changes.
1193 Number of changes.
1194 """
1194 """
1195 levenshtein = 0
1195 levenshtein = 0
1196 insertions = 0
1196 insertions = 0
1197 deletions = 0
1197 deletions = 0
1198 for op, data in diffs:
1198 for op, data in diffs:
1199 if op == self.DIFF_INSERT:
1199 if op == self.DIFF_INSERT:
1200 insertions += len(data)
1200 insertions += len(data)
1201 elif op == self.DIFF_DELETE:
1201 elif op == self.DIFF_DELETE:
1202 deletions += len(data)
1202 deletions += len(data)
1203 elif op == self.DIFF_EQUAL:
1203 elif op == self.DIFF_EQUAL:
1204 # A deletion and an insertion is one substitution.
1204 # A deletion and an insertion is one substitution.
1205 levenshtein += max(insertions, deletions)
1205 levenshtein += max(insertions, deletions)
1206 insertions = 0
1206 insertions = 0
1207 deletions = 0
1207 deletions = 0
1208 levenshtein += max(insertions, deletions)
1208 levenshtein += max(insertions, deletions)
1209 return levenshtein
1209 return levenshtein
1210
1210
1211 def diff_toDelta(self, diffs):
1211 def diff_toDelta(self, diffs):
1212 """Crush the diff into an encoded string which describes the operations
1212 """Crush the diff into an encoded string which describes the operations
1213 required to transform text1 into text2.
1213 required to transform text1 into text2.
1214 E.g. =3\t-2\t+ing -> Keep 3 chars, delete 2 chars, insert 'ing'.
1214 E.g. =3\t-2\t+ing -> Keep 3 chars, delete 2 chars, insert 'ing'.
1215 Operations are tab-separated. Inserted text is escaped using %xx notation.
1215 Operations are tab-separated. Inserted text is escaped using %xx notation.
1216
1216
1217 Args:
1217 Args:
1218 diffs: Array of diff tuples.
1218 diffs: Array of diff tuples.
1219
1219
1220 Returns:
1220 Returns:
1221 Delta text.
1221 Delta text.
1222 """
1222 """
1223 text = []
1223 text = []
1224 for op, data in diffs:
1224 for op, data in diffs:
1225 if op == self.DIFF_INSERT:
1225 if op == self.DIFF_INSERT:
1226 # High ascii will raise UnicodeDecodeError. Use Unicode instead.
1226 # High ascii will raise UnicodeDecodeError. Use Unicode instead.
1227 data = data.encode("utf-8")
1227 data = data.encode("utf-8")
1228 text.append("+" + urllib.parse.quote(data, "!~*'();/?:@&=+$,# "))
1228 text.append("+" + urllib.parse.quote(data, "!~*'();/?:@&=+$,# "))
1229 elif op == self.DIFF_DELETE:
1229 elif op == self.DIFF_DELETE:
1230 text.append("-%d" % len(data))
1230 text.append("-%d" % len(data))
1231 elif op == self.DIFF_EQUAL:
1231 elif op == self.DIFF_EQUAL:
1232 text.append("=%d" % len(data))
1232 text.append("=%d" % len(data))
1233 return "\t".join(text)
1233 return "\t".join(text)
1234
1234
1235 def diff_fromDelta(self, text1, delta):
1235 def diff_fromDelta(self, text1, delta):
1236 """Given the original text1, and an encoded string which describes the
1236 """Given the original text1, and an encoded string which describes the
1237 operations required to transform text1 into text2, compute the full diff.
1237 operations required to transform text1 into text2, compute the full diff.
1238
1238
1239 Args:
1239 Args:
1240 text1: Source string for the diff.
1240 text1: Source string for the diff.
1241 delta: Delta text.
1241 delta: Delta text.
1242
1242
1243 Returns:
1243 Returns:
1244 Array of diff tuples.
1244 Array of diff tuples.
1245
1245
1246 Raises:
1246 Raises:
1247 ValueError: If invalid input.
1247 ValueError: If invalid input.
1248 """
1248 """
1249 if type(delta) == str:
1249 if type(delta) == str:
1250 # Deltas should be composed of a subset of ascii chars, Unicode not
1250 # Deltas should be composed of a subset of ascii chars, Unicode not
1251 # required. If this encode raises UnicodeEncodeError, delta is invalid.
1251 # required. If this encode raises UnicodeEncodeError, delta is invalid.
1252 delta = delta.encode("ascii")
1252 delta = delta.encode("ascii")
1253 diffs = []
1253 diffs = []
1254 pointer = 0 # Cursor in text1
1254 pointer = 0 # Cursor in text1
1255 tokens = delta.split("\t")
1255 tokens = delta.split("\t")
1256 for token in tokens:
1256 for token in tokens:
1257 if token == "":
1257 if token == "":
1258 # Blank tokens are ok (from a trailing \t).
1258 # Blank tokens are ok (from a trailing \t).
1259 continue
1259 continue
1260 # Each token begins with a one character parameter which specifies the
1260 # Each token begins with a one character parameter which specifies the
1261 # operation of this token (delete, insert, equality).
1261 # operation of this token (delete, insert, equality).
1262 param = token[1:]
1262 param = token[1:]
1263 if token[0] == "+":
1263 if token[0] == "+":
1264 param = urllib.parse.unquote(param)
1264 param = urllib.parse.unquote(param)
1265 diffs.append((self.DIFF_INSERT, param))
1265 diffs.append((self.DIFF_INSERT, param))
1266 elif token[0] == "-" or token[0] == "=":
1266 elif token[0] == "-" or token[0] == "=":
1267 try:
1267 try:
1268 n = int(param)
1268 n = int(param)
1269 except ValueError:
1269 except ValueError:
1270 raise ValueError("Invalid number in diff_fromDelta: " + param)
1270 raise ValueError("Invalid number in diff_fromDelta: " + param)
1271 if n < 0:
1271 if n < 0:
1272 raise ValueError("Negative number in diff_fromDelta: " + param)
1272 raise ValueError("Negative number in diff_fromDelta: " + param)
1273 text = text1[pointer : pointer + n]
1273 text = text1[pointer : pointer + n]
1274 pointer += n
1274 pointer += n
1275 if token[0] == "=":
1275 if token[0] == "=":
1276 diffs.append((self.DIFF_EQUAL, text))
1276 diffs.append((self.DIFF_EQUAL, text))
1277 else:
1277 else:
1278 diffs.append((self.DIFF_DELETE, text))
1278 diffs.append((self.DIFF_DELETE, text))
1279 else:
1279 else:
1280 # Anything else is an error.
1280 # Anything else is an error.
1281 raise ValueError(
1281 raise ValueError(
1282 "Invalid diff operation in diff_fromDelta: " + token[0]
1282 "Invalid diff operation in diff_fromDelta: " + token[0]
1283 )
1283 )
1284 if pointer != len(text1):
1284 if pointer != len(text1):
1285 raise ValueError(
1285 raise ValueError(
1286 "Delta length (%d) does not equal source text length (%d)."
1286 "Delta length (%d) does not equal source text length (%d)."
1287 % (pointer, len(text1))
1287 % (pointer, len(text1))
1288 )
1288 )
1289 return diffs
1289 return diffs
1290
1290
1291 # MATCH FUNCTIONS
1291 # MATCH FUNCTIONS
1292
1292
1293 def match_main(self, text, pattern, loc):
1293 def match_main(self, text, pattern, loc):
1294 """Locate the best instance of 'pattern' in 'text' near 'loc'.
1294 """Locate the best instance of 'pattern' in 'text' near 'loc'.
1295
1295
1296 Args:
1296 Args:
1297 text: The text to search.
1297 text: The text to search.
1298 pattern: The pattern to search for.
1298 pattern: The pattern to search for.
1299 loc: The location to search around.
1299 loc: The location to search around.
1300
1300
1301 Returns:
1301 Returns:
1302 Best match index or -1.
1302 Best match index or -1.
1303 """
1303 """
1304 # Check for null inputs.
1304 # Check for null inputs.
1305 if text is None or pattern is None:
1305 if text is None or pattern is None:
1306 raise ValueError("Null inputs. (match_main)")
1306 raise ValueError("Null inputs. (match_main)")
1307
1307
1308 loc = max(0, min(loc, len(text)))
1308 loc = max(0, min(loc, len(text)))
1309 if text == pattern:
1309 if text == pattern:
1310 # Shortcut (potentially not guaranteed by the algorithm)
1310 # Shortcut (potentially not guaranteed by the algorithm)
1311 return 0
1311 return 0
1312 elif not text:
1312 elif not text:
1313 # Nothing to match.
1313 # Nothing to match.
1314 return -1
1314 return -1
1315 elif text[loc : loc + len(pattern)] == pattern:
1315 elif text[loc : loc + len(pattern)] == pattern:
1316 # Perfect match at the perfect spot! (Includes case of null pattern)
1316 # Perfect match at the perfect spot! (Includes case of null pattern)
1317 return loc
1317 return loc
1318 else:
1318 else:
1319 # Do a fuzzy compare.
1319 # Do a fuzzy compare.
1320 match = self.match_bitap(text, pattern, loc)
1320 match = self.match_bitap(text, pattern, loc)
1321 return match
1321 return match
1322
1322
1323 def match_bitap(self, text, pattern, loc):
1323 def match_bitap(self, text, pattern, loc):
1324 """Locate the best instance of 'pattern' in 'text' near 'loc' using the
1324 """Locate the best instance of 'pattern' in 'text' near 'loc' using the
1325 Bitap algorithm.
1325 Bitap algorithm.
1326
1326
1327 Args:
1327 Args:
1328 text: The text to search.
1328 text: The text to search.
1329 pattern: The pattern to search for.
1329 pattern: The pattern to search for.
1330 loc: The location to search around.
1330 loc: The location to search around.
1331
1331
1332 Returns:
1332 Returns:
1333 Best match index or -1.
1333 Best match index or -1.
1334 """
1334 """
1335 # Python doesn't have a maxint limit, so ignore this check.
1335 # Python doesn't have a maxint limit, so ignore this check.
1336 # if self.Match_MaxBits != 0 and len(pattern) > self.Match_MaxBits:
1336 # if self.Match_MaxBits != 0 and len(pattern) > self.Match_MaxBits:
1337 # raise ValueError("Pattern too long for this application.")
1337 # raise ValueError("Pattern too long for this application.")
1338
1338
1339 # Initialise the alphabet.
1339 # Initialise the alphabet.
1340 s = self.match_alphabet(pattern)
1340 s = self.match_alphabet(pattern)
1341
1341
1342 def match_bitapScore(e, x):
1342 def match_bitapScore(e, x):
1343 """Compute and return the score for a match with e errors and x location.
1343 """Compute and return the score for a match with e errors and x location.
1344 Accesses loc and pattern through being a closure.
1344 Accesses loc and pattern through being a closure.
1345
1345
1346 Args:
1346 Args:
1347 e: Number of errors in match.
1347 e: Number of errors in match.
1348 x: Location of match.
1348 x: Location of match.
1349
1349
1350 Returns:
1350 Returns:
1351 Overall score for match (0.0 = good, 1.0 = bad).
1351 Overall score for match (0.0 = good, 1.0 = bad).
1352 """
1352 """
1353 accuracy = float(e) / len(pattern)
1353 accuracy = float(e) / len(pattern)
1354 proximity = abs(loc - x)
1354 proximity = abs(loc - x)
1355 if not self.Match_Distance:
1355 if not self.Match_Distance:
1356 # Dodge divide by zero error.
1356 # Dodge divide by zero error.
1357 return proximity and 1.0 or accuracy
1357 return proximity and 1.0 or accuracy
1358 return accuracy + (proximity / float(self.Match_Distance))
1358 return accuracy + (proximity / float(self.Match_Distance))
1359
1359
1360 # Highest score beyond which we give up.
1360 # Highest score beyond which we give up.
1361 score_threshold = self.Match_Threshold
1361 score_threshold = self.Match_Threshold
1362 # Is there a nearby exact match? (speedup)
1362 # Is there a nearby exact match? (speedup)
1363 best_loc = text.find(pattern, loc)
1363 best_loc = text.find(pattern, loc)
1364 if best_loc != -1:
1364 if best_loc != -1:
1365 score_threshold = min(match_bitapScore(0, best_loc), score_threshold)
1365 score_threshold = min(match_bitapScore(0, best_loc), score_threshold)
1366 # What about in the other direction? (speedup)
1366 # What about in the other direction? (speedup)
1367 best_loc = text.rfind(pattern, loc + len(pattern))
1367 best_loc = text.rfind(pattern, loc + len(pattern))
1368 if best_loc != -1:
1368 if best_loc != -1:
1369 score_threshold = min(match_bitapScore(0, best_loc), score_threshold)
1369 score_threshold = min(match_bitapScore(0, best_loc), score_threshold)
1370
1370
1371 # Initialise the bit arrays.
1371 # Initialise the bit arrays.
1372 matchmask = 1 << (len(pattern) - 1)
1372 matchmask = 1 << (len(pattern) - 1)
1373 best_loc = -1
1373 best_loc = -1
1374
1374
1375 bin_max = len(pattern) + len(text)
1375 bin_max = len(pattern) + len(text)
1376 # Empty initialization added to appease pychecker.
1376 # Empty initialization added to appease pychecker.
1377 last_rd = None
1377 last_rd = None
1378 for d in range(len(pattern)):
1378 for d in range(len(pattern)):
1379 # Scan for the best match each iteration allows for one more error.
1379 # Scan for the best match each iteration allows for one more error.
1380 # Run a binary search to determine how far from 'loc' we can stray at
1380 # Run a binary search to determine how far from 'loc' we can stray at
1381 # this error level.
1381 # this error level.
1382 bin_min = 0
1382 bin_min = 0
1383 bin_mid = bin_max
1383 bin_mid = bin_max
1384 while bin_min < bin_mid:
1384 while bin_min < bin_mid:
1385 if match_bitapScore(d, loc + bin_mid) <= score_threshold:
1385 if match_bitapScore(d, loc + bin_mid) <= score_threshold:
1386 bin_min = bin_mid
1386 bin_min = bin_mid
1387 else:
1387 else:
1388 bin_max = bin_mid
1388 bin_max = bin_mid
1389 bin_mid = (bin_max - bin_min) // 2 + bin_min
1389 bin_mid = (bin_max - bin_min) // 2 + bin_min
1390
1390
1391 # Use the result from this iteration as the maximum for the next.
1391 # Use the result from this iteration as the maximum for the next.
1392 bin_max = bin_mid
1392 bin_max = bin_mid
1393 start = max(1, loc - bin_mid + 1)
1393 start = max(1, loc - bin_mid + 1)
1394 finish = min(loc + bin_mid, len(text)) + len(pattern)
1394 finish = min(loc + bin_mid, len(text)) + len(pattern)
1395
1395
1396 rd = [0] * (finish + 2)
1396 rd = [0] * (finish + 2)
1397 rd[finish + 1] = (1 << d) - 1
1397 rd[finish + 1] = (1 << d) - 1
1398 for j in range(finish, start - 1, -1):
1398 for j in range(finish, start - 1, -1):
1399 if len(text) <= j - 1:
1399 if len(text) <= j - 1:
1400 # Out of range.
1400 # Out of range.
1401 charMatch = 0
1401 charMatch = 0
1402 else:
1402 else:
1403 charMatch = s.get(text[j - 1], 0)
1403 charMatch = s.get(text[j - 1], 0)
1404 if d == 0: # First pass: exact match.
1404 if d == 0: # First pass: exact match.
1405 rd[j] = ((rd[j + 1] << 1) | 1) & charMatch
1405 rd[j] = ((rd[j + 1] << 1) | 1) & charMatch
1406 else: # Subsequent passes: fuzzy match.
1406 else: # Subsequent passes: fuzzy match.
1407 rd[j] = (
1407 rd[j] = (
1408 (((rd[j + 1] << 1) | 1) & charMatch)
1408 (((rd[j + 1] << 1) | 1) & charMatch)
1409 | (((last_rd[j + 1] | last_rd[j]) << 1) | 1)
1409 | (((last_rd[j + 1] | last_rd[j]) << 1) | 1)
1410 | last_rd[j + 1]
1410 | last_rd[j + 1]
1411 )
1411 )
1412 if rd[j] & matchmask:
1412 if rd[j] & matchmask:
1413 score = match_bitapScore(d, j - 1)
1413 score = match_bitapScore(d, j - 1)
1414 # This match will almost certainly be better than any existing match.
1414 # This match will almost certainly be better than any existing match.
1415 # But check anyway.
1415 # But check anyway.
1416 if score <= score_threshold:
1416 if score <= score_threshold:
1417 # Told you so.
1417 # Told you so.
1418 score_threshold = score
1418 score_threshold = score
1419 best_loc = j - 1
1419 best_loc = j - 1
1420 if best_loc > loc:
1420 if best_loc > loc:
1421 # When passing loc, don't exceed our current distance from loc.
1421 # When passing loc, don't exceed our current distance from loc.
1422 start = max(1, 2 * loc - best_loc)
1422 start = max(1, 2 * loc - best_loc)
1423 else:
1423 else:
1424 # Already passed loc, downhill from here on in.
1424 # Already passed loc, downhill from here on in.
1425 break
1425 break
1426 # No hope for a (better) match at greater error levels.
1426 # No hope for a (better) match at greater error levels.
1427 if match_bitapScore(d + 1, loc) > score_threshold:
1427 if match_bitapScore(d + 1, loc) > score_threshold:
1428 break
1428 break
1429 last_rd = rd
1429 last_rd = rd
1430 return best_loc
1430 return best_loc
1431
1431
1432 def match_alphabet(self, pattern):
1432 def match_alphabet(self, pattern):
1433 """Initialise the alphabet for the Bitap algorithm.
1433 """Initialise the alphabet for the Bitap algorithm.
1434
1434
1435 Args:
1435 Args:
1436 pattern: The text to encode.
1436 pattern: The text to encode.
1437
1437
1438 Returns:
1438 Returns:
1439 Hash of character locations.
1439 Hash of character locations.
1440 """
1440 """
1441 s = {}
1441 s = {}
1442 for char in pattern:
1442 for char in pattern:
1443 s[char] = 0
1443 s[char] = 0
1444 for i in range(len(pattern)):
1444 for i in range(len(pattern)):
1445 s[pattern[i]] |= 1 << (len(pattern) - i - 1)
1445 s[pattern[i]] |= 1 << (len(pattern) - i - 1)
1446 return s
1446 return s
1447
1447
1448 # PATCH FUNCTIONS
1448 # PATCH FUNCTIONS
1449
1449
1450 def patch_addContext(self, patch, text):
1450 def patch_addContext(self, patch, text):
1451 """Increase the context until it is unique,
1451 """Increase the context until it is unique,
1452 but don't let the pattern expand beyond Match_MaxBits.
1452 but don't let the pattern expand beyond Match_MaxBits.
1453
1453
1454 Args:
1454 Args:
1455 patch: The patch to grow.
1455 patch: The patch to grow.
1456 text: Source text.
1456 text: Source text.
1457 """
1457 """
1458 if len(text) == 0:
1458 if len(text) == 0:
1459 return
1459 return
1460 pattern = text[patch.start2 : patch.start2 + patch.length1]
1460 pattern = text[patch.start2 : patch.start2 + patch.length1]
1461 padding = 0
1461 padding = 0
1462
1462
1463 # Look for the first and last matches of pattern in text. If two different
1463 # Look for the first and last matches of pattern in text. If two different
1464 # matches are found, increase the pattern length.
1464 # matches are found, increase the pattern length.
1465 while text.find(pattern) != text.rfind(pattern) and (
1465 while text.find(pattern) != text.rfind(pattern) and (
1466 self.Match_MaxBits == 0
1466 self.Match_MaxBits == 0
1467 or len(pattern) < self.Match_MaxBits - self.Patch_Margin - self.Patch_Margin
1467 or len(pattern) < self.Match_MaxBits - self.Patch_Margin - self.Patch_Margin
1468 ):
1468 ):
1469 padding += self.Patch_Margin
1469 padding += self.Patch_Margin
1470 pattern = text[
1470 pattern = text[
1471 max(0, patch.start2 - padding) : patch.start2 + patch.length1 + padding
1471 max(0, patch.start2 - padding) : patch.start2 + patch.length1 + padding
1472 ]
1472 ]
1473 # Add one chunk for good luck.
1473 # Add one chunk for good luck.
1474 padding += self.Patch_Margin
1474 padding += self.Patch_Margin
1475
1475
1476 # Add the prefix.
1476 # Add the prefix.
1477 prefix = text[max(0, patch.start2 - padding) : patch.start2]
1477 prefix = text[max(0, patch.start2 - padding) : patch.start2]
1478 if prefix:
1478 if prefix:
1479 patch.diffs[:0] = [(self.DIFF_EQUAL, prefix)]
1479 patch.diffs[:0] = [(self.DIFF_EQUAL, prefix)]
1480 # Add the suffix.
1480 # Add the suffix.
1481 suffix = text[
1481 suffix = text[
1482 patch.start2 + patch.length1 : patch.start2 + patch.length1 + padding
1482 patch.start2 + patch.length1 : patch.start2 + patch.length1 + padding
1483 ]
1483 ]
1484 if suffix:
1484 if suffix:
1485 patch.diffs.append((self.DIFF_EQUAL, suffix))
1485 patch.diffs.append((self.DIFF_EQUAL, suffix))
1486
1486
1487 # Roll back the start points.
1487 # Roll back the start points.
1488 patch.start1 -= len(prefix)
1488 patch.start1 -= len(prefix)
1489 patch.start2 -= len(prefix)
1489 patch.start2 -= len(prefix)
1490 # Extend lengths.
1490 # Extend lengths.
1491 patch.length1 += len(prefix) + len(suffix)
1491 patch.length1 += len(prefix) + len(suffix)
1492 patch.length2 += len(prefix) + len(suffix)
1492 patch.length2 += len(prefix) + len(suffix)
1493
1493
1494 def patch_make(self, a, b=None, c=None):
1494 def patch_make(self, a, b=None, c=None):
1495 """Compute a list of patches to turn text1 into text2.
1495 """Compute a list of patches to turn text1 into text2.
1496 Use diffs if provided, otherwise compute it ourselves.
1496 Use diffs if provided, otherwise compute it ourselves.
1497 There are four ways to call this function, depending on what data is
1497 There are four ways to call this function, depending on what data is
1498 available to the caller:
1498 available to the caller:
1499 Method 1:
1499 Method 1:
1500 a = text1, b = text2
1500 a = text1, b = text2
1501 Method 2:
1501 Method 2:
1502 a = diffs
1502 a = diffs
1503 Method 3 (optimal):
1503 Method 3 (optimal):
1504 a = text1, b = diffs
1504 a = text1, b = diffs
1505 Method 4 (deprecated, use method 3):
1505 Method 4 (deprecated, use method 3):
1506 a = text1, b = text2, c = diffs
1506 a = text1, b = text2, c = diffs
1507
1507
1508 Args:
1508 Args:
1509 a: text1 (methods 1,3,4) or Array of diff tuples for text1 to
1509 a: text1 (methods 1,3,4) or Array of diff tuples for text1 to
1510 text2 (method 2).
1510 text2 (method 2).
1511 b: text2 (methods 1,4) or Array of diff tuples for text1 to
1511 b: text2 (methods 1,4) or Array of diff tuples for text1 to
1512 text2 (method 3) or undefined (method 2).
1512 text2 (method 3) or undefined (method 2).
1513 c: Array of diff tuples for text1 to text2 (method 4) or
1513 c: Array of diff tuples for text1 to text2 (method 4) or
1514 undefined (methods 1,2,3).
1514 undefined (methods 1,2,3).
1515
1515
1516 Returns:
1516 Returns:
1517 Array of Patch objects.
1517 Array of Patch objects.
1518 """
1518 """
1519 text1 = None
1519 text1 = None
1520 diffs = None
1520 diffs = None
1521 # Note that texts may arrive as 'str' or 'unicode'.
1521 # Note that texts may arrive as 'str' or 'unicode'.
1522 if isinstance(a, str) and isinstance(b, str) and c is None:
1522 if isinstance(a, str) and isinstance(b, str) and c is None:
1523 # Method 1: text1, text2
1523 # Method 1: text1, text2
1524 # Compute diffs from text1 and text2.
1524 # Compute diffs from text1 and text2.
1525 text1 = a
1525 text1 = a
1526 diffs = self.diff_main(text1, b, True)
1526 diffs = self.diff_main(text1, b, True)
1527 if len(diffs) > 2:
1527 if len(diffs) > 2:
1528 self.diff_cleanupSemantic(diffs)
1528 self.diff_cleanupSemantic(diffs)
1529 self.diff_cleanupEfficiency(diffs)
1529 self.diff_cleanupEfficiency(diffs)
1530 elif isinstance(a, list) and b is None and c is None:
1530 elif isinstance(a, list) and b is None and c is None:
1531 # Method 2: diffs
1531 # Method 2: diffs
1532 # Compute text1 from diffs.
1532 # Compute text1 from diffs.
1533 diffs = a
1533 diffs = a
1534 text1 = self.diff_text1(diffs)
1534 text1 = self.diff_text1(diffs)
1535 elif isinstance(a, str) and isinstance(b, list) and c is None:
1535 elif isinstance(a, str) and isinstance(b, list) and c is None:
1536 # Method 3: text1, diffs
1536 # Method 3: text1, diffs
1537 text1 = a
1537 text1 = a
1538 diffs = b
1538 diffs = b
1539 elif isinstance(a, str) and isinstance(b, str) and isinstance(c, list):
1539 elif isinstance(a, str) and isinstance(b, str) and isinstance(c, list):
1540 # Method 4: text1, text2, diffs
1540 # Method 4: text1, text2, diffs
1541 # text2 is not used.
1541 # text2 is not used.
1542 text1 = a
1542 text1 = a
1543 diffs = c
1543 diffs = c
1544 else:
1544 else:
1545 raise ValueError("Unknown call format to patch_make.")
1545 raise ValueError("Unknown call format to patch_make.")
1546
1546
1547 if not diffs:
1547 if not diffs:
1548 return [] # Get rid of the None case.
1548 return [] # Get rid of the None case.
1549 patches = []
1549 patches = []
1550 patch = patch_obj()
1550 patch = patch_obj()
1551 char_count1 = 0 # Number of characters into the text1 string.
1551 char_count1 = 0 # Number of characters into the text1 string.
1552 char_count2 = 0 # Number of characters into the text2 string.
1552 char_count2 = 0 # Number of characters into the text2 string.
1553 prepatch_text = text1 # Recreate the patches to determine context info.
1553 prepatch_text = text1 # Recreate the patches to determine context info.
1554 postpatch_text = text1
1554 postpatch_text = text1
1555 for x in range(len(diffs)):
1555 for x in range(len(diffs)):
1556 (diff_type, diff_text) = diffs[x]
1556 (diff_type, diff_text) = diffs[x]
1557 if len(patch.diffs) == 0 and diff_type != self.DIFF_EQUAL:
1557 if len(patch.diffs) == 0 and diff_type != self.DIFF_EQUAL:
1558 # A new patch starts here.
1558 # A new patch starts here.
1559 patch.start1 = char_count1
1559 patch.start1 = char_count1
1560 patch.start2 = char_count2
1560 patch.start2 = char_count2
1561 if diff_type == self.DIFF_INSERT:
1561 if diff_type == self.DIFF_INSERT:
1562 # Insertion
1562 # Insertion
1563 patch.diffs.append(diffs[x])
1563 patch.diffs.append(diffs[x])
1564 patch.length2 += len(diff_text)
1564 patch.length2 += len(diff_text)
1565 postpatch_text = (
1565 postpatch_text = (
1566 postpatch_text[:char_count2]
1566 postpatch_text[:char_count2]
1567 + diff_text
1567 + diff_text
1568 + postpatch_text[char_count2:]
1568 + postpatch_text[char_count2:]
1569 )
1569 )
1570 elif diff_type == self.DIFF_DELETE:
1570 elif diff_type == self.DIFF_DELETE:
1571 # Deletion.
1571 # Deletion.
1572 patch.length1 += len(diff_text)
1572 patch.length1 += len(diff_text)
1573 patch.diffs.append(diffs[x])
1573 patch.diffs.append(diffs[x])
1574 postpatch_text = (
1574 postpatch_text = (
1575 postpatch_text[:char_count2]
1575 postpatch_text[:char_count2]
1576 + postpatch_text[char_count2 + len(diff_text) :]
1576 + postpatch_text[char_count2 + len(diff_text) :]
1577 )
1577 )
1578 elif (
1578 elif (
1579 diff_type == self.DIFF_EQUAL
1579 diff_type == self.DIFF_EQUAL
1580 and len(diff_text) <= 2 * self.Patch_Margin
1580 and len(diff_text) <= 2 * self.Patch_Margin
1581 and len(patch.diffs) != 0
1581 and len(patch.diffs) != 0
1582 and len(diffs) != x + 1
1582 and len(diffs) != x + 1
1583 ):
1583 ):
1584 # Small equality inside a patch.
1584 # Small equality inside a patch.
1585 patch.diffs.append(diffs[x])
1585 patch.diffs.append(diffs[x])
1586 patch.length1 += len(diff_text)
1586 patch.length1 += len(diff_text)
1587 patch.length2 += len(diff_text)
1587 patch.length2 += len(diff_text)
1588
1588
1589 if diff_type == self.DIFF_EQUAL and len(diff_text) >= 2 * self.Patch_Margin:
1589 if diff_type == self.DIFF_EQUAL and len(diff_text) >= 2 * self.Patch_Margin:
1590 # Time for a new patch.
1590 # Time for a new patch.
1591 if len(patch.diffs) != 0:
1591 if len(patch.diffs) != 0:
1592 self.patch_addContext(patch, prepatch_text)
1592 self.patch_addContext(patch, prepatch_text)
1593 patches.append(patch)
1593 patches.append(patch)
1594 patch = patch_obj()
1594 patch = patch_obj()
1595 # Unlike Unidiff, our patch lists have a rolling context.
1595 # Unlike Unidiff, our patch lists have a rolling context.
1596 # http://code.google.com/p/google-diff-match-patch/wiki/Unidiff
1596 # http://code.google.com/p/google-diff-match-patch/wiki/Unidiff
1597 # Update prepatch text & pos to reflect the application of the
1597 # Update prepatch text & pos to reflect the application of the
1598 # just completed patch.
1598 # just completed patch.
1599 prepatch_text = postpatch_text
1599 prepatch_text = postpatch_text
1600 char_count1 = char_count2
1600 char_count1 = char_count2
1601
1601
1602 # Update the current character count.
1602 # Update the current character count.
1603 if diff_type != self.DIFF_INSERT:
1603 if diff_type != self.DIFF_INSERT:
1604 char_count1 += len(diff_text)
1604 char_count1 += len(diff_text)
1605 if diff_type != self.DIFF_DELETE:
1605 if diff_type != self.DIFF_DELETE:
1606 char_count2 += len(diff_text)
1606 char_count2 += len(diff_text)
1607
1607
1608 # Pick up the leftover patch if not empty.
1608 # Pick up the leftover patch if not empty.
1609 if len(patch.diffs) != 0:
1609 if len(patch.diffs) != 0:
1610 self.patch_addContext(patch, prepatch_text)
1610 self.patch_addContext(patch, prepatch_text)
1611 patches.append(patch)
1611 patches.append(patch)
1612 return patches
1612 return patches
1613
1613
1614 def patch_deepCopy(self, patches):
1614 def patch_deepCopy(self, patches):
1615 """Given an array of patches, return another array that is identical.
1615 """Given an array of patches, return another array that is identical.
1616
1616
1617 Args:
1617 Args:
1618 patches: Array of Patch objects.
1618 patches: Array of Patch objects.
1619
1619
1620 Returns:
1620 Returns:
1621 Array of Patch objects.
1621 Array of Patch objects.
1622 """
1622 """
1623 patchesCopy = []
1623 patchesCopy = []
1624 for patch in patches:
1624 for patch in patches:
1625 patchCopy = patch_obj()
1625 patchCopy = patch_obj()
1626 # No need to deep copy the tuples since they are immutable.
1626 # No need to deep copy the tuples since they are immutable.
1627 patchCopy.diffs = patch.diffs[:]
1627 patchCopy.diffs = patch.diffs[:]
1628 patchCopy.start1 = patch.start1
1628 patchCopy.start1 = patch.start1
1629 patchCopy.start2 = patch.start2
1629 patchCopy.start2 = patch.start2
1630 patchCopy.length1 = patch.length1
1630 patchCopy.length1 = patch.length1
1631 patchCopy.length2 = patch.length2
1631 patchCopy.length2 = patch.length2
1632 patchesCopy.append(patchCopy)
1632 patchesCopy.append(patchCopy)
1633 return patchesCopy
1633 return patchesCopy
1634
1634
1635 def patch_apply(self, patches, text):
1635 def patch_apply(self, patches, text):
1636 """Merge a set of patches onto the text. Return a patched text, as well
1636 """Merge a set of patches onto the text. Return a patched text, as well
1637 as a list of true/false values indicating which patches were applied.
1637 as a list of true/false values indicating which patches were applied.
1638
1638
1639 Args:
1639 Args:
1640 patches: Array of Patch objects.
1640 patches: Array of Patch objects.
1641 text: Old text.
1641 text: Old text.
1642
1642
1643 Returns:
1643 Returns:
1644 Two element Array, containing the new text and an array of boolean values.
1644 Two element Array, containing the new text and an array of boolean values.
1645 """
1645 """
1646 if not patches:
1646 if not patches:
1647 return (text, [])
1647 return (text, [])
1648
1648
1649 # Deep copy the patches so that no changes are made to originals.
1649 # Deep copy the patches so that no changes are made to originals.
1650 patches = self.patch_deepCopy(patches)
1650 patches = self.patch_deepCopy(patches)
1651
1651
1652 nullPadding = self.patch_addPadding(patches)
1652 nullPadding = self.patch_addPadding(patches)
1653 text = nullPadding + text + nullPadding
1653 text = nullPadding + text + nullPadding
1654 self.patch_splitMax(patches)
1654 self.patch_splitMax(patches)
1655
1655
1656 # delta keeps track of the offset between the expected and actual location
1656 # delta keeps track of the offset between the expected and actual location
1657 # of the previous patch. If there are patches expected at positions 10 and
1657 # of the previous patch. If there are patches expected at positions 10 and
1658 # 20, but the first patch was found at 12, delta is 2 and the second patch
1658 # 20, but the first patch was found at 12, delta is 2 and the second patch
1659 # has an effective expected position of 22.
1659 # has an effective expected position of 22.
1660 delta = 0
1660 delta = 0
1661 results = []
1661 results = []
1662 for patch in patches:
1662 for patch in patches:
1663 expected_loc = patch.start2 + delta
1663 expected_loc = patch.start2 + delta
1664 text1 = self.diff_text1(patch.diffs)
1664 text1 = self.diff_text1(patch.diffs)
1665 end_loc = -1
1665 end_loc = -1
1666 if len(text1) > self.Match_MaxBits:
1666 if len(text1) > self.Match_MaxBits:
1667 # patch_splitMax will only provide an oversized pattern in the case of
1667 # patch_splitMax will only provide an oversized pattern in the case of
1668 # a monster delete.
1668 # a monster delete.
1669 start_loc = self.match_main(
1669 start_loc = self.match_main(
1670 text, text1[: self.Match_MaxBits], expected_loc
1670 text, text1[: self.Match_MaxBits], expected_loc
1671 )
1671 )
1672 if start_loc != -1:
1672 if start_loc != -1:
1673 end_loc = self.match_main(
1673 end_loc = self.match_main(
1674 text,
1674 text,
1675 text1[-self.Match_MaxBits :],
1675 text1[-self.Match_MaxBits :],
1676 expected_loc + len(text1) - self.Match_MaxBits,
1676 expected_loc + len(text1) - self.Match_MaxBits,
1677 )
1677 )
1678 if end_loc == -1 or start_loc >= end_loc:
1678 if end_loc == -1 or start_loc >= end_loc:
1679 # Can't find valid trailing context. Drop this patch.
1679 # Can't find valid trailing context. Drop this patch.
1680 start_loc = -1
1680 start_loc = -1
1681 else:
1681 else:
1682 start_loc = self.match_main(text, text1, expected_loc)
1682 start_loc = self.match_main(text, text1, expected_loc)
1683 if start_loc == -1:
1683 if start_loc == -1:
1684 # No match found. :(
1684 # No match found. :(
1685 results.append(False)
1685 results.append(False)
1686 # Subtract the delta for this failed patch from subsequent patches.
1686 # Subtract the delta for this failed patch from subsequent patches.
1687 delta -= patch.length2 - patch.length1
1687 delta -= patch.length2 - patch.length1
1688 else:
1688 else:
1689 # Found a match. :)
1689 # Found a match. :)
1690 results.append(True)
1690 results.append(True)
1691 delta = start_loc - expected_loc
1691 delta = start_loc - expected_loc
1692 if end_loc == -1:
1692 if end_loc == -1:
1693 text2 = text[start_loc : start_loc + len(text1)]
1693 text2 = text[start_loc : start_loc + len(text1)]
1694 else:
1694 else:
1695 text2 = text[start_loc : end_loc + self.Match_MaxBits]
1695 text2 = text[start_loc : end_loc + self.Match_MaxBits]
1696 if text1 == text2:
1696 if text1 == text2:
1697 # Perfect match, just shove the replacement text in.
1697 # Perfect match, just shove the replacement text in.
1698 text = (
1698 text = (
1699 text[:start_loc]
1699 text[:start_loc]
1700 + self.diff_text2(patch.diffs)
1700 + self.diff_text2(patch.diffs)
1701 + text[start_loc + len(text1) :]
1701 + text[start_loc + len(text1) :]
1702 )
1702 )
1703 else:
1703 else:
1704 # Imperfect match.
1704 # Imperfect match.
1705 # Run a diff to get a framework of equivalent indices.
1705 # Run a diff to get a framework of equivalent indices.
1706 diffs = self.diff_main(text1, text2, False)
1706 diffs = self.diff_main(text1, text2, False)
1707 if (
1707 if (
1708 len(text1) > self.Match_MaxBits
1708 len(text1) > self.Match_MaxBits
1709 and self.diff_levenshtein(diffs) / float(len(text1))
1709 and self.diff_levenshtein(diffs) / float(len(text1))
1710 > self.Patch_DeleteThreshold
1710 > self.Patch_DeleteThreshold
1711 ):
1711 ):
1712 # The end points match, but the content is unacceptably bad.
1712 # The end points match, but the content is unacceptably bad.
1713 results[-1] = False
1713 results[-1] = False
1714 else:
1714 else:
1715 self.diff_cleanupSemanticLossless(diffs)
1715 self.diff_cleanupSemanticLossless(diffs)
1716 index1 = 0
1716 index1 = 0
1717 for op, data in patch.diffs:
1717 for op, data in patch.diffs:
1718 if op != self.DIFF_EQUAL:
1718 if op != self.DIFF_EQUAL:
1719 index2 = self.diff_xIndex(diffs, index1)
1719 index2 = self.diff_xIndex(diffs, index1)
1720 if op == self.DIFF_INSERT: # Insertion
1720 if op == self.DIFF_INSERT: # Insertion
1721 text = (
1721 text = (
1722 text[: start_loc + index2]
1722 text[: start_loc + index2]
1723 + data
1723 + data
1724 + text[start_loc + index2 :]
1724 + text[start_loc + index2 :]
1725 )
1725 )
1726 elif op == self.DIFF_DELETE: # Deletion
1726 elif op == self.DIFF_DELETE: # Deletion
1727 text = (
1727 text = (
1728 text[: start_loc + index2]
1728 text[: start_loc + index2]
1729 + text[
1729 + text[
1730 start_loc
1730 start_loc
1731 + self.diff_xIndex(diffs, index1 + len(data)) :
1731 + self.diff_xIndex(diffs, index1 + len(data)) :
1732 ]
1732 ]
1733 )
1733 )
1734 if op != self.DIFF_DELETE:
1734 if op != self.DIFF_DELETE:
1735 index1 += len(data)
1735 index1 += len(data)
1736 # Strip the padding off.
1736 # Strip the padding off.
1737 text = text[len(nullPadding) : -len(nullPadding)]
1737 text = text[len(nullPadding) : -len(nullPadding)]
1738 return (text, results)
1738 return (text, results)
1739
1739
1740 def patch_addPadding(self, patches):
1740 def patch_addPadding(self, patches):
1741 """Add some padding on text start and end so that edges can match
1741 """Add some padding on text start and end so that edges can match
1742 something. Intended to be called only from within patch_apply.
1742 something. Intended to be called only from within patch_apply.
1743
1743
1744 Args:
1744 Args:
1745 patches: Array of Patch objects.
1745 patches: Array of Patch objects.
1746
1746
1747 Returns:
1747 Returns:
1748 The padding string added to each side.
1748 The padding string added to each side.
1749 """
1749 """
1750 paddingLength = self.Patch_Margin
1750 paddingLength = self.Patch_Margin
1751 nullPadding = ""
1751 nullPadding = ""
1752 for x in range(1, paddingLength + 1):
1752 for x in range(1, paddingLength + 1):
1753 nullPadding += chr(x)
1753 nullPadding += chr(x)
1754
1754
1755 # Bump all the patches forward.
1755 # Bump all the patches forward.
1756 for patch in patches:
1756 for patch in patches:
1757 patch.start1 += paddingLength
1757 patch.start1 += paddingLength
1758 patch.start2 += paddingLength
1758 patch.start2 += paddingLength
1759
1759
1760 # Add some padding on start of first diff.
1760 # Add some padding on start of first diff.
1761 patch = patches[0]
1761 patch = patches[0]
1762 diffs = patch.diffs
1762 diffs = patch.diffs
1763 if not diffs or diffs[0][0] != self.DIFF_EQUAL:
1763 if not diffs or diffs[0][0] != self.DIFF_EQUAL:
1764 # Add nullPadding equality.
1764 # Add nullPadding equality.
1765 diffs.insert(0, (self.DIFF_EQUAL, nullPadding))
1765 diffs.insert(0, (self.DIFF_EQUAL, nullPadding))
1766 patch.start1 -= paddingLength # Should be 0.
1766 patch.start1 -= paddingLength # Should be 0.
1767 patch.start2 -= paddingLength # Should be 0.
1767 patch.start2 -= paddingLength # Should be 0.
1768 patch.length1 += paddingLength
1768 patch.length1 += paddingLength
1769 patch.length2 += paddingLength
1769 patch.length2 += paddingLength
1770 elif paddingLength > len(diffs[0][1]):
1770 elif paddingLength > len(diffs[0][1]):
1771 # Grow first equality.
1771 # Grow first equality.
1772 extraLength = paddingLength - len(diffs[0][1])
1772 extraLength = paddingLength - len(diffs[0][1])
1773 newText = nullPadding[len(diffs[0][1]) :] + diffs[0][1]
1773 newText = nullPadding[len(diffs[0][1]) :] + diffs[0][1]
1774 diffs[0] = (diffs[0][0], newText)
1774 diffs[0] = (diffs[0][0], newText)
1775 patch.start1 -= extraLength
1775 patch.start1 -= extraLength
1776 patch.start2 -= extraLength
1776 patch.start2 -= extraLength
1777 patch.length1 += extraLength
1777 patch.length1 += extraLength
1778 patch.length2 += extraLength
1778 patch.length2 += extraLength
1779
1779
1780 # Add some padding on end of last diff.
1780 # Add some padding on end of last diff.
1781 patch = patches[-1]
1781 patch = patches[-1]
1782 diffs = patch.diffs
1782 diffs = patch.diffs
1783 if not diffs or diffs[-1][0] != self.DIFF_EQUAL:
1783 if not diffs or diffs[-1][0] != self.DIFF_EQUAL:
1784 # Add nullPadding equality.
1784 # Add nullPadding equality.
1785 diffs.append((self.DIFF_EQUAL, nullPadding))
1785 diffs.append((self.DIFF_EQUAL, nullPadding))
1786 patch.length1 += paddingLength
1786 patch.length1 += paddingLength
1787 patch.length2 += paddingLength
1787 patch.length2 += paddingLength
1788 elif paddingLength > len(diffs[-1][1]):
1788 elif paddingLength > len(diffs[-1][1]):
1789 # Grow last equality.
1789 # Grow last equality.
1790 extraLength = paddingLength - len(diffs[-1][1])
1790 extraLength = paddingLength - len(diffs[-1][1])
1791 newText = diffs[-1][1] + nullPadding[:extraLength]
1791 newText = diffs[-1][1] + nullPadding[:extraLength]
1792 diffs[-1] = (diffs[-1][0], newText)
1792 diffs[-1] = (diffs[-1][0], newText)
1793 patch.length1 += extraLength
1793 patch.length1 += extraLength
1794 patch.length2 += extraLength
1794 patch.length2 += extraLength
1795
1795
1796 return nullPadding
1796 return nullPadding
1797
1797
1798 def patch_splitMax(self, patches):
1798 def patch_splitMax(self, patches):
1799 """Look through the patches and break up any which are longer than the
1799 """Look through the patches and break up any which are longer than the
1800 maximum limit of the match algorithm.
1800 maximum limit of the match algorithm.
1801 Intended to be called only from within patch_apply.
1801 Intended to be called only from within patch_apply.
1802
1802
1803 Args:
1803 Args:
1804 patches: Array of Patch objects.
1804 patches: Array of Patch objects.
1805 """
1805 """
1806 patch_size = self.Match_MaxBits
1806 patch_size = self.Match_MaxBits
1807 if patch_size == 0:
1807 if patch_size == 0:
1808 # Python has the option of not splitting strings due to its ability
1808 # Python has the option of not splitting strings due to its ability
1809 # to handle integers of arbitrary precision.
1809 # to handle integers of arbitrary precision.
1810 return
1810 return
1811 for x in range(len(patches)):
1811 for x in range(len(patches)):
1812 if patches[x].length1 <= patch_size:
1812 if patches[x].length1 <= patch_size:
1813 continue
1813 continue
1814 bigpatch = patches[x]
1814 bigpatch = patches[x]
1815 # Remove the big old patch.
1815 # Remove the big old patch.
1816 del patches[x]
1816 del patches[x]
1817 x -= 1
1817 x -= 1
1818 start1 = bigpatch.start1
1818 start1 = bigpatch.start1
1819 start2 = bigpatch.start2
1819 start2 = bigpatch.start2
1820 precontext = ""
1820 precontext = ""
1821 while len(bigpatch.diffs) != 0:
1821 while len(bigpatch.diffs) != 0:
1822 # Create one of several smaller patches.
1822 # Create one of several smaller patches.
1823 patch = patch_obj()
1823 patch = patch_obj()
1824 empty = True
1824 empty = True
1825 patch.start1 = start1 - len(precontext)
1825 patch.start1 = start1 - len(precontext)
1826 patch.start2 = start2 - len(precontext)
1826 patch.start2 = start2 - len(precontext)
1827 if precontext:
1827 if precontext:
1828 patch.length1 = patch.length2 = len(precontext)
1828 patch.length1 = patch.length2 = len(precontext)
1829 patch.diffs.append((self.DIFF_EQUAL, precontext))
1829 patch.diffs.append((self.DIFF_EQUAL, precontext))
1830
1830
1831 while (
1831 while (
1832 len(bigpatch.diffs) != 0
1832 len(bigpatch.diffs) != 0
1833 and patch.length1 < patch_size - self.Patch_Margin
1833 and patch.length1 < patch_size - self.Patch_Margin
1834 ):
1834 ):
1835 (diff_type, diff_text) = bigpatch.diffs[0]
1835 (diff_type, diff_text) = bigpatch.diffs[0]
1836 if diff_type == self.DIFF_INSERT:
1836 if diff_type == self.DIFF_INSERT:
1837 # Insertions are harmless.
1837 # Insertions are harmless.
1838 patch.length2 += len(diff_text)
1838 patch.length2 += len(diff_text)
1839 start2 += len(diff_text)
1839 start2 += len(diff_text)
1840 patch.diffs.append(bigpatch.diffs.pop(0))
1840 patch.diffs.append(bigpatch.diffs.pop(0))
1841 empty = False
1841 empty = False
1842 elif (
1842 elif (
1843 diff_type == self.DIFF_DELETE
1843 diff_type == self.DIFF_DELETE
1844 and len(patch.diffs) == 1
1844 and len(patch.diffs) == 1
1845 and patch.diffs[0][0] == self.DIFF_EQUAL
1845 and patch.diffs[0][0] == self.DIFF_EQUAL
1846 and len(diff_text) > 2 * patch_size
1846 and len(diff_text) > 2 * patch_size
1847 ):
1847 ):
1848 # This is a large deletion. Let it pass in one chunk.
1848 # This is a large deletion. Let it pass in one chunk.
1849 patch.length1 += len(diff_text)
1849 patch.length1 += len(diff_text)
1850 start1 += len(diff_text)
1850 start1 += len(diff_text)
1851 empty = False
1851 empty = False
1852 patch.diffs.append((diff_type, diff_text))
1852 patch.diffs.append((diff_type, diff_text))
1853 del bigpatch.diffs[0]
1853 del bigpatch.diffs[0]
1854 else:
1854 else:
1855 # Deletion or equality. Only take as much as we can stomach.
1855 # Deletion or equality. Only take as much as we can stomach.
1856 diff_text = diff_text[
1856 diff_text = diff_text[
1857 : patch_size - patch.length1 - self.Patch_Margin
1857 : patch_size - patch.length1 - self.Patch_Margin
1858 ]
1858 ]
1859 patch.length1 += len(diff_text)
1859 patch.length1 += len(diff_text)
1860 start1 += len(diff_text)
1860 start1 += len(diff_text)
1861 if diff_type == self.DIFF_EQUAL:
1861 if diff_type == self.DIFF_EQUAL:
1862 patch.length2 += len(diff_text)
1862 patch.length2 += len(diff_text)
1863 start2 += len(diff_text)
1863 start2 += len(diff_text)
1864 else:
1864 else:
1865 empty = False
1865 empty = False
1866
1866
1867 patch.diffs.append((diff_type, diff_text))
1867 patch.diffs.append((diff_type, diff_text))
1868 if diff_text == bigpatch.diffs[0][1]:
1868 if diff_text == bigpatch.diffs[0][1]:
1869 del bigpatch.diffs[0]
1869 del bigpatch.diffs[0]
1870 else:
1870 else:
1871 bigpatch.diffs[0] = (
1871 bigpatch.diffs[0] = (
1872 bigpatch.diffs[0][0],
1872 bigpatch.diffs[0][0],
1873 bigpatch.diffs[0][1][len(diff_text) :],
1873 bigpatch.diffs[0][1][len(diff_text) :],
1874 )
1874 )
1875
1875
1876 # Compute the head context for the next patch.
1876 # Compute the head context for the next patch.
1877 precontext = self.diff_text2(patch.diffs)
1877 precontext = self.diff_text2(patch.diffs)
1878 precontext = precontext[-self.Patch_Margin :]
1878 precontext = precontext[-self.Patch_Margin :]
1879 # Append the end context for this patch.
1879 # Append the end context for this patch.
1880 postcontext = self.diff_text1(bigpatch.diffs)[: self.Patch_Margin]
1880 postcontext = self.diff_text1(bigpatch.diffs)[: self.Patch_Margin]
1881 if postcontext:
1881 if postcontext:
1882 patch.length1 += len(postcontext)
1882 patch.length1 += len(postcontext)
1883 patch.length2 += len(postcontext)
1883 patch.length2 += len(postcontext)
1884 if len(patch.diffs) != 0 and patch.diffs[-1][0] == self.DIFF_EQUAL:
1884 if len(patch.diffs) != 0 and patch.diffs[-1][0] == self.DIFF_EQUAL:
1885 patch.diffs[-1] = (
1885 patch.diffs[-1] = (
1886 self.DIFF_EQUAL,
1886 self.DIFF_EQUAL,
1887 patch.diffs[-1][1] + postcontext,
1887 patch.diffs[-1][1] + postcontext,
1888 )
1888 )
1889 else:
1889 else:
1890 patch.diffs.append((self.DIFF_EQUAL, postcontext))
1890 patch.diffs.append((self.DIFF_EQUAL, postcontext))
1891
1891
1892 if not empty:
1892 if not empty:
1893 x += 1
1893 x += 1
1894 patches.insert(x, patch)
1894 patches.insert(x, patch)
1895
1895
1896 def patch_toText(self, patches):
1896 def patch_toText(self, patches):
1897 """Take a list of patches and return a textual representation.
1897 """Take a list of patches and return a textual representation.
1898
1898
1899 Args:
1899 Args:
1900 patches: Array of Patch objects.
1900 patches: Array of Patch objects.
1901
1901
1902 Returns:
1902 Returns:
1903 Text representation of patches.
1903 Text representation of patches.
1904 """
1904 """
1905 text = []
1905 text = []
1906 for patch in patches:
1906 for patch in patches:
1907 text.append(str(patch))
1907 text.append(str(patch))
1908 return "".join(text)
1908 return "".join(text)
1909
1909
1910 def patch_fromText(self, textline):
1910 def patch_fromText(self, textline):
1911 """Parse a textual representation of patches and return a list of patch
1911 """Parse a textual representation of patches and return a list of patch
1912 objects.
1912 objects.
1913
1913
1914 Args:
1914 Args:
1915 textline: Text representation of patches.
1915 textline: Text representation of patches.
1916
1916
1917 Returns:
1917 Returns:
1918 Array of Patch objects.
1918 Array of Patch objects.
1919
1919
1920 Raises:
1920 Raises:
1921 ValueError: If invalid input.
1921 ValueError: If invalid input.
1922 """
1922 """
1923 if type(textline) == unicode:
1923 if type(textline) == str:
1924 # Patches should be composed of a subset of ascii chars, Unicode not
1924 # Patches should be composed of a subset of ascii chars, Unicode not
1925 # required. If this encode raises UnicodeEncodeError, patch is invalid.
1925 # required. If this encode raises UnicodeEncodeError, patch is invalid.
1926 textline = textline.encode("ascii")
1926 textline = textline.encode("ascii")
1927 patches = []
1927 patches = []
1928 if not textline:
1928 if not textline:
1929 return patches
1929 return patches
1930 text = textline.split("\n")
1930 text = textline.split("\n")
1931 while len(text) != 0:
1931 while len(text) != 0:
1932 m = re.match("^@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@$", text[0])
1932 m = re.match("^@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@$", text[0])
1933 if not m:
1933 if not m:
1934 raise ValueError("Invalid patch string: " + text[0])
1934 raise ValueError("Invalid patch string: " + text[0])
1935 patch = patch_obj()
1935 patch = patch_obj()
1936 patches.append(patch)
1936 patches.append(patch)
1937 patch.start1 = int(m.group(1))
1937 patch.start1 = int(m.group(1))
1938 if m.group(2) == "":
1938 if m.group(2) == "":
1939 patch.start1 -= 1
1939 patch.start1 -= 1
1940 patch.length1 = 1
1940 patch.length1 = 1
1941 elif m.group(2) == "0":
1941 elif m.group(2) == "0":
1942 patch.length1 = 0
1942 patch.length1 = 0
1943 else:
1943 else:
1944 patch.start1 -= 1
1944 patch.start1 -= 1
1945 patch.length1 = int(m.group(2))
1945 patch.length1 = int(m.group(2))
1946
1946
1947 patch.start2 = int(m.group(3))
1947 patch.start2 = int(m.group(3))
1948 if m.group(4) == "":
1948 if m.group(4) == "":
1949 patch.start2 -= 1
1949 patch.start2 -= 1
1950 patch.length2 = 1
1950 patch.length2 = 1
1951 elif m.group(4) == "0":
1951 elif m.group(4) == "0":
1952 patch.length2 = 0
1952 patch.length2 = 0
1953 else:
1953 else:
1954 patch.start2 -= 1
1954 patch.start2 -= 1
1955 patch.length2 = int(m.group(4))
1955 patch.length2 = int(m.group(4))
1956
1956
1957 del text[0]
1957 del text[0]
1958
1958
1959 while len(text) != 0:
1959 while len(text) != 0:
1960 if text[0]:
1960 if text[0]:
1961 sign = text[0][0]
1961 sign = text[0][0]
1962 else:
1962 else:
1963 sign = ""
1963 sign = ""
1964 line = urllib.parse.unquote(text[0][1:])
1964 line = urllib.parse.unquote(text[0][1:])
1965 line = line.decode("utf-8")
1965 line = line.decode("utf-8")
1966 if sign == "+":
1966 if sign == "+":
1967 # Insertion.
1967 # Insertion.
1968 patch.diffs.append((self.DIFF_INSERT, line))
1968 patch.diffs.append((self.DIFF_INSERT, line))
1969 elif sign == "-":
1969 elif sign == "-":
1970 # Deletion.
1970 # Deletion.
1971 patch.diffs.append((self.DIFF_DELETE, line))
1971 patch.diffs.append((self.DIFF_DELETE, line))
1972 elif sign == " ":
1972 elif sign == " ":
1973 # Minor equality.
1973 # Minor equality.
1974 patch.diffs.append((self.DIFF_EQUAL, line))
1974 patch.diffs.append((self.DIFF_EQUAL, line))
1975 elif sign == "@":
1975 elif sign == "@":
1976 # Start of next patch.
1976 # Start of next patch.
1977 break
1977 break
1978 elif sign == "":
1978 elif sign == "":
1979 # Blank line? Whatever.
1979 # Blank line? Whatever.
1980 pass
1980 pass
1981 else:
1981 else:
1982 # WTF?
1982 # WTF?
1983 raise ValueError("Invalid patch mode: '%s'\n%s" % (sign, line))
1983 raise ValueError("Invalid patch mode: '%s'\n%s" % (sign, line))
1984 del text[0]
1984 del text[0]
1985 return patches
1985 return patches
1986
1986
1987
1987
1988 class patch_obj:
1988 class patch_obj:
1989 """Class representing one patch operation."""
1989 """Class representing one patch operation."""
1990
1990
1991 def __init__(self):
1991 def __init__(self):
1992 """Initializes with an empty list of diffs."""
1992 """Initializes with an empty list of diffs."""
1993 self.diffs = []
1993 self.diffs = []
1994 self.start1 = None
1994 self.start1 = None
1995 self.start2 = None
1995 self.start2 = None
1996 self.length1 = 0
1996 self.length1 = 0
1997 self.length2 = 0
1997 self.length2 = 0
1998
1998
1999 def __str__(self):
1999 def __str__(self):
2000 """Emmulate GNU diff's format.
2000 """Emmulate GNU diff's format.
2001 Header: @@ -382,8 +481,9 @@
2001 Header: @@ -382,8 +481,9 @@
2002 Indicies are printed as 1-based, not 0-based.
2002 Indicies are printed as 1-based, not 0-based.
2003
2003
2004 Returns:
2004 Returns:
2005 The GNU diff string.
2005 The GNU diff string.
2006 """
2006 """
2007 if self.length1 == 0:
2007 if self.length1 == 0:
2008 coords1 = str(self.start1) + ",0"
2008 coords1 = str(self.start1) + ",0"
2009 elif self.length1 == 1:
2009 elif self.length1 == 1:
2010 coords1 = str(self.start1 + 1)
2010 coords1 = str(self.start1 + 1)
2011 else:
2011 else:
2012 coords1 = str(self.start1 + 1) + "," + str(self.length1)
2012 coords1 = str(self.start1 + 1) + "," + str(self.length1)
2013 if self.length2 == 0:
2013 if self.length2 == 0:
2014 coords2 = str(self.start2) + ",0"
2014 coords2 = str(self.start2) + ",0"
2015 elif self.length2 == 1:
2015 elif self.length2 == 1:
2016 coords2 = str(self.start2 + 1)
2016 coords2 = str(self.start2 + 1)
2017 else:
2017 else:
2018 coords2 = str(self.start2 + 1) + "," + str(self.length2)
2018 coords2 = str(self.start2 + 1) + "," + str(self.length2)
2019 text = ["@@ -", coords1, " +", coords2, " @@\n"]
2019 text = ["@@ -", coords1, " +", coords2, " @@\n"]
2020 # Escape the body of the patch with %xx notation.
2020 # Escape the body of the patch with %xx notation.
2021 for op, data in self.diffs:
2021 for op, data in self.diffs:
2022 if op == diff_match_patch.DIFF_INSERT:
2022 if op == diff_match_patch.DIFF_INSERT:
2023 text.append("+")
2023 text.append("+")
2024 elif op == diff_match_patch.DIFF_DELETE:
2024 elif op == diff_match_patch.DIFF_DELETE:
2025 text.append("-")
2025 text.append("-")
2026 elif op == diff_match_patch.DIFF_EQUAL:
2026 elif op == diff_match_patch.DIFF_EQUAL:
2027 text.append(" ")
2027 text.append(" ")
2028 # High ascii will raise UnicodeDecodeError. Use Unicode instead.
2028 # High ascii will raise UnicodeDecodeError. Use Unicode instead.
2029 data = data.encode("utf-8")
2029 data = data.encode("utf-8")
2030 text.append(urllib.parse.quote(data, "!~*'();/?:@&=+$,# ") + "\n")
2030 text.append(urllib.parse.quote(data, "!~*'();/?:@&=+$,# ") + "\n")
2031 return "".join(text)
2031 return "".join(text)
@@ -1,1272 +1,1271 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2011-2020 RhodeCode GmbH
3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21
21
22 """
22 """
23 Set of diffing helpers, previously part of vcs
23 Set of diffing helpers, previously part of vcs
24 """
24 """
25
25
26 import os
26 import os
27 import re
27 import re
28 import bz2
28 import bz2
29 import gzip
29 import gzip
30 import time
30 import time
31
31
32 import collections
32 import collections
33 import difflib
33 import difflib
34 import logging
34 import logging
35 import pickle
35 import pickle
36 from itertools import tee
36 from itertools import tee
37
37
38 from rhodecode.lib.vcs.exceptions import VCSError
38 from rhodecode.lib.vcs.exceptions import VCSError
39 from rhodecode.lib.vcs.nodes import FileNode, SubModuleNode
39 from rhodecode.lib.vcs.nodes import FileNode, SubModuleNode
40 from rhodecode.lib.utils2 import safe_unicode, safe_str
40 from rhodecode.lib.utils2 import safe_unicode, safe_str
41
41
42 log = logging.getLogger(__name__)
42 log = logging.getLogger(__name__)
43
43
44 # define max context, a file with more than this numbers of lines is unusable
44 # define max context, a file with more than this numbers of lines is unusable
45 # in browser anyway
45 # in browser anyway
46 MAX_CONTEXT = 20 * 1024
46 MAX_CONTEXT = 20 * 1024
47 DEFAULT_CONTEXT = 3
47 DEFAULT_CONTEXT = 3
48
48
49
49
50 def get_diff_context(request):
50 def get_diff_context(request):
51 return MAX_CONTEXT if request.GET.get('fullcontext', '') == '1' else DEFAULT_CONTEXT
51 return MAX_CONTEXT if request.GET.get('fullcontext', '') == '1' else DEFAULT_CONTEXT
52
52
53
53
54 def get_diff_whitespace_flag(request):
54 def get_diff_whitespace_flag(request):
55 return request.GET.get('ignorews', '') == '1'
55 return request.GET.get('ignorews', '') == '1'
56
56
57
57
58 class OPS(object):
58 class OPS(object):
59 ADD = 'A'
59 ADD = 'A'
60 MOD = 'M'
60 MOD = 'M'
61 DEL = 'D'
61 DEL = 'D'
62
62
63
63
64 def get_gitdiff(filenode_old, filenode_new, ignore_whitespace=True, context=3):
64 def get_gitdiff(filenode_old, filenode_new, ignore_whitespace=True, context=3):
65 """
65 """
66 Returns git style diff between given ``filenode_old`` and ``filenode_new``.
66 Returns git style diff between given ``filenode_old`` and ``filenode_new``.
67
67
68 :param ignore_whitespace: ignore whitespaces in diff
68 :param ignore_whitespace: ignore whitespaces in diff
69 """
69 """
70 # make sure we pass in default context
70 # make sure we pass in default context
71 context = context or 3
71 context = context or 3
72 # protect against IntOverflow when passing HUGE context
72 # protect against IntOverflow when passing HUGE context
73 if context > MAX_CONTEXT:
73 if context > MAX_CONTEXT:
74 context = MAX_CONTEXT
74 context = MAX_CONTEXT
75
75
76 submodules = filter(lambda o: isinstance(o, SubModuleNode),
76 submodules = [o for o in [filenode_new, filenode_old] if isinstance(o, SubModuleNode)]
77 [filenode_new, filenode_old])
78 if submodules:
77 if submodules:
79 return ''
78 return ''
80
79
81 for filenode in (filenode_old, filenode_new):
80 for filenode in (filenode_old, filenode_new):
82 if not isinstance(filenode, FileNode):
81 if not isinstance(filenode, FileNode):
83 raise VCSError(
82 raise VCSError(
84 "Given object should be FileNode object, not %s"
83 "Given object should be FileNode object, not %s"
85 % filenode.__class__)
84 % filenode.__class__)
86
85
87 repo = filenode_new.commit.repository
86 repo = filenode_new.commit.repository
88 old_commit = filenode_old.commit or repo.EMPTY_COMMIT
87 old_commit = filenode_old.commit or repo.EMPTY_COMMIT
89 new_commit = filenode_new.commit
88 new_commit = filenode_new.commit
90
89
91 vcs_gitdiff = repo.get_diff(
90 vcs_gitdiff = repo.get_diff(
92 old_commit, new_commit, filenode_new.path,
91 old_commit, new_commit, filenode_new.path,
93 ignore_whitespace, context, path1=filenode_old.path)
92 ignore_whitespace, context, path1=filenode_old.path)
94 return vcs_gitdiff
93 return vcs_gitdiff
95
94
96 NEW_FILENODE = 1
95 NEW_FILENODE = 1
97 DEL_FILENODE = 2
96 DEL_FILENODE = 2
98 MOD_FILENODE = 3
97 MOD_FILENODE = 3
99 RENAMED_FILENODE = 4
98 RENAMED_FILENODE = 4
100 COPIED_FILENODE = 5
99 COPIED_FILENODE = 5
101 CHMOD_FILENODE = 6
100 CHMOD_FILENODE = 6
102 BIN_FILENODE = 7
101 BIN_FILENODE = 7
103
102
104
103
105 class LimitedDiffContainer(object):
104 class LimitedDiffContainer(object):
106
105
107 def __init__(self, diff_limit, cur_diff_size, diff):
106 def __init__(self, diff_limit, cur_diff_size, diff):
108 self.diff = diff
107 self.diff = diff
109 self.diff_limit = diff_limit
108 self.diff_limit = diff_limit
110 self.cur_diff_size = cur_diff_size
109 self.cur_diff_size = cur_diff_size
111
110
112 def __getitem__(self, key):
111 def __getitem__(self, key):
113 return self.diff.__getitem__(key)
112 return self.diff.__getitem__(key)
114
113
115 def __iter__(self):
114 def __iter__(self):
116 for l in self.diff:
115 for l in self.diff:
117 yield l
116 yield l
118
117
119
118
120 class Action(object):
119 class Action(object):
121 """
120 """
122 Contains constants for the action value of the lines in a parsed diff.
121 Contains constants for the action value of the lines in a parsed diff.
123 """
122 """
124
123
125 ADD = 'add'
124 ADD = 'add'
126 DELETE = 'del'
125 DELETE = 'del'
127 UNMODIFIED = 'unmod'
126 UNMODIFIED = 'unmod'
128
127
129 CONTEXT = 'context'
128 CONTEXT = 'context'
130 OLD_NO_NL = 'old-no-nl'
129 OLD_NO_NL = 'old-no-nl'
131 NEW_NO_NL = 'new-no-nl'
130 NEW_NO_NL = 'new-no-nl'
132
131
133
132
134 class DiffProcessor(object):
133 class DiffProcessor(object):
135 """
134 """
136 Give it a unified or git diff and it returns a list of the files that were
135 Give it a unified or git diff and it returns a list of the files that were
137 mentioned in the diff together with a dict of meta information that
136 mentioned in the diff together with a dict of meta information that
138 can be used to render it in a HTML template.
137 can be used to render it in a HTML template.
139
138
140 .. note:: Unicode handling
139 .. note:: Unicode handling
141
140
142 The original diffs are a byte sequence and can contain filenames
141 The original diffs are a byte sequence and can contain filenames
143 in mixed encodings. This class generally returns `unicode` objects
142 in mixed encodings. This class generally returns `unicode` objects
144 since the result is intended for presentation to the user.
143 since the result is intended for presentation to the user.
145
144
146 """
145 """
147 _chunk_re = re.compile(r'^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@(.*)')
146 _chunk_re = re.compile(r'^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@(.*)')
148 _newline_marker = re.compile(r'^\\ No newline at end of file')
147 _newline_marker = re.compile(r'^\\ No newline at end of file')
149
148
150 # used for inline highlighter word split
149 # used for inline highlighter word split
151 _token_re = re.compile(r'()(&gt;|&lt;|&amp;|\W+?)')
150 _token_re = re.compile(r'()(&gt;|&lt;|&amp;|\W+?)')
152
151
153 # collapse ranges of commits over given number
152 # collapse ranges of commits over given number
154 _collapse_commits_over = 5
153 _collapse_commits_over = 5
155
154
156 def __init__(self, diff, format='gitdiff', diff_limit=None,
155 def __init__(self, diff, format='gitdiff', diff_limit=None,
157 file_limit=None, show_full_diff=True):
156 file_limit=None, show_full_diff=True):
158 """
157 """
159 :param diff: A `Diff` object representing a diff from a vcs backend
158 :param diff: A `Diff` object representing a diff from a vcs backend
160 :param format: format of diff passed, `udiff` or `gitdiff`
159 :param format: format of diff passed, `udiff` or `gitdiff`
161 :param diff_limit: define the size of diff that is considered "big"
160 :param diff_limit: define the size of diff that is considered "big"
162 based on that parameter cut off will be triggered, set to None
161 based on that parameter cut off will be triggered, set to None
163 to show full diff
162 to show full diff
164 """
163 """
165 self._diff = diff
164 self._diff = diff
166 self._format = format
165 self._format = format
167 self.adds = 0
166 self.adds = 0
168 self.removes = 0
167 self.removes = 0
169 # calculate diff size
168 # calculate diff size
170 self.diff_limit = diff_limit
169 self.diff_limit = diff_limit
171 self.file_limit = file_limit
170 self.file_limit = file_limit
172 self.show_full_diff = show_full_diff
171 self.show_full_diff = show_full_diff
173 self.cur_diff_size = 0
172 self.cur_diff_size = 0
174 self.parsed = False
173 self.parsed = False
175 self.parsed_diff = []
174 self.parsed_diff = []
176
175
177 log.debug('Initialized DiffProcessor with %s mode', format)
176 log.debug('Initialized DiffProcessor with %s mode', format)
178 if format == 'gitdiff':
177 if format == 'gitdiff':
179 self.differ = self._highlight_line_difflib
178 self.differ = self._highlight_line_difflib
180 self._parser = self._parse_gitdiff
179 self._parser = self._parse_gitdiff
181 else:
180 else:
182 self.differ = self._highlight_line_udiff
181 self.differ = self._highlight_line_udiff
183 self._parser = self._new_parse_gitdiff
182 self._parser = self._new_parse_gitdiff
184
183
185 def _copy_iterator(self):
184 def _copy_iterator(self):
186 """
185 """
187 make a fresh copy of generator, we should not iterate thru
186 make a fresh copy of generator, we should not iterate thru
188 an original as it's needed for repeating operations on
187 an original as it's needed for repeating operations on
189 this instance of DiffProcessor
188 this instance of DiffProcessor
190 """
189 """
191 self.__udiff, iterator_copy = tee(self.__udiff)
190 self.__udiff, iterator_copy = tee(self.__udiff)
192 return iterator_copy
191 return iterator_copy
193
192
194 def _escaper(self, string):
193 def _escaper(self, string):
195 """
194 """
196 Escaper for diff escapes special chars and checks the diff limit
195 Escaper for diff escapes special chars and checks the diff limit
197
196
198 :param string:
197 :param string:
199 """
198 """
200 self.cur_diff_size += len(string)
199 self.cur_diff_size += len(string)
201
200
202 if not self.show_full_diff and (self.cur_diff_size > self.diff_limit):
201 if not self.show_full_diff and (self.cur_diff_size > self.diff_limit):
203 raise DiffLimitExceeded('Diff Limit Exceeded')
202 raise DiffLimitExceeded('Diff Limit Exceeded')
204
203
205 return string \
204 return string \
206 .replace('&', '&amp;')\
205 .replace('&', '&amp;')\
207 .replace('<', '&lt;')\
206 .replace('<', '&lt;')\
208 .replace('>', '&gt;')
207 .replace('>', '&gt;')
209
208
210 def _line_counter(self, l):
209 def _line_counter(self, l):
211 """
210 """
212 Checks each line and bumps total adds/removes for this diff
211 Checks each line and bumps total adds/removes for this diff
213
212
214 :param l:
213 :param l:
215 """
214 """
216 if l.startswith('+') and not l.startswith('+++'):
215 if l.startswith('+') and not l.startswith('+++'):
217 self.adds += 1
216 self.adds += 1
218 elif l.startswith('-') and not l.startswith('---'):
217 elif l.startswith('-') and not l.startswith('---'):
219 self.removes += 1
218 self.removes += 1
220 return safe_unicode(l)
219 return safe_unicode(l)
221
220
222 def _highlight_line_difflib(self, line, next_):
221 def _highlight_line_difflib(self, line, next_):
223 """
222 """
224 Highlight inline changes in both lines.
223 Highlight inline changes in both lines.
225 """
224 """
226
225
227 if line['action'] == Action.DELETE:
226 if line['action'] == Action.DELETE:
228 old, new = line, next_
227 old, new = line, next_
229 else:
228 else:
230 old, new = next_, line
229 old, new = next_, line
231
230
232 oldwords = self._token_re.split(old['line'])
231 oldwords = self._token_re.split(old['line'])
233 newwords = self._token_re.split(new['line'])
232 newwords = self._token_re.split(new['line'])
234 sequence = difflib.SequenceMatcher(None, oldwords, newwords)
233 sequence = difflib.SequenceMatcher(None, oldwords, newwords)
235
234
236 oldfragments, newfragments = [], []
235 oldfragments, newfragments = [], []
237 for tag, i1, i2, j1, j2 in sequence.get_opcodes():
236 for tag, i1, i2, j1, j2 in sequence.get_opcodes():
238 oldfrag = ''.join(oldwords[i1:i2])
237 oldfrag = ''.join(oldwords[i1:i2])
239 newfrag = ''.join(newwords[j1:j2])
238 newfrag = ''.join(newwords[j1:j2])
240 if tag != 'equal':
239 if tag != 'equal':
241 if oldfrag:
240 if oldfrag:
242 oldfrag = '<del>%s</del>' % oldfrag
241 oldfrag = '<del>%s</del>' % oldfrag
243 if newfrag:
242 if newfrag:
244 newfrag = '<ins>%s</ins>' % newfrag
243 newfrag = '<ins>%s</ins>' % newfrag
245 oldfragments.append(oldfrag)
244 oldfragments.append(oldfrag)
246 newfragments.append(newfrag)
245 newfragments.append(newfrag)
247
246
248 old['line'] = "".join(oldfragments)
247 old['line'] = "".join(oldfragments)
249 new['line'] = "".join(newfragments)
248 new['line'] = "".join(newfragments)
250
249
251 def _highlight_line_udiff(self, line, next_):
250 def _highlight_line_udiff(self, line, next_):
252 """
251 """
253 Highlight inline changes in both lines.
252 Highlight inline changes in both lines.
254 """
253 """
255 start = 0
254 start = 0
256 limit = min(len(line['line']), len(next_['line']))
255 limit = min(len(line['line']), len(next_['line']))
257 while start < limit and line['line'][start] == next_['line'][start]:
256 while start < limit and line['line'][start] == next_['line'][start]:
258 start += 1
257 start += 1
259 end = -1
258 end = -1
260 limit -= start
259 limit -= start
261 while -end <= limit and line['line'][end] == next_['line'][end]:
260 while -end <= limit and line['line'][end] == next_['line'][end]:
262 end -= 1
261 end -= 1
263 end += 1
262 end += 1
264 if start or end:
263 if start or end:
265 def do(l):
264 def do(l):
266 last = end + len(l['line'])
265 last = end + len(l['line'])
267 if l['action'] == Action.ADD:
266 if l['action'] == Action.ADD:
268 tag = 'ins'
267 tag = 'ins'
269 else:
268 else:
270 tag = 'del'
269 tag = 'del'
271 l['line'] = '%s<%s>%s</%s>%s' % (
270 l['line'] = '%s<%s>%s</%s>%s' % (
272 l['line'][:start],
271 l['line'][:start],
273 tag,
272 tag,
274 l['line'][start:last],
273 l['line'][start:last],
275 tag,
274 tag,
276 l['line'][last:]
275 l['line'][last:]
277 )
276 )
278 do(line)
277 do(line)
279 do(next_)
278 do(next_)
280
279
281 def _clean_line(self, line, command):
280 def _clean_line(self, line, command):
282 if command in ['+', '-', ' ']:
281 if command in ['+', '-', ' ']:
283 # only modify the line if it's actually a diff thing
282 # only modify the line if it's actually a diff thing
284 line = line[1:]
283 line = line[1:]
285 return line
284 return line
286
285
287 def _parse_gitdiff(self, inline_diff=True):
286 def _parse_gitdiff(self, inline_diff=True):
288 _files = []
287 _files = []
289 diff_container = lambda arg: arg
288 diff_container = lambda arg: arg
290
289
291 for chunk in self._diff.chunks():
290 for chunk in self._diff.chunks():
292 head = chunk.header
291 head = chunk.header
293
292
294 diff = map(self._escaper, self.diff_splitter(chunk.diff))
293 diff = map(self._escaper, self.diff_splitter(chunk.diff))
295 raw_diff = chunk.raw
294 raw_diff = chunk.raw
296 limited_diff = False
295 limited_diff = False
297 exceeds_limit = False
296 exceeds_limit = False
298
297
299 op = None
298 op = None
300 stats = {
299 stats = {
301 'added': 0,
300 'added': 0,
302 'deleted': 0,
301 'deleted': 0,
303 'binary': False,
302 'binary': False,
304 'ops': {},
303 'ops': {},
305 }
304 }
306
305
307 if head['deleted_file_mode']:
306 if head['deleted_file_mode']:
308 op = OPS.DEL
307 op = OPS.DEL
309 stats['binary'] = True
308 stats['binary'] = True
310 stats['ops'][DEL_FILENODE] = 'deleted file'
309 stats['ops'][DEL_FILENODE] = 'deleted file'
311
310
312 elif head['new_file_mode']:
311 elif head['new_file_mode']:
313 op = OPS.ADD
312 op = OPS.ADD
314 stats['binary'] = True
313 stats['binary'] = True
315 stats['ops'][NEW_FILENODE] = 'new file %s' % head['new_file_mode']
314 stats['ops'][NEW_FILENODE] = 'new file %s' % head['new_file_mode']
316 else: # modify operation, can be copy, rename or chmod
315 else: # modify operation, can be copy, rename or chmod
317
316
318 # CHMOD
317 # CHMOD
319 if head['new_mode'] and head['old_mode']:
318 if head['new_mode'] and head['old_mode']:
320 op = OPS.MOD
319 op = OPS.MOD
321 stats['binary'] = True
320 stats['binary'] = True
322 stats['ops'][CHMOD_FILENODE] = (
321 stats['ops'][CHMOD_FILENODE] = (
323 'modified file chmod %s => %s' % (
322 'modified file chmod %s => %s' % (
324 head['old_mode'], head['new_mode']))
323 head['old_mode'], head['new_mode']))
325 # RENAME
324 # RENAME
326 if head['rename_from'] != head['rename_to']:
325 if head['rename_from'] != head['rename_to']:
327 op = OPS.MOD
326 op = OPS.MOD
328 stats['binary'] = True
327 stats['binary'] = True
329 stats['ops'][RENAMED_FILENODE] = (
328 stats['ops'][RENAMED_FILENODE] = (
330 'file renamed from %s to %s' % (
329 'file renamed from %s to %s' % (
331 head['rename_from'], head['rename_to']))
330 head['rename_from'], head['rename_to']))
332 # COPY
331 # COPY
333 if head.get('copy_from') and head.get('copy_to'):
332 if head.get('copy_from') and head.get('copy_to'):
334 op = OPS.MOD
333 op = OPS.MOD
335 stats['binary'] = True
334 stats['binary'] = True
336 stats['ops'][COPIED_FILENODE] = (
335 stats['ops'][COPIED_FILENODE] = (
337 'file copied from %s to %s' % (
336 'file copied from %s to %s' % (
338 head['copy_from'], head['copy_to']))
337 head['copy_from'], head['copy_to']))
339
338
340 # If our new parsed headers didn't match anything fallback to
339 # If our new parsed headers didn't match anything fallback to
341 # old style detection
340 # old style detection
342 if op is None:
341 if op is None:
343 if not head['a_file'] and head['b_file']:
342 if not head['a_file'] and head['b_file']:
344 op = OPS.ADD
343 op = OPS.ADD
345 stats['binary'] = True
344 stats['binary'] = True
346 stats['ops'][NEW_FILENODE] = 'new file'
345 stats['ops'][NEW_FILENODE] = 'new file'
347
346
348 elif head['a_file'] and not head['b_file']:
347 elif head['a_file'] and not head['b_file']:
349 op = OPS.DEL
348 op = OPS.DEL
350 stats['binary'] = True
349 stats['binary'] = True
351 stats['ops'][DEL_FILENODE] = 'deleted file'
350 stats['ops'][DEL_FILENODE] = 'deleted file'
352
351
353 # it's not ADD not DELETE
352 # it's not ADD not DELETE
354 if op is None:
353 if op is None:
355 op = OPS.MOD
354 op = OPS.MOD
356 stats['binary'] = True
355 stats['binary'] = True
357 stats['ops'][MOD_FILENODE] = 'modified file'
356 stats['ops'][MOD_FILENODE] = 'modified file'
358
357
359 # a real non-binary diff
358 # a real non-binary diff
360 if head['a_file'] or head['b_file']:
359 if head['a_file'] or head['b_file']:
361 try:
360 try:
362 raw_diff, chunks, _stats = self._parse_lines(diff)
361 raw_diff, chunks, _stats = self._parse_lines(diff)
363 stats['binary'] = False
362 stats['binary'] = False
364 stats['added'] = _stats[0]
363 stats['added'] = _stats[0]
365 stats['deleted'] = _stats[1]
364 stats['deleted'] = _stats[1]
366 # explicit mark that it's a modified file
365 # explicit mark that it's a modified file
367 if op == OPS.MOD:
366 if op == OPS.MOD:
368 stats['ops'][MOD_FILENODE] = 'modified file'
367 stats['ops'][MOD_FILENODE] = 'modified file'
369 exceeds_limit = len(raw_diff) > self.file_limit
368 exceeds_limit = len(raw_diff) > self.file_limit
370
369
371 # changed from _escaper function so we validate size of
370 # changed from _escaper function so we validate size of
372 # each file instead of the whole diff
371 # each file instead of the whole diff
373 # diff will hide big files but still show small ones
372 # diff will hide big files but still show small ones
374 # from my tests, big files are fairly safe to be parsed
373 # from my tests, big files are fairly safe to be parsed
375 # but the browser is the bottleneck
374 # but the browser is the bottleneck
376 if not self.show_full_diff and exceeds_limit:
375 if not self.show_full_diff and exceeds_limit:
377 raise DiffLimitExceeded('File Limit Exceeded')
376 raise DiffLimitExceeded('File Limit Exceeded')
378
377
379 except DiffLimitExceeded:
378 except DiffLimitExceeded:
380 diff_container = lambda _diff: \
379 diff_container = lambda _diff: \
381 LimitedDiffContainer(
380 LimitedDiffContainer(
382 self.diff_limit, self.cur_diff_size, _diff)
381 self.diff_limit, self.cur_diff_size, _diff)
383
382
384 exceeds_limit = len(raw_diff) > self.file_limit
383 exceeds_limit = len(raw_diff) > self.file_limit
385 limited_diff = True
384 limited_diff = True
386 chunks = []
385 chunks = []
387
386
388 else: # GIT format binary patch, or possibly empty diff
387 else: # GIT format binary patch, or possibly empty diff
389 if head['bin_patch']:
388 if head['bin_patch']:
390 # we have operation already extracted, but we mark simply
389 # we have operation already extracted, but we mark simply
391 # it's a diff we wont show for binary files
390 # it's a diff we wont show for binary files
392 stats['ops'][BIN_FILENODE] = 'binary diff hidden'
391 stats['ops'][BIN_FILENODE] = 'binary diff hidden'
393 chunks = []
392 chunks = []
394
393
395 if chunks and not self.show_full_diff and op == OPS.DEL:
394 if chunks and not self.show_full_diff and op == OPS.DEL:
396 # if not full diff mode show deleted file contents
395 # if not full diff mode show deleted file contents
397 # TODO: anderson: if the view is not too big, there is no way
396 # TODO: anderson: if the view is not too big, there is no way
398 # to see the content of the file
397 # to see the content of the file
399 chunks = []
398 chunks = []
400
399
401 chunks.insert(0, [{
400 chunks.insert(0, [{
402 'old_lineno': '',
401 'old_lineno': '',
403 'new_lineno': '',
402 'new_lineno': '',
404 'action': Action.CONTEXT,
403 'action': Action.CONTEXT,
405 'line': msg,
404 'line': msg,
406 } for _op, msg in stats['ops'].items()
405 } for _op, msg in stats['ops'].items()
407 if _op not in [MOD_FILENODE]])
406 if _op not in [MOD_FILENODE]])
408
407
409 _files.append({
408 _files.append({
410 'filename': safe_unicode(head['b_path']),
409 'filename': safe_unicode(head['b_path']),
411 'old_revision': head['a_blob_id'],
410 'old_revision': head['a_blob_id'],
412 'new_revision': head['b_blob_id'],
411 'new_revision': head['b_blob_id'],
413 'chunks': chunks,
412 'chunks': chunks,
414 'raw_diff': safe_unicode(raw_diff),
413 'raw_diff': safe_unicode(raw_diff),
415 'operation': op,
414 'operation': op,
416 'stats': stats,
415 'stats': stats,
417 'exceeds_limit': exceeds_limit,
416 'exceeds_limit': exceeds_limit,
418 'is_limited_diff': limited_diff,
417 'is_limited_diff': limited_diff,
419 })
418 })
420
419
421 sorter = lambda info: {OPS.ADD: 0, OPS.MOD: 1,
420 sorter = lambda info: {OPS.ADD: 0, OPS.MOD: 1,
422 OPS.DEL: 2}.get(info['operation'])
421 OPS.DEL: 2}.get(info['operation'])
423
422
424 if not inline_diff:
423 if not inline_diff:
425 return diff_container(sorted(_files, key=sorter))
424 return diff_container(sorted(_files, key=sorter))
426
425
427 # highlight inline changes
426 # highlight inline changes
428 for diff_data in _files:
427 for diff_data in _files:
429 for chunk in diff_data['chunks']:
428 for chunk in diff_data['chunks']:
430 lineiter = iter(chunk)
429 lineiter = iter(chunk)
431 try:
430 try:
432 while 1:
431 while 1:
433 line = next(lineiter)
432 line = next(lineiter)
434 if line['action'] not in (
433 if line['action'] not in (
435 Action.UNMODIFIED, Action.CONTEXT):
434 Action.UNMODIFIED, Action.CONTEXT):
436 nextline = next(lineiter)
435 nextline = next(lineiter)
437 if nextline['action'] in ['unmod', 'context'] or \
436 if nextline['action'] in ['unmod', 'context'] or \
438 nextline['action'] == line['action']:
437 nextline['action'] == line['action']:
439 continue
438 continue
440 self.differ(line, nextline)
439 self.differ(line, nextline)
441 except StopIteration:
440 except StopIteration:
442 pass
441 pass
443
442
444 return diff_container(sorted(_files, key=sorter))
443 return diff_container(sorted(_files, key=sorter))
445
444
446 def _check_large_diff(self):
445 def _check_large_diff(self):
447 if self.diff_limit:
446 if self.diff_limit:
448 log.debug('Checking if diff exceeds current diff_limit of %s', self.diff_limit)
447 log.debug('Checking if diff exceeds current diff_limit of %s', self.diff_limit)
449 if not self.show_full_diff and (self.cur_diff_size > self.diff_limit):
448 if not self.show_full_diff and (self.cur_diff_size > self.diff_limit):
450 raise DiffLimitExceeded('Diff Limit `%s` Exceeded', self.diff_limit)
449 raise DiffLimitExceeded('Diff Limit `%s` Exceeded', self.diff_limit)
451
450
452 # FIXME: NEWDIFFS: dan: this replaces _parse_gitdiff
451 # FIXME: NEWDIFFS: dan: this replaces _parse_gitdiff
453 def _new_parse_gitdiff(self, inline_diff=True):
452 def _new_parse_gitdiff(self, inline_diff=True):
454 _files = []
453 _files = []
455
454
456 # this can be overriden later to a LimitedDiffContainer type
455 # this can be overriden later to a LimitedDiffContainer type
457 diff_container = lambda arg: arg
456 diff_container = lambda arg: arg
458
457
459 for chunk in self._diff.chunks():
458 for chunk in self._diff.chunks():
460 head = chunk.header
459 head = chunk.header
461 log.debug('parsing diff %r', head)
460 log.debug('parsing diff %r', head)
462
461
463 raw_diff = chunk.raw
462 raw_diff = chunk.raw
464 limited_diff = False
463 limited_diff = False
465 exceeds_limit = False
464 exceeds_limit = False
466
465
467 op = None
466 op = None
468 stats = {
467 stats = {
469 'added': 0,
468 'added': 0,
470 'deleted': 0,
469 'deleted': 0,
471 'binary': False,
470 'binary': False,
472 'old_mode': None,
471 'old_mode': None,
473 'new_mode': None,
472 'new_mode': None,
474 'ops': {},
473 'ops': {},
475 }
474 }
476 if head['old_mode']:
475 if head['old_mode']:
477 stats['old_mode'] = head['old_mode']
476 stats['old_mode'] = head['old_mode']
478 if head['new_mode']:
477 if head['new_mode']:
479 stats['new_mode'] = head['new_mode']
478 stats['new_mode'] = head['new_mode']
480 if head['b_mode']:
479 if head['b_mode']:
481 stats['new_mode'] = head['b_mode']
480 stats['new_mode'] = head['b_mode']
482
481
483 # delete file
482 # delete file
484 if head['deleted_file_mode']:
483 if head['deleted_file_mode']:
485 op = OPS.DEL
484 op = OPS.DEL
486 stats['binary'] = True
485 stats['binary'] = True
487 stats['ops'][DEL_FILENODE] = 'deleted file'
486 stats['ops'][DEL_FILENODE] = 'deleted file'
488
487
489 # new file
488 # new file
490 elif head['new_file_mode']:
489 elif head['new_file_mode']:
491 op = OPS.ADD
490 op = OPS.ADD
492 stats['binary'] = True
491 stats['binary'] = True
493 stats['old_mode'] = None
492 stats['old_mode'] = None
494 stats['new_mode'] = head['new_file_mode']
493 stats['new_mode'] = head['new_file_mode']
495 stats['ops'][NEW_FILENODE] = 'new file %s' % head['new_file_mode']
494 stats['ops'][NEW_FILENODE] = 'new file %s' % head['new_file_mode']
496
495
497 # modify operation, can be copy, rename or chmod
496 # modify operation, can be copy, rename or chmod
498 else:
497 else:
499 # CHMOD
498 # CHMOD
500 if head['new_mode'] and head['old_mode']:
499 if head['new_mode'] and head['old_mode']:
501 op = OPS.MOD
500 op = OPS.MOD
502 stats['binary'] = True
501 stats['binary'] = True
503 stats['ops'][CHMOD_FILENODE] = (
502 stats['ops'][CHMOD_FILENODE] = (
504 'modified file chmod %s => %s' % (
503 'modified file chmod %s => %s' % (
505 head['old_mode'], head['new_mode']))
504 head['old_mode'], head['new_mode']))
506
505
507 # RENAME
506 # RENAME
508 if head['rename_from'] != head['rename_to']:
507 if head['rename_from'] != head['rename_to']:
509 op = OPS.MOD
508 op = OPS.MOD
510 stats['binary'] = True
509 stats['binary'] = True
511 stats['renamed'] = (head['rename_from'], head['rename_to'])
510 stats['renamed'] = (head['rename_from'], head['rename_to'])
512 stats['ops'][RENAMED_FILENODE] = (
511 stats['ops'][RENAMED_FILENODE] = (
513 'file renamed from %s to %s' % (
512 'file renamed from %s to %s' % (
514 head['rename_from'], head['rename_to']))
513 head['rename_from'], head['rename_to']))
515 # COPY
514 # COPY
516 if head.get('copy_from') and head.get('copy_to'):
515 if head.get('copy_from') and head.get('copy_to'):
517 op = OPS.MOD
516 op = OPS.MOD
518 stats['binary'] = True
517 stats['binary'] = True
519 stats['copied'] = (head['copy_from'], head['copy_to'])
518 stats['copied'] = (head['copy_from'], head['copy_to'])
520 stats['ops'][COPIED_FILENODE] = (
519 stats['ops'][COPIED_FILENODE] = (
521 'file copied from %s to %s' % (
520 'file copied from %s to %s' % (
522 head['copy_from'], head['copy_to']))
521 head['copy_from'], head['copy_to']))
523
522
524 # If our new parsed headers didn't match anything fallback to
523 # If our new parsed headers didn't match anything fallback to
525 # old style detection
524 # old style detection
526 if op is None:
525 if op is None:
527 if not head['a_file'] and head['b_file']:
526 if not head['a_file'] and head['b_file']:
528 op = OPS.ADD
527 op = OPS.ADD
529 stats['binary'] = True
528 stats['binary'] = True
530 stats['new_file'] = True
529 stats['new_file'] = True
531 stats['ops'][NEW_FILENODE] = 'new file'
530 stats['ops'][NEW_FILENODE] = 'new file'
532
531
533 elif head['a_file'] and not head['b_file']:
532 elif head['a_file'] and not head['b_file']:
534 op = OPS.DEL
533 op = OPS.DEL
535 stats['binary'] = True
534 stats['binary'] = True
536 stats['ops'][DEL_FILENODE] = 'deleted file'
535 stats['ops'][DEL_FILENODE] = 'deleted file'
537
536
538 # it's not ADD not DELETE
537 # it's not ADD not DELETE
539 if op is None:
538 if op is None:
540 op = OPS.MOD
539 op = OPS.MOD
541 stats['binary'] = True
540 stats['binary'] = True
542 stats['ops'][MOD_FILENODE] = 'modified file'
541 stats['ops'][MOD_FILENODE] = 'modified file'
543
542
544 # a real non-binary diff
543 # a real non-binary diff
545 if head['a_file'] or head['b_file']:
544 if head['a_file'] or head['b_file']:
546 # simulate splitlines, so we keep the line end part
545 # simulate splitlines, so we keep the line end part
547 diff = self.diff_splitter(chunk.diff)
546 diff = self.diff_splitter(chunk.diff)
548
547
549 # append each file to the diff size
548 # append each file to the diff size
550 raw_chunk_size = len(raw_diff)
549 raw_chunk_size = len(raw_diff)
551
550
552 exceeds_limit = raw_chunk_size > self.file_limit
551 exceeds_limit = raw_chunk_size > self.file_limit
553 self.cur_diff_size += raw_chunk_size
552 self.cur_diff_size += raw_chunk_size
554
553
555 try:
554 try:
556 # Check each file instead of the whole diff.
555 # Check each file instead of the whole diff.
557 # Diff will hide big files but still show small ones.
556 # Diff will hide big files but still show small ones.
558 # From the tests big files are fairly safe to be parsed
557 # From the tests big files are fairly safe to be parsed
559 # but the browser is the bottleneck.
558 # but the browser is the bottleneck.
560 if not self.show_full_diff and exceeds_limit:
559 if not self.show_full_diff and exceeds_limit:
561 log.debug('File `%s` exceeds current file_limit of %s',
560 log.debug('File `%s` exceeds current file_limit of %s',
562 safe_unicode(head['b_path']), self.file_limit)
561 safe_unicode(head['b_path']), self.file_limit)
563 raise DiffLimitExceeded(
562 raise DiffLimitExceeded(
564 'File Limit %s Exceeded', self.file_limit)
563 'File Limit %s Exceeded', self.file_limit)
565
564
566 self._check_large_diff()
565 self._check_large_diff()
567
566
568 raw_diff, chunks, _stats = self._new_parse_lines(diff)
567 raw_diff, chunks, _stats = self._new_parse_lines(diff)
569 stats['binary'] = False
568 stats['binary'] = False
570 stats['added'] = _stats[0]
569 stats['added'] = _stats[0]
571 stats['deleted'] = _stats[1]
570 stats['deleted'] = _stats[1]
572 # explicit mark that it's a modified file
571 # explicit mark that it's a modified file
573 if op == OPS.MOD:
572 if op == OPS.MOD:
574 stats['ops'][MOD_FILENODE] = 'modified file'
573 stats['ops'][MOD_FILENODE] = 'modified file'
575
574
576 except DiffLimitExceeded:
575 except DiffLimitExceeded:
577 diff_container = lambda _diff: \
576 diff_container = lambda _diff: \
578 LimitedDiffContainer(
577 LimitedDiffContainer(
579 self.diff_limit, self.cur_diff_size, _diff)
578 self.diff_limit, self.cur_diff_size, _diff)
580
579
581 limited_diff = True
580 limited_diff = True
582 chunks = []
581 chunks = []
583
582
584 else: # GIT format binary patch, or possibly empty diff
583 else: # GIT format binary patch, or possibly empty diff
585 if head['bin_patch']:
584 if head['bin_patch']:
586 # we have operation already extracted, but we mark simply
585 # we have operation already extracted, but we mark simply
587 # it's a diff we wont show for binary files
586 # it's a diff we wont show for binary files
588 stats['ops'][BIN_FILENODE] = 'binary diff hidden'
587 stats['ops'][BIN_FILENODE] = 'binary diff hidden'
589 chunks = []
588 chunks = []
590
589
591 # Hide content of deleted node by setting empty chunks
590 # Hide content of deleted node by setting empty chunks
592 if chunks and not self.show_full_diff and op == OPS.DEL:
591 if chunks and not self.show_full_diff and op == OPS.DEL:
593 # if not full diff mode show deleted file contents
592 # if not full diff mode show deleted file contents
594 # TODO: anderson: if the view is not too big, there is no way
593 # TODO: anderson: if the view is not too big, there is no way
595 # to see the content of the file
594 # to see the content of the file
596 chunks = []
595 chunks = []
597
596
598 chunks.insert(
597 chunks.insert(
599 0, [{'old_lineno': '',
598 0, [{'old_lineno': '',
600 'new_lineno': '',
599 'new_lineno': '',
601 'action': Action.CONTEXT,
600 'action': Action.CONTEXT,
602 'line': msg,
601 'line': msg,
603 } for _op, msg in stats['ops'].items()
602 } for _op, msg in stats['ops'].items()
604 if _op not in [MOD_FILENODE]])
603 if _op not in [MOD_FILENODE]])
605
604
606 original_filename = safe_unicode(head['a_path'])
605 original_filename = safe_unicode(head['a_path'])
607 _files.append({
606 _files.append({
608 'original_filename': original_filename,
607 'original_filename': original_filename,
609 'filename': safe_unicode(head['b_path']),
608 'filename': safe_unicode(head['b_path']),
610 'old_revision': head['a_blob_id'],
609 'old_revision': head['a_blob_id'],
611 'new_revision': head['b_blob_id'],
610 'new_revision': head['b_blob_id'],
612 'chunks': chunks,
611 'chunks': chunks,
613 'raw_diff': safe_unicode(raw_diff),
612 'raw_diff': safe_unicode(raw_diff),
614 'operation': op,
613 'operation': op,
615 'stats': stats,
614 'stats': stats,
616 'exceeds_limit': exceeds_limit,
615 'exceeds_limit': exceeds_limit,
617 'is_limited_diff': limited_diff,
616 'is_limited_diff': limited_diff,
618 })
617 })
619
618
620 sorter = lambda info: {OPS.ADD: 0, OPS.MOD: 1,
619 sorter = lambda info: {OPS.ADD: 0, OPS.MOD: 1,
621 OPS.DEL: 2}.get(info['operation'])
620 OPS.DEL: 2}.get(info['operation'])
622
621
623 return diff_container(sorted(_files, key=sorter))
622 return diff_container(sorted(_files, key=sorter))
624
623
625 # FIXME: NEWDIFFS: dan: this gets replaced by _new_parse_lines
624 # FIXME: NEWDIFFS: dan: this gets replaced by _new_parse_lines
626 def _parse_lines(self, diff_iter):
625 def _parse_lines(self, diff_iter):
627 """
626 """
628 Parse the diff an return data for the template.
627 Parse the diff an return data for the template.
629 """
628 """
630
629
631 stats = [0, 0]
630 stats = [0, 0]
632 chunks = []
631 chunks = []
633 raw_diff = []
632 raw_diff = []
634
633
635 try:
634 try:
636 line = next(diff_iter)
635 line = next(diff_iter)
637
636
638 while line:
637 while line:
639 raw_diff.append(line)
638 raw_diff.append(line)
640 lines = []
639 lines = []
641 chunks.append(lines)
640 chunks.append(lines)
642
641
643 match = self._chunk_re.match(line)
642 match = self._chunk_re.match(line)
644
643
645 if not match:
644 if not match:
646 break
645 break
647
646
648 gr = match.groups()
647 gr = match.groups()
649 (old_line, old_end,
648 (old_line, old_end,
650 new_line, new_end) = [int(x or 1) for x in gr[:-1]]
649 new_line, new_end) = [int(x or 1) for x in gr[:-1]]
651 old_line -= 1
650 old_line -= 1
652 new_line -= 1
651 new_line -= 1
653
652
654 context = len(gr) == 5
653 context = len(gr) == 5
655 old_end += old_line
654 old_end += old_line
656 new_end += new_line
655 new_end += new_line
657
656
658 if context:
657 if context:
659 # skip context only if it's first line
658 # skip context only if it's first line
660 if int(gr[0]) > 1:
659 if int(gr[0]) > 1:
661 lines.append({
660 lines.append({
662 'old_lineno': '...',
661 'old_lineno': '...',
663 'new_lineno': '...',
662 'new_lineno': '...',
664 'action': Action.CONTEXT,
663 'action': Action.CONTEXT,
665 'line': line,
664 'line': line,
666 })
665 })
667
666
668 line = next(diff_iter)
667 line = next(diff_iter)
669
668
670 while old_line < old_end or new_line < new_end:
669 while old_line < old_end or new_line < new_end:
671 command = ' '
670 command = ' '
672 if line:
671 if line:
673 command = line[0]
672 command = line[0]
674
673
675 affects_old = affects_new = False
674 affects_old = affects_new = False
676
675
677 # ignore those if we don't expect them
676 # ignore those if we don't expect them
678 if command in '#@':
677 if command in '#@':
679 continue
678 continue
680 elif command == '+':
679 elif command == '+':
681 affects_new = True
680 affects_new = True
682 action = Action.ADD
681 action = Action.ADD
683 stats[0] += 1
682 stats[0] += 1
684 elif command == '-':
683 elif command == '-':
685 affects_old = True
684 affects_old = True
686 action = Action.DELETE
685 action = Action.DELETE
687 stats[1] += 1
686 stats[1] += 1
688 else:
687 else:
689 affects_old = affects_new = True
688 affects_old = affects_new = True
690 action = Action.UNMODIFIED
689 action = Action.UNMODIFIED
691
690
692 if not self._newline_marker.match(line):
691 if not self._newline_marker.match(line):
693 old_line += affects_old
692 old_line += affects_old
694 new_line += affects_new
693 new_line += affects_new
695 lines.append({
694 lines.append({
696 'old_lineno': affects_old and old_line or '',
695 'old_lineno': affects_old and old_line or '',
697 'new_lineno': affects_new and new_line or '',
696 'new_lineno': affects_new and new_line or '',
698 'action': action,
697 'action': action,
699 'line': self._clean_line(line, command)
698 'line': self._clean_line(line, command)
700 })
699 })
701 raw_diff.append(line)
700 raw_diff.append(line)
702
701
703 line = next(diff_iter)
702 line = next(diff_iter)
704
703
705 if self._newline_marker.match(line):
704 if self._newline_marker.match(line):
706 # we need to append to lines, since this is not
705 # we need to append to lines, since this is not
707 # counted in the line specs of diff
706 # counted in the line specs of diff
708 lines.append({
707 lines.append({
709 'old_lineno': '...',
708 'old_lineno': '...',
710 'new_lineno': '...',
709 'new_lineno': '...',
711 'action': Action.CONTEXT,
710 'action': Action.CONTEXT,
712 'line': self._clean_line(line, command)
711 'line': self._clean_line(line, command)
713 })
712 })
714
713
715 except StopIteration:
714 except StopIteration:
716 pass
715 pass
717 return ''.join(raw_diff), chunks, stats
716 return ''.join(raw_diff), chunks, stats
718
717
719 # FIXME: NEWDIFFS: dan: this replaces _parse_lines
718 # FIXME: NEWDIFFS: dan: this replaces _parse_lines
720 def _new_parse_lines(self, diff_iter):
719 def _new_parse_lines(self, diff_iter):
721 """
720 """
722 Parse the diff an return data for the template.
721 Parse the diff an return data for the template.
723 """
722 """
724
723
725 stats = [0, 0]
724 stats = [0, 0]
726 chunks = []
725 chunks = []
727 raw_diff = []
726 raw_diff = []
728
727
729 try:
728 try:
730 line = next(diff_iter)
729 line = next(diff_iter)
731
730
732 while line:
731 while line:
733 raw_diff.append(line)
732 raw_diff.append(line)
734 # match header e.g @@ -0,0 +1 @@\n'
733 # match header e.g @@ -0,0 +1 @@\n'
735 match = self._chunk_re.match(line)
734 match = self._chunk_re.match(line)
736
735
737 if not match:
736 if not match:
738 break
737 break
739
738
740 gr = match.groups()
739 gr = match.groups()
741 (old_line, old_end,
740 (old_line, old_end,
742 new_line, new_end) = [int(x or 1) for x in gr[:-1]]
741 new_line, new_end) = [int(x or 1) for x in gr[:-1]]
743
742
744 lines = []
743 lines = []
745 hunk = {
744 hunk = {
746 'section_header': gr[-1],
745 'section_header': gr[-1],
747 'source_start': old_line,
746 'source_start': old_line,
748 'source_length': old_end,
747 'source_length': old_end,
749 'target_start': new_line,
748 'target_start': new_line,
750 'target_length': new_end,
749 'target_length': new_end,
751 'lines': lines,
750 'lines': lines,
752 }
751 }
753 chunks.append(hunk)
752 chunks.append(hunk)
754
753
755 old_line -= 1
754 old_line -= 1
756 new_line -= 1
755 new_line -= 1
757
756
758 context = len(gr) == 5
757 context = len(gr) == 5
759 old_end += old_line
758 old_end += old_line
760 new_end += new_line
759 new_end += new_line
761
760
762 line = next(diff_iter)
761 line = next(diff_iter)
763
762
764 while old_line < old_end or new_line < new_end:
763 while old_line < old_end or new_line < new_end:
765 command = ' '
764 command = ' '
766 if line:
765 if line:
767 command = line[0]
766 command = line[0]
768
767
769 affects_old = affects_new = False
768 affects_old = affects_new = False
770
769
771 # ignore those if we don't expect them
770 # ignore those if we don't expect them
772 if command in '#@':
771 if command in '#@':
773 continue
772 continue
774 elif command == '+':
773 elif command == '+':
775 affects_new = True
774 affects_new = True
776 action = Action.ADD
775 action = Action.ADD
777 stats[0] += 1
776 stats[0] += 1
778 elif command == '-':
777 elif command == '-':
779 affects_old = True
778 affects_old = True
780 action = Action.DELETE
779 action = Action.DELETE
781 stats[1] += 1
780 stats[1] += 1
782 else:
781 else:
783 affects_old = affects_new = True
782 affects_old = affects_new = True
784 action = Action.UNMODIFIED
783 action = Action.UNMODIFIED
785
784
786 if not self._newline_marker.match(line):
785 if not self._newline_marker.match(line):
787 old_line += affects_old
786 old_line += affects_old
788 new_line += affects_new
787 new_line += affects_new
789 lines.append({
788 lines.append({
790 'old_lineno': affects_old and old_line or '',
789 'old_lineno': affects_old and old_line or '',
791 'new_lineno': affects_new and new_line or '',
790 'new_lineno': affects_new and new_line or '',
792 'action': action,
791 'action': action,
793 'line': self._clean_line(line, command)
792 'line': self._clean_line(line, command)
794 })
793 })
795 raw_diff.append(line)
794 raw_diff.append(line)
796
795
797 line = next(diff_iter)
796 line = next(diff_iter)
798
797
799 if self._newline_marker.match(line):
798 if self._newline_marker.match(line):
800 # we need to append to lines, since this is not
799 # we need to append to lines, since this is not
801 # counted in the line specs of diff
800 # counted in the line specs of diff
802 if affects_old:
801 if affects_old:
803 action = Action.OLD_NO_NL
802 action = Action.OLD_NO_NL
804 elif affects_new:
803 elif affects_new:
805 action = Action.NEW_NO_NL
804 action = Action.NEW_NO_NL
806 else:
805 else:
807 raise Exception('invalid context for no newline')
806 raise Exception('invalid context for no newline')
808
807
809 lines.append({
808 lines.append({
810 'old_lineno': None,
809 'old_lineno': None,
811 'new_lineno': None,
810 'new_lineno': None,
812 'action': action,
811 'action': action,
813 'line': self._clean_line(line, command)
812 'line': self._clean_line(line, command)
814 })
813 })
815
814
816 except StopIteration:
815 except StopIteration:
817 pass
816 pass
818
817
819 return ''.join(raw_diff), chunks, stats
818 return ''.join(raw_diff), chunks, stats
820
819
821 def _safe_id(self, idstring):
820 def _safe_id(self, idstring):
822 """Make a string safe for including in an id attribute.
821 """Make a string safe for including in an id attribute.
823
822
824 The HTML spec says that id attributes 'must begin with
823 The HTML spec says that id attributes 'must begin with
825 a letter ([A-Za-z]) and may be followed by any number
824 a letter ([A-Za-z]) and may be followed by any number
826 of letters, digits ([0-9]), hyphens ("-"), underscores
825 of letters, digits ([0-9]), hyphens ("-"), underscores
827 ("_"), colons (":"), and periods (".")'. These regexps
826 ("_"), colons (":"), and periods (".")'. These regexps
828 are slightly over-zealous, in that they remove colons
827 are slightly over-zealous, in that they remove colons
829 and periods unnecessarily.
828 and periods unnecessarily.
830
829
831 Whitespace is transformed into underscores, and then
830 Whitespace is transformed into underscores, and then
832 anything which is not a hyphen or a character that
831 anything which is not a hyphen or a character that
833 matches \w (alphanumerics and underscore) is removed.
832 matches \w (alphanumerics and underscore) is removed.
834
833
835 """
834 """
836 # Transform all whitespace to underscore
835 # Transform all whitespace to underscore
837 idstring = re.sub(r'\s', "_", '%s' % idstring)
836 idstring = re.sub(r'\s', "_", '%s' % idstring)
838 # Remove everything that is not a hyphen or a member of \w
837 # Remove everything that is not a hyphen or a member of \w
839 idstring = re.sub(r'(?!-)\W', "", idstring).lower()
838 idstring = re.sub(r'(?!-)\W', "", idstring).lower()
840 return idstring
839 return idstring
841
840
842 @classmethod
841 @classmethod
843 def diff_splitter(cls, string):
842 def diff_splitter(cls, string):
844 """
843 """
845 Diff split that emulates .splitlines() but works only on \n
844 Diff split that emulates .splitlines() but works only on \n
846 """
845 """
847 if not string:
846 if not string:
848 return
847 return
849 elif string == '\n':
848 elif string == '\n':
850 yield u'\n'
849 yield '\n'
851 else:
850 else:
852
851
853 has_newline = string.endswith('\n')
852 has_newline = string.endswith('\n')
854 elements = string.split('\n')
853 elements = string.split('\n')
855 if has_newline:
854 if has_newline:
856 # skip last element as it's empty string from newlines
855 # skip last element as it's empty string from newlines
857 elements = elements[:-1]
856 elements = elements[:-1]
858
857
859 len_elements = len(elements)
858 len_elements = len(elements)
860
859
861 for cnt, line in enumerate(elements, start=1):
860 for cnt, line in enumerate(elements, start=1):
862 last_line = cnt == len_elements
861 last_line = cnt == len_elements
863 if last_line and not has_newline:
862 if last_line and not has_newline:
864 yield safe_unicode(line)
863 yield safe_unicode(line)
865 else:
864 else:
866 yield safe_unicode(line) + '\n'
865 yield safe_unicode(line) + '\n'
867
866
868 def prepare(self, inline_diff=True):
867 def prepare(self, inline_diff=True):
869 """
868 """
870 Prepare the passed udiff for HTML rendering.
869 Prepare the passed udiff for HTML rendering.
871
870
872 :return: A list of dicts with diff information.
871 :return: A list of dicts with diff information.
873 """
872 """
874 parsed = self._parser(inline_diff=inline_diff)
873 parsed = self._parser(inline_diff=inline_diff)
875 self.parsed = True
874 self.parsed = True
876 self.parsed_diff = parsed
875 self.parsed_diff = parsed
877 return parsed
876 return parsed
878
877
879 def as_raw(self, diff_lines=None):
878 def as_raw(self, diff_lines=None):
880 """
879 """
881 Returns raw diff as a byte string
880 Returns raw diff as a byte string
882 """
881 """
883 return self._diff.raw
882 return self._diff.raw
884
883
885 def as_html(self, table_class='code-difftable', line_class='line',
884 def as_html(self, table_class='code-difftable', line_class='line',
886 old_lineno_class='lineno old', new_lineno_class='lineno new',
885 old_lineno_class='lineno old', new_lineno_class='lineno new',
887 code_class='code', enable_comments=False, parsed_lines=None):
886 code_class='code', enable_comments=False, parsed_lines=None):
888 """
887 """
889 Return given diff as html table with customized css classes
888 Return given diff as html table with customized css classes
890 """
889 """
891 # TODO(marcink): not sure how to pass in translator
890 # TODO(marcink): not sure how to pass in translator
892 # here in an efficient way, leave the _ for proper gettext extraction
891 # here in an efficient way, leave the _ for proper gettext extraction
893 _ = lambda s: s
892 _ = lambda s: s
894
893
895 def _link_to_if(condition, label, url):
894 def _link_to_if(condition, label, url):
896 """
895 """
897 Generates a link if condition is meet or just the label if not.
896 Generates a link if condition is meet or just the label if not.
898 """
897 """
899
898
900 if condition:
899 if condition:
901 return '''<a href="%(url)s" class="tooltip"
900 return '''<a href="%(url)s" class="tooltip"
902 title="%(title)s">%(label)s</a>''' % {
901 title="%(title)s">%(label)s</a>''' % {
903 'title': _('Click to select line'),
902 'title': _('Click to select line'),
904 'url': url,
903 'url': url,
905 'label': label
904 'label': label
906 }
905 }
907 else:
906 else:
908 return label
907 return label
909 if not self.parsed:
908 if not self.parsed:
910 self.prepare()
909 self.prepare()
911
910
912 diff_lines = self.parsed_diff
911 diff_lines = self.parsed_diff
913 if parsed_lines:
912 if parsed_lines:
914 diff_lines = parsed_lines
913 diff_lines = parsed_lines
915
914
916 _html_empty = True
915 _html_empty = True
917 _html = []
916 _html = []
918 _html.append('''<table class="%(table_class)s">\n''' % {
917 _html.append('''<table class="%(table_class)s">\n''' % {
919 'table_class': table_class
918 'table_class': table_class
920 })
919 })
921
920
922 for diff in diff_lines:
921 for diff in diff_lines:
923 for line in diff['chunks']:
922 for line in diff['chunks']:
924 _html_empty = False
923 _html_empty = False
925 for change in line:
924 for change in line:
926 _html.append('''<tr class="%(lc)s %(action)s">\n''' % {
925 _html.append('''<tr class="%(lc)s %(action)s">\n''' % {
927 'lc': line_class,
926 'lc': line_class,
928 'action': change['action']
927 'action': change['action']
929 })
928 })
930 anchor_old_id = ''
929 anchor_old_id = ''
931 anchor_new_id = ''
930 anchor_new_id = ''
932 anchor_old = "%(filename)s_o%(oldline_no)s" % {
931 anchor_old = "%(filename)s_o%(oldline_no)s" % {
933 'filename': self._safe_id(diff['filename']),
932 'filename': self._safe_id(diff['filename']),
934 'oldline_no': change['old_lineno']
933 'oldline_no': change['old_lineno']
935 }
934 }
936 anchor_new = "%(filename)s_n%(oldline_no)s" % {
935 anchor_new = "%(filename)s_n%(oldline_no)s" % {
937 'filename': self._safe_id(diff['filename']),
936 'filename': self._safe_id(diff['filename']),
938 'oldline_no': change['new_lineno']
937 'oldline_no': change['new_lineno']
939 }
938 }
940 cond_old = (change['old_lineno'] != '...' and
939 cond_old = (change['old_lineno'] != '...' and
941 change['old_lineno'])
940 change['old_lineno'])
942 cond_new = (change['new_lineno'] != '...' and
941 cond_new = (change['new_lineno'] != '...' and
943 change['new_lineno'])
942 change['new_lineno'])
944 if cond_old:
943 if cond_old:
945 anchor_old_id = 'id="%s"' % anchor_old
944 anchor_old_id = 'id="%s"' % anchor_old
946 if cond_new:
945 if cond_new:
947 anchor_new_id = 'id="%s"' % anchor_new
946 anchor_new_id = 'id="%s"' % anchor_new
948
947
949 if change['action'] != Action.CONTEXT:
948 if change['action'] != Action.CONTEXT:
950 anchor_link = True
949 anchor_link = True
951 else:
950 else:
952 anchor_link = False
951 anchor_link = False
953
952
954 ###########################################################
953 ###########################################################
955 # COMMENT ICONS
954 # COMMENT ICONS
956 ###########################################################
955 ###########################################################
957 _html.append('''\t<td class="add-comment-line"><span class="add-comment-content">''')
956 _html.append('''\t<td class="add-comment-line"><span class="add-comment-content">''')
958
957
959 if enable_comments and change['action'] != Action.CONTEXT:
958 if enable_comments and change['action'] != Action.CONTEXT:
960 _html.append('''<a href="#"><span class="icon-comment-add"></span></a>''')
959 _html.append('''<a href="#"><span class="icon-comment-add"></span></a>''')
961
960
962 _html.append('''</span></td><td class="comment-toggle tooltip" title="Toggle Comment Thread"><i class="icon-comment"></i></td>\n''')
961 _html.append('''</span></td><td class="comment-toggle tooltip" title="Toggle Comment Thread"><i class="icon-comment"></i></td>\n''')
963
962
964 ###########################################################
963 ###########################################################
965 # OLD LINE NUMBER
964 # OLD LINE NUMBER
966 ###########################################################
965 ###########################################################
967 _html.append('''\t<td %(a_id)s class="%(olc)s">''' % {
966 _html.append('''\t<td %(a_id)s class="%(olc)s">''' % {
968 'a_id': anchor_old_id,
967 'a_id': anchor_old_id,
969 'olc': old_lineno_class
968 'olc': old_lineno_class
970 })
969 })
971
970
972 _html.append('''%(link)s''' % {
971 _html.append('''%(link)s''' % {
973 'link': _link_to_if(anchor_link, change['old_lineno'],
972 'link': _link_to_if(anchor_link, change['old_lineno'],
974 '#%s' % anchor_old)
973 '#%s' % anchor_old)
975 })
974 })
976 _html.append('''</td>\n''')
975 _html.append('''</td>\n''')
977 ###########################################################
976 ###########################################################
978 # NEW LINE NUMBER
977 # NEW LINE NUMBER
979 ###########################################################
978 ###########################################################
980
979
981 _html.append('''\t<td %(a_id)s class="%(nlc)s">''' % {
980 _html.append('''\t<td %(a_id)s class="%(nlc)s">''' % {
982 'a_id': anchor_new_id,
981 'a_id': anchor_new_id,
983 'nlc': new_lineno_class
982 'nlc': new_lineno_class
984 })
983 })
985
984
986 _html.append('''%(link)s''' % {
985 _html.append('''%(link)s''' % {
987 'link': _link_to_if(anchor_link, change['new_lineno'],
986 'link': _link_to_if(anchor_link, change['new_lineno'],
988 '#%s' % anchor_new)
987 '#%s' % anchor_new)
989 })
988 })
990 _html.append('''</td>\n''')
989 _html.append('''</td>\n''')
991 ###########################################################
990 ###########################################################
992 # CODE
991 # CODE
993 ###########################################################
992 ###########################################################
994 code_classes = [code_class]
993 code_classes = [code_class]
995 if (not enable_comments or
994 if (not enable_comments or
996 change['action'] == Action.CONTEXT):
995 change['action'] == Action.CONTEXT):
997 code_classes.append('no-comment')
996 code_classes.append('no-comment')
998 _html.append('\t<td class="%s">' % ' '.join(code_classes))
997 _html.append('\t<td class="%s">' % ' '.join(code_classes))
999 _html.append('''\n\t\t<pre>%(code)s</pre>\n''' % {
998 _html.append('''\n\t\t<pre>%(code)s</pre>\n''' % {
1000 'code': change['line']
999 'code': change['line']
1001 })
1000 })
1002
1001
1003 _html.append('''\t</td>''')
1002 _html.append('''\t</td>''')
1004 _html.append('''\n</tr>\n''')
1003 _html.append('''\n</tr>\n''')
1005 _html.append('''</table>''')
1004 _html.append('''</table>''')
1006 if _html_empty:
1005 if _html_empty:
1007 return None
1006 return None
1008 return ''.join(_html)
1007 return ''.join(_html)
1009
1008
1010 def stat(self):
1009 def stat(self):
1011 """
1010 """
1012 Returns tuple of added, and removed lines for this instance
1011 Returns tuple of added, and removed lines for this instance
1013 """
1012 """
1014 return self.adds, self.removes
1013 return self.adds, self.removes
1015
1014
1016 def get_context_of_line(
1015 def get_context_of_line(
1017 self, path, diff_line=None, context_before=3, context_after=3):
1016 self, path, diff_line=None, context_before=3, context_after=3):
1018 """
1017 """
1019 Returns the context lines for the specified diff line.
1018 Returns the context lines for the specified diff line.
1020
1019
1021 :type diff_line: :class:`DiffLineNumber`
1020 :type diff_line: :class:`DiffLineNumber`
1022 """
1021 """
1023 assert self.parsed, "DiffProcessor is not initialized."
1022 assert self.parsed, "DiffProcessor is not initialized."
1024
1023
1025 if None not in diff_line:
1024 if None not in diff_line:
1026 raise ValueError(
1025 raise ValueError(
1027 "Cannot specify both line numbers: {}".format(diff_line))
1026 "Cannot specify both line numbers: {}".format(diff_line))
1028
1027
1029 file_diff = self._get_file_diff(path)
1028 file_diff = self._get_file_diff(path)
1030 chunk, idx = self._find_chunk_line_index(file_diff, diff_line)
1029 chunk, idx = self._find_chunk_line_index(file_diff, diff_line)
1031
1030
1032 first_line_to_include = max(idx - context_before, 0)
1031 first_line_to_include = max(idx - context_before, 0)
1033 first_line_after_context = idx + context_after + 1
1032 first_line_after_context = idx + context_after + 1
1034 context_lines = chunk[first_line_to_include:first_line_after_context]
1033 context_lines = chunk[first_line_to_include:first_line_after_context]
1035
1034
1036 line_contents = [
1035 line_contents = [
1037 _context_line(line) for line in context_lines
1036 _context_line(line) for line in context_lines
1038 if _is_diff_content(line)]
1037 if _is_diff_content(line)]
1039 # TODO: johbo: Interim fixup, the diff chunks drop the final newline.
1038 # TODO: johbo: Interim fixup, the diff chunks drop the final newline.
1040 # Once they are fixed, we can drop this line here.
1039 # Once they are fixed, we can drop this line here.
1041 if line_contents:
1040 if line_contents:
1042 line_contents[-1] = (
1041 line_contents[-1] = (
1043 line_contents[-1][0], line_contents[-1][1].rstrip('\n') + '\n')
1042 line_contents[-1][0], line_contents[-1][1].rstrip('\n') + '\n')
1044 return line_contents
1043 return line_contents
1045
1044
1046 def find_context(self, path, context, offset=0):
1045 def find_context(self, path, context, offset=0):
1047 """
1046 """
1048 Finds the given `context` inside of the diff.
1047 Finds the given `context` inside of the diff.
1049
1048
1050 Use the parameter `offset` to specify which offset the target line has
1049 Use the parameter `offset` to specify which offset the target line has
1051 inside of the given `context`. This way the correct diff line will be
1050 inside of the given `context`. This way the correct diff line will be
1052 returned.
1051 returned.
1053
1052
1054 :param offset: Shall be used to specify the offset of the main line
1053 :param offset: Shall be used to specify the offset of the main line
1055 within the given `context`.
1054 within the given `context`.
1056 """
1055 """
1057 if offset < 0 or offset >= len(context):
1056 if offset < 0 or offset >= len(context):
1058 raise ValueError(
1057 raise ValueError(
1059 "Only positive values up to the length of the context "
1058 "Only positive values up to the length of the context "
1060 "minus one are allowed.")
1059 "minus one are allowed.")
1061
1060
1062 matches = []
1061 matches = []
1063 file_diff = self._get_file_diff(path)
1062 file_diff = self._get_file_diff(path)
1064
1063
1065 for chunk in file_diff['chunks']:
1064 for chunk in file_diff['chunks']:
1066 context_iter = iter(context)
1065 context_iter = iter(context)
1067 for line_idx, line in enumerate(chunk):
1066 for line_idx, line in enumerate(chunk):
1068 try:
1067 try:
1069 if _context_line(line) == next(context_iter):
1068 if _context_line(line) == next(context_iter):
1070 continue
1069 continue
1071 except StopIteration:
1070 except StopIteration:
1072 matches.append((line_idx, chunk))
1071 matches.append((line_idx, chunk))
1073 context_iter = iter(context)
1072 context_iter = iter(context)
1074
1073
1075 # Increment position and triger StopIteration
1074 # Increment position and triger StopIteration
1076 # if we had a match at the end
1075 # if we had a match at the end
1077 line_idx += 1
1076 line_idx += 1
1078 try:
1077 try:
1079 next(context_iter)
1078 next(context_iter)
1080 except StopIteration:
1079 except StopIteration:
1081 matches.append((line_idx, chunk))
1080 matches.append((line_idx, chunk))
1082
1081
1083 effective_offset = len(context) - offset
1082 effective_offset = len(context) - offset
1084 found_at_diff_lines = [
1083 found_at_diff_lines = [
1085 _line_to_diff_line_number(chunk[idx - effective_offset])
1084 _line_to_diff_line_number(chunk[idx - effective_offset])
1086 for idx, chunk in matches]
1085 for idx, chunk in matches]
1087
1086
1088 return found_at_diff_lines
1087 return found_at_diff_lines
1089
1088
1090 def _get_file_diff(self, path):
1089 def _get_file_diff(self, path):
1091 for file_diff in self.parsed_diff:
1090 for file_diff in self.parsed_diff:
1092 if file_diff['filename'] == path:
1091 if file_diff['filename'] == path:
1093 break
1092 break
1094 else:
1093 else:
1095 raise FileNotInDiffException("File {} not in diff".format(path))
1094 raise FileNotInDiffException("File {} not in diff".format(path))
1096 return file_diff
1095 return file_diff
1097
1096
1098 def _find_chunk_line_index(self, file_diff, diff_line):
1097 def _find_chunk_line_index(self, file_diff, diff_line):
1099 for chunk in file_diff['chunks']:
1098 for chunk in file_diff['chunks']:
1100 for idx, line in enumerate(chunk):
1099 for idx, line in enumerate(chunk):
1101 if line['old_lineno'] == diff_line.old:
1100 if line['old_lineno'] == diff_line.old:
1102 return chunk, idx
1101 return chunk, idx
1103 if line['new_lineno'] == diff_line.new:
1102 if line['new_lineno'] == diff_line.new:
1104 return chunk, idx
1103 return chunk, idx
1105 raise LineNotInDiffException(
1104 raise LineNotInDiffException(
1106 "The line {} is not part of the diff.".format(diff_line))
1105 "The line {} is not part of the diff.".format(diff_line))
1107
1106
1108
1107
1109 def _is_diff_content(line):
1108 def _is_diff_content(line):
1110 return line['action'] in (
1109 return line['action'] in (
1111 Action.UNMODIFIED, Action.ADD, Action.DELETE)
1110 Action.UNMODIFIED, Action.ADD, Action.DELETE)
1112
1111
1113
1112
1114 def _context_line(line):
1113 def _context_line(line):
1115 return (line['action'], line['line'])
1114 return (line['action'], line['line'])
1116
1115
1117
1116
1118 DiffLineNumber = collections.namedtuple('DiffLineNumber', ['old', 'new'])
1117 DiffLineNumber = collections.namedtuple('DiffLineNumber', ['old', 'new'])
1119
1118
1120
1119
1121 def _line_to_diff_line_number(line):
1120 def _line_to_diff_line_number(line):
1122 new_line_no = line['new_lineno'] or None
1121 new_line_no = line['new_lineno'] or None
1123 old_line_no = line['old_lineno'] or None
1122 old_line_no = line['old_lineno'] or None
1124 return DiffLineNumber(old=old_line_no, new=new_line_no)
1123 return DiffLineNumber(old=old_line_no, new=new_line_no)
1125
1124
1126
1125
1127 class FileNotInDiffException(Exception):
1126 class FileNotInDiffException(Exception):
1128 """
1127 """
1129 Raised when the context for a missing file is requested.
1128 Raised when the context for a missing file is requested.
1130
1129
1131 If you request the context for a line in a file which is not part of the
1130 If you request the context for a line in a file which is not part of the
1132 given diff, then this exception is raised.
1131 given diff, then this exception is raised.
1133 """
1132 """
1134
1133
1135
1134
1136 class LineNotInDiffException(Exception):
1135 class LineNotInDiffException(Exception):
1137 """
1136 """
1138 Raised when the context for a missing line is requested.
1137 Raised when the context for a missing line is requested.
1139
1138
1140 If you request the context for a line in a file and this line is not
1139 If you request the context for a line in a file and this line is not
1141 part of the given diff, then this exception is raised.
1140 part of the given diff, then this exception is raised.
1142 """
1141 """
1143
1142
1144
1143
1145 class DiffLimitExceeded(Exception):
1144 class DiffLimitExceeded(Exception):
1146 pass
1145 pass
1147
1146
1148
1147
1149 # NOTE(marcink): if diffs.mako change, probably this
1148 # NOTE(marcink): if diffs.mako change, probably this
1150 # needs a bump to next version
1149 # needs a bump to next version
1151 CURRENT_DIFF_VERSION = 'v5'
1150 CURRENT_DIFF_VERSION = 'v5'
1152
1151
1153
1152
1154 def _cleanup_cache_file(cached_diff_file):
1153 def _cleanup_cache_file(cached_diff_file):
1155 # cleanup file to not store it "damaged"
1154 # cleanup file to not store it "damaged"
1156 try:
1155 try:
1157 os.remove(cached_diff_file)
1156 os.remove(cached_diff_file)
1158 except Exception:
1157 except Exception:
1159 log.exception('Failed to cleanup path %s', cached_diff_file)
1158 log.exception('Failed to cleanup path %s', cached_diff_file)
1160
1159
1161
1160
1162 def _get_compression_mode(cached_diff_file):
1161 def _get_compression_mode(cached_diff_file):
1163 mode = 'bz2'
1162 mode = 'bz2'
1164 if 'mode:plain' in cached_diff_file:
1163 if 'mode:plain' in cached_diff_file:
1165 mode = 'plain'
1164 mode = 'plain'
1166 elif 'mode:gzip' in cached_diff_file:
1165 elif 'mode:gzip' in cached_diff_file:
1167 mode = 'gzip'
1166 mode = 'gzip'
1168 return mode
1167 return mode
1169
1168
1170
1169
1171 def cache_diff(cached_diff_file, diff, commits):
1170 def cache_diff(cached_diff_file, diff, commits):
1172 compression_mode = _get_compression_mode(cached_diff_file)
1171 compression_mode = _get_compression_mode(cached_diff_file)
1173
1172
1174 struct = {
1173 struct = {
1175 'version': CURRENT_DIFF_VERSION,
1174 'version': CURRENT_DIFF_VERSION,
1176 'diff': diff,
1175 'diff': diff,
1177 'commits': commits
1176 'commits': commits
1178 }
1177 }
1179
1178
1180 start = time.time()
1179 start = time.time()
1181 try:
1180 try:
1182 if compression_mode == 'plain':
1181 if compression_mode == 'plain':
1183 with open(cached_diff_file, 'wb') as f:
1182 with open(cached_diff_file, 'wb') as f:
1184 pickle.dump(struct, f)
1183 pickle.dump(struct, f)
1185 elif compression_mode == 'gzip':
1184 elif compression_mode == 'gzip':
1186 with gzip.GzipFile(cached_diff_file, 'wb') as f:
1185 with gzip.GzipFile(cached_diff_file, 'wb') as f:
1187 pickle.dump(struct, f)
1186 pickle.dump(struct, f)
1188 else:
1187 else:
1189 with bz2.BZ2File(cached_diff_file, 'wb') as f:
1188 with bz2.BZ2File(cached_diff_file, 'wb') as f:
1190 pickle.dump(struct, f)
1189 pickle.dump(struct, f)
1191 except Exception:
1190 except Exception:
1192 log.warn('Failed to save cache', exc_info=True)
1191 log.warn('Failed to save cache', exc_info=True)
1193 _cleanup_cache_file(cached_diff_file)
1192 _cleanup_cache_file(cached_diff_file)
1194
1193
1195 log.debug('Saved diff cache under %s in %.4fs', cached_diff_file, time.time() - start)
1194 log.debug('Saved diff cache under %s in %.4fs', cached_diff_file, time.time() - start)
1196
1195
1197
1196
1198 def load_cached_diff(cached_diff_file):
1197 def load_cached_diff(cached_diff_file):
1199 compression_mode = _get_compression_mode(cached_diff_file)
1198 compression_mode = _get_compression_mode(cached_diff_file)
1200
1199
1201 default_struct = {
1200 default_struct = {
1202 'version': CURRENT_DIFF_VERSION,
1201 'version': CURRENT_DIFF_VERSION,
1203 'diff': None,
1202 'diff': None,
1204 'commits': None
1203 'commits': None
1205 }
1204 }
1206
1205
1207 has_cache = os.path.isfile(cached_diff_file)
1206 has_cache = os.path.isfile(cached_diff_file)
1208 if not has_cache:
1207 if not has_cache:
1209 log.debug('Reading diff cache file failed %s', cached_diff_file)
1208 log.debug('Reading diff cache file failed %s', cached_diff_file)
1210 return default_struct
1209 return default_struct
1211
1210
1212 data = None
1211 data = None
1213
1212
1214 start = time.time()
1213 start = time.time()
1215 try:
1214 try:
1216 if compression_mode == 'plain':
1215 if compression_mode == 'plain':
1217 with open(cached_diff_file, 'rb') as f:
1216 with open(cached_diff_file, 'rb') as f:
1218 data = pickle.load(f)
1217 data = pickle.load(f)
1219 elif compression_mode == 'gzip':
1218 elif compression_mode == 'gzip':
1220 with gzip.GzipFile(cached_diff_file, 'rb') as f:
1219 with gzip.GzipFile(cached_diff_file, 'rb') as f:
1221 data = pickle.load(f)
1220 data = pickle.load(f)
1222 else:
1221 else:
1223 with bz2.BZ2File(cached_diff_file, 'rb') as f:
1222 with bz2.BZ2File(cached_diff_file, 'rb') as f:
1224 data = pickle.load(f)
1223 data = pickle.load(f)
1225 except Exception:
1224 except Exception:
1226 log.warn('Failed to read diff cache file', exc_info=True)
1225 log.warn('Failed to read diff cache file', exc_info=True)
1227
1226
1228 if not data:
1227 if not data:
1229 data = default_struct
1228 data = default_struct
1230
1229
1231 if not isinstance(data, dict):
1230 if not isinstance(data, dict):
1232 # old version of data ?
1231 # old version of data ?
1233 data = default_struct
1232 data = default_struct
1234
1233
1235 # check version
1234 # check version
1236 if data.get('version') != CURRENT_DIFF_VERSION:
1235 if data.get('version') != CURRENT_DIFF_VERSION:
1237 # purge cache
1236 # purge cache
1238 _cleanup_cache_file(cached_diff_file)
1237 _cleanup_cache_file(cached_diff_file)
1239 return default_struct
1238 return default_struct
1240
1239
1241 log.debug('Loaded diff cache from %s in %.4fs', cached_diff_file, time.time() - start)
1240 log.debug('Loaded diff cache from %s in %.4fs', cached_diff_file, time.time() - start)
1242
1241
1243 return data
1242 return data
1244
1243
1245
1244
1246 def generate_diff_cache_key(*args):
1245 def generate_diff_cache_key(*args):
1247 """
1246 """
1248 Helper to generate a cache key using arguments
1247 Helper to generate a cache key using arguments
1249 """
1248 """
1250 def arg_mapper(input_param):
1249 def arg_mapper(input_param):
1251 input_param = safe_str(input_param)
1250 input_param = safe_str(input_param)
1252 # we cannot allow '/' in arguments since it would allow
1251 # we cannot allow '/' in arguments since it would allow
1253 # subdirectory usage
1252 # subdirectory usage
1254 input_param.replace('/', '_')
1253 input_param.replace('/', '_')
1255 return input_param or None # prevent empty string arguments
1254 return input_param or None # prevent empty string arguments
1256
1255
1257 return '_'.join([
1256 return '_'.join([
1258 '{}' for i in range(len(args))]).format(*map(arg_mapper, args))
1257 '{}' for i in range(len(args))]).format(*map(arg_mapper, args))
1259
1258
1260
1259
1261 def diff_cache_exist(cache_storage, *args):
1260 def diff_cache_exist(cache_storage, *args):
1262 """
1261 """
1263 Based on all generated arguments check and return a cache path
1262 Based on all generated arguments check and return a cache path
1264 """
1263 """
1265 args = list(args) + ['mode:gzip']
1264 args = list(args) + ['mode:gzip']
1266 cache_key = generate_diff_cache_key(*args)
1265 cache_key = generate_diff_cache_key(*args)
1267 cache_file_path = os.path.join(cache_storage, cache_key)
1266 cache_file_path = os.path.join(cache_storage, cache_key)
1268 # prevent path traversal attacks using some param that have e.g '../../'
1267 # prevent path traversal attacks using some param that have e.g '../../'
1269 if not os.path.abspath(cache_file_path).startswith(cache_storage):
1268 if not os.path.abspath(cache_file_path).startswith(cache_storage):
1270 raise ValueError('Final path must be within {}'.format(cache_storage))
1269 raise ValueError('Final path must be within {}'.format(cache_storage))
1271
1270
1272 return cache_file_path
1271 return cache_file_path
@@ -1,444 +1,444 b''
1 # Copyright (c) Django Software Foundation and individual contributors.
1 # Copyright (c) Django Software Foundation and individual contributors.
2 # All rights reserved.
2 # All rights reserved.
3 #
3 #
4 # Redistribution and use in source and binary forms, with or without modification,
4 # Redistribution and use in source and binary forms, with or without modification,
5 # are permitted provided that the following conditions are met:
5 # are permitted provided that the following conditions are met:
6 #
6 #
7 # 1. Redistributions of source code must retain the above copyright notice,
7 # 1. Redistributions of source code must retain the above copyright notice,
8 # this list of conditions and the following disclaimer.
8 # this list of conditions and the following disclaimer.
9 #
9 #
10 # 2. Redistributions in binary form must reproduce the above copyright
10 # 2. Redistributions in binary form must reproduce the above copyright
11 # notice, this list of conditions and the following disclaimer in the
11 # notice, this list of conditions and the following disclaimer in the
12 # documentation and/or other materials provided with the distribution.
12 # documentation and/or other materials provided with the distribution.
13 #
13 #
14 # 3. Neither the name of Django nor the names of its contributors may be used
14 # 3. Neither the name of Django nor the names of its contributors may be used
15 # to endorse or promote products derived from this software without
15 # to endorse or promote products derived from this software without
16 # specific prior written permission.
16 # specific prior written permission.
17 #
17 #
18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
19 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
20 # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20 # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
21 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
22 # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
22 # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
23 # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
23 # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
24 # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
24 # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
25 # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25 # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
28
29 """
29 """
30 For definitions of the different versions of RSS, see:
30 For definitions of the different versions of RSS, see:
31 http://web.archive.org/web/20110718035220/http://diveintomark.org/archives/2004/02/04/incompatible-rss
31 http://web.archive.org/web/20110718035220/http://diveintomark.org/archives/2004/02/04/incompatible-rss
32 """
32 """
33
33
34
34
35 import datetime
35 import datetime
36 from io import StringIO
36 import io
37
37
38 import pytz
38 import pytz
39 from six.moves.urllib import parse as urlparse
39 from six.moves.urllib import parse as urlparse
40
40
41 from rhodecode.lib.feedgenerator import datetime_safe
41 from rhodecode.lib.feedgenerator import datetime_safe
42 from rhodecode.lib.feedgenerator.utils import SimplerXMLGenerator, iri_to_uri, force_text
42 from rhodecode.lib.feedgenerator.utils import SimplerXMLGenerator, iri_to_uri, force_text
43
43
44
44
45 #### The following code comes from ``django.utils.feedgenerator`` ####
45 #### The following code comes from ``django.utils.feedgenerator`` ####
46
46
47
47
48 def rfc2822_date(date):
48 def rfc2822_date(date):
49 # We can't use strftime() because it produces locale-dependent results, so
49 # We can't use strftime() because it produces locale-dependent results, so
50 # we have to map english month and day names manually
50 # we have to map english month and day names manually
51 months = ('Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec',)
51 months = ('Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec',)
52 days = ('Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun')
52 days = ('Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun')
53 # Support datetime objects older than 1900
53 # Support datetime objects older than 1900
54 date = datetime_safe.new_datetime(date)
54 date = datetime_safe.new_datetime(date)
55 # We do this ourselves to be timezone aware, email.Utils is not tz aware.
55 # We do this ourselves to be timezone aware, email.Utils is not tz aware.
56 dow = days[date.weekday()]
56 dow = days[date.weekday()]
57 month = months[date.month - 1]
57 month = months[date.month - 1]
58 time_str = date.strftime('%s, %%d %s %%Y %%H:%%M:%%S ' % (dow, month))
58 time_str = date.strftime('%s, %%d %s %%Y %%H:%%M:%%S ' % (dow, month))
59
59
60 offset = date.utcoffset()
60 offset = date.utcoffset()
61 # Historically, this function assumes that naive datetimes are in UTC.
61 # Historically, this function assumes that naive datetimes are in UTC.
62 if offset is None:
62 if offset is None:
63 return time_str + '-0000'
63 return time_str + '-0000'
64 else:
64 else:
65 timezone = (offset.days * 24 * 60) + (offset.seconds // 60)
65 timezone = (offset.days * 24 * 60) + (offset.seconds // 60)
66 hour, minute = divmod(timezone, 60)
66 hour, minute = divmod(timezone, 60)
67 return time_str + '%+03d%02d' % (hour, minute)
67 return time_str + '%+03d%02d' % (hour, minute)
68
68
69
69
70 def rfc3339_date(date):
70 def rfc3339_date(date):
71 # Support datetime objects older than 1900
71 # Support datetime objects older than 1900
72 date = datetime_safe.new_datetime(date)
72 date = datetime_safe.new_datetime(date)
73 time_str = date.strftime('%Y-%m-%dT%H:%M:%S')
73 time_str = date.strftime('%Y-%m-%dT%H:%M:%S')
74
74
75 offset = date.utcoffset()
75 offset = date.utcoffset()
76 # Historically, this function assumes that naive datetimes are in UTC.
76 # Historically, this function assumes that naive datetimes are in UTC.
77 if offset is None:
77 if offset is None:
78 return time_str + 'Z'
78 return time_str + 'Z'
79 else:
79 else:
80 timezone = (offset.days * 24 * 60) + (offset.seconds // 60)
80 timezone = (offset.days * 24 * 60) + (offset.seconds // 60)
81 hour, minute = divmod(timezone, 60)
81 hour, minute = divmod(timezone, 60)
82 return time_str + '%+03d:%02d' % (hour, minute)
82 return time_str + '%+03d:%02d' % (hour, minute)
83
83
84
84
85 def get_tag_uri(url, date):
85 def get_tag_uri(url, date):
86 """
86 """
87 Creates a TagURI.
87 Creates a TagURI.
88
88
89 See http://web.archive.org/web/20110514113830/http://diveintomark.org/archives/2004/05/28/howto-atom-id
89 See http://web.archive.org/web/20110514113830/http://diveintomark.org/archives/2004/05/28/howto-atom-id
90 """
90 """
91 bits = urlparse(url)
91 bits = urlparse(url)
92 d = ''
92 d = ''
93 if date is not None:
93 if date is not None:
94 d = ',%s' % datetime_safe.new_datetime(date).strftime('%Y-%m-%d')
94 d = ',%s' % datetime_safe.new_datetime(date).strftime('%Y-%m-%d')
95 return 'tag:%s%s:%s/%s' % (bits.hostname, d, bits.path, bits.fragment)
95 return 'tag:%s%s:%s/%s' % (bits.hostname, d, bits.path, bits.fragment)
96
96
97
97
98 class SyndicationFeed(object):
98 class SyndicationFeed(object):
99 """Base class for all syndication feeds. Subclasses should provide write()"""
99 """Base class for all syndication feeds. Subclasses should provide write()"""
100
100
101 def __init__(self, title, link, description, language=None, author_email=None,
101 def __init__(self, title, link, description, language=None, author_email=None,
102 author_name=None, author_link=None, subtitle=None, categories=None,
102 author_name=None, author_link=None, subtitle=None, categories=None,
103 feed_url=None, feed_copyright=None, feed_guid=None, ttl=None, **kwargs):
103 feed_url=None, feed_copyright=None, feed_guid=None, ttl=None, **kwargs):
104 def to_unicode(s):
104 def to_unicode(s):
105 return force_text(s, strings_only=True)
105 return force_text(s, strings_only=True)
106 if categories:
106 if categories:
107 categories = [force_text(c) for c in categories]
107 categories = [force_text(c) for c in categories]
108 if ttl is not None:
108 if ttl is not None:
109 # Force ints to unicode
109 # Force ints to unicode
110 ttl = force_text(ttl)
110 ttl = force_text(ttl)
111 self.feed = {
111 self.feed = {
112 'title': to_unicode(title),
112 'title': to_unicode(title),
113 'link': iri_to_uri(link),
113 'link': iri_to_uri(link),
114 'description': to_unicode(description),
114 'description': to_unicode(description),
115 'language': to_unicode(language),
115 'language': to_unicode(language),
116 'author_email': to_unicode(author_email),
116 'author_email': to_unicode(author_email),
117 'author_name': to_unicode(author_name),
117 'author_name': to_unicode(author_name),
118 'author_link': iri_to_uri(author_link),
118 'author_link': iri_to_uri(author_link),
119 'subtitle': to_unicode(subtitle),
119 'subtitle': to_unicode(subtitle),
120 'categories': categories or (),
120 'categories': categories or (),
121 'feed_url': iri_to_uri(feed_url),
121 'feed_url': iri_to_uri(feed_url),
122 'feed_copyright': to_unicode(feed_copyright),
122 'feed_copyright': to_unicode(feed_copyright),
123 'id': feed_guid or link,
123 'id': feed_guid or link,
124 'ttl': ttl,
124 'ttl': ttl,
125 }
125 }
126 self.feed.update(kwargs)
126 self.feed.update(kwargs)
127 self.items = []
127 self.items = []
128
128
129 def add_item(self, title, link, description, author_email=None,
129 def add_item(self, title, link, description, author_email=None,
130 author_name=None, author_link=None, pubdate=None, comments=None,
130 author_name=None, author_link=None, pubdate=None, comments=None,
131 unique_id=None, unique_id_is_permalink=None, enclosure=None,
131 unique_id=None, unique_id_is_permalink=None, enclosure=None,
132 categories=(), item_copyright=None, ttl=None, updateddate=None,
132 categories=(), item_copyright=None, ttl=None, updateddate=None,
133 enclosures=None, **kwargs):
133 enclosures=None, **kwargs):
134 """
134 """
135 Adds an item to the feed. All args are expected to be Python Unicode
135 Adds an item to the feed. All args are expected to be Python Unicode
136 objects except pubdate and updateddate, which are datetime.datetime
136 objects except pubdate and updateddate, which are datetime.datetime
137 objects, and enclosures, which is an iterable of instances of the
137 objects, and enclosures, which is an iterable of instances of the
138 Enclosure class.
138 Enclosure class.
139 """
139 """
140 def to_unicode(s):
140 def to_unicode(s):
141 return force_text(s, strings_only=True)
141 return force_text(s, strings_only=True)
142 if categories:
142 if categories:
143 categories = [to_unicode(c) for c in categories]
143 categories = [to_unicode(c) for c in categories]
144 if ttl is not None:
144 if ttl is not None:
145 # Force ints to unicode
145 # Force ints to unicode
146 ttl = force_text(ttl)
146 ttl = force_text(ttl)
147 if enclosure is None:
147 if enclosure is None:
148 enclosures = [] if enclosures is None else enclosures
148 enclosures = [] if enclosures is None else enclosures
149
149
150 item = {
150 item = {
151 'title': to_unicode(title),
151 'title': to_unicode(title),
152 'link': iri_to_uri(link),
152 'link': iri_to_uri(link),
153 'description': to_unicode(description),
153 'description': to_unicode(description),
154 'author_email': to_unicode(author_email),
154 'author_email': to_unicode(author_email),
155 'author_name': to_unicode(author_name),
155 'author_name': to_unicode(author_name),
156 'author_link': iri_to_uri(author_link),
156 'author_link': iri_to_uri(author_link),
157 'pubdate': pubdate,
157 'pubdate': pubdate,
158 'updateddate': updateddate,
158 'updateddate': updateddate,
159 'comments': to_unicode(comments),
159 'comments': to_unicode(comments),
160 'unique_id': to_unicode(unique_id),
160 'unique_id': to_unicode(unique_id),
161 'unique_id_is_permalink': unique_id_is_permalink,
161 'unique_id_is_permalink': unique_id_is_permalink,
162 'enclosures': enclosures,
162 'enclosures': enclosures,
163 'categories': categories or (),
163 'categories': categories or (),
164 'item_copyright': to_unicode(item_copyright),
164 'item_copyright': to_unicode(item_copyright),
165 'ttl': ttl,
165 'ttl': ttl,
166 }
166 }
167 item.update(kwargs)
167 item.update(kwargs)
168 self.items.append(item)
168 self.items.append(item)
169
169
170 def num_items(self):
170 def num_items(self):
171 return len(self.items)
171 return len(self.items)
172
172
173 def root_attributes(self):
173 def root_attributes(self):
174 """
174 """
175 Return extra attributes to place on the root (i.e. feed/channel) element.
175 Return extra attributes to place on the root (i.e. feed/channel) element.
176 Called from write().
176 Called from write().
177 """
177 """
178 return {}
178 return {}
179
179
180 def add_root_elements(self, handler):
180 def add_root_elements(self, handler):
181 """
181 """
182 Add elements in the root (i.e. feed/channel) element. Called
182 Add elements in the root (i.e. feed/channel) element. Called
183 from write().
183 from write().
184 """
184 """
185 pass
185 pass
186
186
187 def item_attributes(self, item):
187 def item_attributes(self, item):
188 """
188 """
189 Return extra attributes to place on each item (i.e. item/entry) element.
189 Return extra attributes to place on each item (i.e. item/entry) element.
190 """
190 """
191 return {}
191 return {}
192
192
193 def add_item_elements(self, handler, item):
193 def add_item_elements(self, handler, item):
194 """
194 """
195 Add elements on each item (i.e. item/entry) element.
195 Add elements on each item (i.e. item/entry) element.
196 """
196 """
197 pass
197 pass
198
198
199 def write(self, outfile, encoding):
199 def write(self, outfile, encoding):
200 """
200 """
201 Outputs the feed in the given encoding to outfile, which is a file-like
201 Outputs the feed in the given encoding to outfile, which is a file-like
202 object. Subclasses should override this.
202 object. Subclasses should override this.
203 """
203 """
204 raise NotImplementedError('subclasses of SyndicationFeed must provide a write() method')
204 raise NotImplementedError('subclasses of SyndicationFeed must provide a write() method')
205
205
206 def writeString(self, encoding):
206 def writeString(self, encoding):
207 """
207 """
208 Returns the feed in the given encoding as a string.
208 Returns the feed in the given encoding as a string.
209 """
209 """
210 s = StringIO()
210 s = io.StringIO()
211 self.write(s, encoding)
211 self.write(s, encoding)
212 return s.getvalue()
212 return s.getvalue()
213
213
214 def latest_post_date(self):
214 def latest_post_date(self):
215 """
215 """
216 Returns the latest item's pubdate or updateddate. If no items
216 Returns the latest item's pubdate or updateddate. If no items
217 have either of these attributes this returns the current UTC date/time.
217 have either of these attributes this returns the current UTC date/time.
218 """
218 """
219 latest_date = None
219 latest_date = None
220 date_keys = ('updateddate', 'pubdate')
220 date_keys = ('updateddate', 'pubdate')
221
221
222 for item in self.items:
222 for item in self.items:
223 for date_key in date_keys:
223 for date_key in date_keys:
224 item_date = item.get(date_key)
224 item_date = item.get(date_key)
225 if item_date:
225 if item_date:
226 if latest_date is None or item_date > latest_date:
226 if latest_date is None or item_date > latest_date:
227 latest_date = item_date
227 latest_date = item_date
228
228
229 # datetime.now(tz=utc) is slower, as documented in django.utils.timezone.now
229 # datetime.now(tz=utc) is slower, as documented in django.utils.timezone.now
230 return latest_date or datetime.datetime.utcnow().replace(tzinfo=pytz.utc)
230 return latest_date or datetime.datetime.utcnow().replace(tzinfo=pytz.utc)
231
231
232
232
233 class Enclosure(object):
233 class Enclosure(object):
234 """Represents an RSS enclosure"""
234 """Represents an RSS enclosure"""
235 def __init__(self, url, length, mime_type):
235 def __init__(self, url, length, mime_type):
236 """All args are expected to be Python Unicode objects"""
236 """All args are expected to be Python Unicode objects"""
237 self.length, self.mime_type = length, mime_type
237 self.length, self.mime_type = length, mime_type
238 self.url = iri_to_uri(url)
238 self.url = iri_to_uri(url)
239
239
240
240
241 class RssFeed(SyndicationFeed):
241 class RssFeed(SyndicationFeed):
242 content_type = 'application/rss+xml; charset=utf-8'
242 content_type = 'application/rss+xml; charset=utf-8'
243
243
244 def write(self, outfile, encoding):
244 def write(self, outfile, encoding):
245 handler = SimplerXMLGenerator(outfile, encoding)
245 handler = SimplerXMLGenerator(outfile, encoding)
246 handler.startDocument()
246 handler.startDocument()
247 handler.startElement("rss", self.rss_attributes())
247 handler.startElement("rss", self.rss_attributes())
248 handler.startElement("channel", self.root_attributes())
248 handler.startElement("channel", self.root_attributes())
249 self.add_root_elements(handler)
249 self.add_root_elements(handler)
250 self.write_items(handler)
250 self.write_items(handler)
251 self.endChannelElement(handler)
251 self.endChannelElement(handler)
252 handler.endElement("rss")
252 handler.endElement("rss")
253
253
254 def rss_attributes(self):
254 def rss_attributes(self):
255 return {"version": self._version,
255 return {"version": self._version,
256 "xmlns:atom": "http://www.w3.org/2005/Atom"}
256 "xmlns:atom": "http://www.w3.org/2005/Atom"}
257
257
258 def write_items(self, handler):
258 def write_items(self, handler):
259 for item in self.items:
259 for item in self.items:
260 handler.startElement('item', self.item_attributes(item))
260 handler.startElement('item', self.item_attributes(item))
261 self.add_item_elements(handler, item)
261 self.add_item_elements(handler, item)
262 handler.endElement("item")
262 handler.endElement("item")
263
263
264 def add_root_elements(self, handler):
264 def add_root_elements(self, handler):
265 handler.addQuickElement("title", self.feed['title'])
265 handler.addQuickElement("title", self.feed['title'])
266 handler.addQuickElement("link", self.feed['link'])
266 handler.addQuickElement("link", self.feed['link'])
267 handler.addQuickElement("description", self.feed['description'])
267 handler.addQuickElement("description", self.feed['description'])
268 if self.feed['feed_url'] is not None:
268 if self.feed['feed_url'] is not None:
269 handler.addQuickElement("atom:link", None, {"rel": "self", "href": self.feed['feed_url']})
269 handler.addQuickElement("atom:link", None, {"rel": "self", "href": self.feed['feed_url']})
270 if self.feed['language'] is not None:
270 if self.feed['language'] is not None:
271 handler.addQuickElement("language", self.feed['language'])
271 handler.addQuickElement("language", self.feed['language'])
272 for cat in self.feed['categories']:
272 for cat in self.feed['categories']:
273 handler.addQuickElement("category", cat)
273 handler.addQuickElement("category", cat)
274 if self.feed['feed_copyright'] is not None:
274 if self.feed['feed_copyright'] is not None:
275 handler.addQuickElement("copyright", self.feed['feed_copyright'])
275 handler.addQuickElement("copyright", self.feed['feed_copyright'])
276 handler.addQuickElement("lastBuildDate", rfc2822_date(self.latest_post_date()))
276 handler.addQuickElement("lastBuildDate", rfc2822_date(self.latest_post_date()))
277 if self.feed['ttl'] is not None:
277 if self.feed['ttl'] is not None:
278 handler.addQuickElement("ttl", self.feed['ttl'])
278 handler.addQuickElement("ttl", self.feed['ttl'])
279
279
280 def endChannelElement(self, handler):
280 def endChannelElement(self, handler):
281 handler.endElement("channel")
281 handler.endElement("channel")
282
282
283
283
284 class RssUserland091Feed(RssFeed):
284 class RssUserland091Feed(RssFeed):
285 _version = "0.91"
285 _version = "0.91"
286
286
287 def add_item_elements(self, handler, item):
287 def add_item_elements(self, handler, item):
288 handler.addQuickElement("title", item['title'])
288 handler.addQuickElement("title", item['title'])
289 handler.addQuickElement("link", item['link'])
289 handler.addQuickElement("link", item['link'])
290 if item['description'] is not None:
290 if item['description'] is not None:
291 handler.addQuickElement("description", item['description'])
291 handler.addQuickElement("description", item['description'])
292
292
293
293
294 class Rss201rev2Feed(RssFeed):
294 class Rss201rev2Feed(RssFeed):
295 # Spec: http://blogs.law.harvard.edu/tech/rss
295 # Spec: http://blogs.law.harvard.edu/tech/rss
296 _version = "2.0"
296 _version = "2.0"
297
297
298 def add_item_elements(self, handler, item):
298 def add_item_elements(self, handler, item):
299 handler.addQuickElement("title", item['title'])
299 handler.addQuickElement("title", item['title'])
300 handler.addQuickElement("link", item['link'])
300 handler.addQuickElement("link", item['link'])
301 if item['description'] is not None:
301 if item['description'] is not None:
302 handler.addQuickElement("description", item['description'])
302 handler.addQuickElement("description", item['description'])
303
303
304 # Author information.
304 # Author information.
305 if item["author_name"] and item["author_email"]:
305 if item["author_name"] and item["author_email"]:
306 handler.addQuickElement("author", "%s (%s)" % (item['author_email'], item['author_name']))
306 handler.addQuickElement("author", "%s (%s)" % (item['author_email'], item['author_name']))
307 elif item["author_email"]:
307 elif item["author_email"]:
308 handler.addQuickElement("author", item["author_email"])
308 handler.addQuickElement("author", item["author_email"])
309 elif item["author_name"]:
309 elif item["author_name"]:
310 handler.addQuickElement(
310 handler.addQuickElement(
311 "dc:creator", item["author_name"], {"xmlns:dc": "http://purl.org/dc/elements/1.1/"}
311 "dc:creator", item["author_name"], {"xmlns:dc": "http://purl.org/dc/elements/1.1/"}
312 )
312 )
313
313
314 if item['pubdate'] is not None:
314 if item['pubdate'] is not None:
315 handler.addQuickElement("pubDate", rfc2822_date(item['pubdate']))
315 handler.addQuickElement("pubDate", rfc2822_date(item['pubdate']))
316 if item['comments'] is not None:
316 if item['comments'] is not None:
317 handler.addQuickElement("comments", item['comments'])
317 handler.addQuickElement("comments", item['comments'])
318 if item['unique_id'] is not None:
318 if item['unique_id'] is not None:
319 guid_attrs = {}
319 guid_attrs = {}
320 if isinstance(item.get('unique_id_is_permalink'), bool):
320 if isinstance(item.get('unique_id_is_permalink'), bool):
321 guid_attrs['isPermaLink'] = str(item['unique_id_is_permalink']).lower()
321 guid_attrs['isPermaLink'] = str(item['unique_id_is_permalink']).lower()
322 handler.addQuickElement("guid", item['unique_id'], guid_attrs)
322 handler.addQuickElement("guid", item['unique_id'], guid_attrs)
323 if item['ttl'] is not None:
323 if item['ttl'] is not None:
324 handler.addQuickElement("ttl", item['ttl'])
324 handler.addQuickElement("ttl", item['ttl'])
325
325
326 # Enclosure.
326 # Enclosure.
327 if item['enclosures']:
327 if item['enclosures']:
328 enclosures = list(item['enclosures'])
328 enclosures = list(item['enclosures'])
329 if len(enclosures) > 1:
329 if len(enclosures) > 1:
330 raise ValueError(
330 raise ValueError(
331 "RSS feed items may only have one enclosure, see "
331 "RSS feed items may only have one enclosure, see "
332 "http://www.rssboard.org/rss-profile#element-channel-item-enclosure"
332 "http://www.rssboard.org/rss-profile#element-channel-item-enclosure"
333 )
333 )
334 enclosure = enclosures[0]
334 enclosure = enclosures[0]
335 handler.addQuickElement('enclosure', '', {
335 handler.addQuickElement('enclosure', '', {
336 'url': enclosure.url,
336 'url': enclosure.url,
337 'length': enclosure.length,
337 'length': enclosure.length,
338 'type': enclosure.mime_type,
338 'type': enclosure.mime_type,
339 })
339 })
340
340
341 # Categories.
341 # Categories.
342 for cat in item['categories']:
342 for cat in item['categories']:
343 handler.addQuickElement("category", cat)
343 handler.addQuickElement("category", cat)
344
344
345
345
346 class Atom1Feed(SyndicationFeed):
346 class Atom1Feed(SyndicationFeed):
347 # Spec: https://tools.ietf.org/html/rfc4287
347 # Spec: https://tools.ietf.org/html/rfc4287
348 content_type = 'application/atom+xml; charset=utf-8'
348 content_type = 'application/atom+xml; charset=utf-8'
349 ns = "http://www.w3.org/2005/Atom"
349 ns = "http://www.w3.org/2005/Atom"
350
350
351 def write(self, outfile, encoding):
351 def write(self, outfile, encoding):
352 handler = SimplerXMLGenerator(outfile, encoding)
352 handler = SimplerXMLGenerator(outfile, encoding)
353 handler.startDocument()
353 handler.startDocument()
354 handler.startElement('feed', self.root_attributes())
354 handler.startElement('feed', self.root_attributes())
355 self.add_root_elements(handler)
355 self.add_root_elements(handler)
356 self.write_items(handler)
356 self.write_items(handler)
357 handler.endElement("feed")
357 handler.endElement("feed")
358
358
359 def root_attributes(self):
359 def root_attributes(self):
360 if self.feed['language'] is not None:
360 if self.feed['language'] is not None:
361 return {"xmlns": self.ns, "xml:lang": self.feed['language']}
361 return {"xmlns": self.ns, "xml:lang": self.feed['language']}
362 else:
362 else:
363 return {"xmlns": self.ns}
363 return {"xmlns": self.ns}
364
364
365 def add_root_elements(self, handler):
365 def add_root_elements(self, handler):
366 handler.addQuickElement("title", self.feed['title'])
366 handler.addQuickElement("title", self.feed['title'])
367 handler.addQuickElement("link", "", {"rel": "alternate", "href": self.feed['link']})
367 handler.addQuickElement("link", "", {"rel": "alternate", "href": self.feed['link']})
368 if self.feed['feed_url'] is not None:
368 if self.feed['feed_url'] is not None:
369 handler.addQuickElement("link", "", {"rel": "self", "href": self.feed['feed_url']})
369 handler.addQuickElement("link", "", {"rel": "self", "href": self.feed['feed_url']})
370 handler.addQuickElement("id", self.feed['id'])
370 handler.addQuickElement("id", self.feed['id'])
371 handler.addQuickElement("updated", rfc3339_date(self.latest_post_date()))
371 handler.addQuickElement("updated", rfc3339_date(self.latest_post_date()))
372 if self.feed['author_name'] is not None:
372 if self.feed['author_name'] is not None:
373 handler.startElement("author", {})
373 handler.startElement("author", {})
374 handler.addQuickElement("name", self.feed['author_name'])
374 handler.addQuickElement("name", self.feed['author_name'])
375 if self.feed['author_email'] is not None:
375 if self.feed['author_email'] is not None:
376 handler.addQuickElement("email", self.feed['author_email'])
376 handler.addQuickElement("email", self.feed['author_email'])
377 if self.feed['author_link'] is not None:
377 if self.feed['author_link'] is not None:
378 handler.addQuickElement("uri", self.feed['author_link'])
378 handler.addQuickElement("uri", self.feed['author_link'])
379 handler.endElement("author")
379 handler.endElement("author")
380 if self.feed['subtitle'] is not None:
380 if self.feed['subtitle'] is not None:
381 handler.addQuickElement("subtitle", self.feed['subtitle'])
381 handler.addQuickElement("subtitle", self.feed['subtitle'])
382 for cat in self.feed['categories']:
382 for cat in self.feed['categories']:
383 handler.addQuickElement("category", "", {"term": cat})
383 handler.addQuickElement("category", "", {"term": cat})
384 if self.feed['feed_copyright'] is not None:
384 if self.feed['feed_copyright'] is not None:
385 handler.addQuickElement("rights", self.feed['feed_copyright'])
385 handler.addQuickElement("rights", self.feed['feed_copyright'])
386
386
387 def write_items(self, handler):
387 def write_items(self, handler):
388 for item in self.items:
388 for item in self.items:
389 handler.startElement("entry", self.item_attributes(item))
389 handler.startElement("entry", self.item_attributes(item))
390 self.add_item_elements(handler, item)
390 self.add_item_elements(handler, item)
391 handler.endElement("entry")
391 handler.endElement("entry")
392
392
393 def add_item_elements(self, handler, item):
393 def add_item_elements(self, handler, item):
394 handler.addQuickElement("title", item['title'])
394 handler.addQuickElement("title", item['title'])
395 handler.addQuickElement("link", "", {"href": item['link'], "rel": "alternate"})
395 handler.addQuickElement("link", "", {"href": item['link'], "rel": "alternate"})
396
396
397 if item['pubdate'] is not None:
397 if item['pubdate'] is not None:
398 handler.addQuickElement('published', rfc3339_date(item['pubdate']))
398 handler.addQuickElement('published', rfc3339_date(item['pubdate']))
399
399
400 if item['updateddate'] is not None:
400 if item['updateddate'] is not None:
401 handler.addQuickElement('updated', rfc3339_date(item['updateddate']))
401 handler.addQuickElement('updated', rfc3339_date(item['updateddate']))
402
402
403 # Author information.
403 # Author information.
404 if item['author_name'] is not None:
404 if item['author_name'] is not None:
405 handler.startElement("author", {})
405 handler.startElement("author", {})
406 handler.addQuickElement("name", item['author_name'])
406 handler.addQuickElement("name", item['author_name'])
407 if item['author_email'] is not None:
407 if item['author_email'] is not None:
408 handler.addQuickElement("email", item['author_email'])
408 handler.addQuickElement("email", item['author_email'])
409 if item['author_link'] is not None:
409 if item['author_link'] is not None:
410 handler.addQuickElement("uri", item['author_link'])
410 handler.addQuickElement("uri", item['author_link'])
411 handler.endElement("author")
411 handler.endElement("author")
412
412
413 # Unique ID.
413 # Unique ID.
414 if item['unique_id'] is not None:
414 if item['unique_id'] is not None:
415 unique_id = item['unique_id']
415 unique_id = item['unique_id']
416 else:
416 else:
417 unique_id = get_tag_uri(item['link'], item['pubdate'])
417 unique_id = get_tag_uri(item['link'], item['pubdate'])
418 handler.addQuickElement("id", unique_id)
418 handler.addQuickElement("id", unique_id)
419
419
420 # Summary.
420 # Summary.
421 if item['description'] is not None:
421 if item['description'] is not None:
422 handler.addQuickElement("summary", item['description'], {"type": "html"})
422 handler.addQuickElement("summary", item['description'], {"type": "html"})
423
423
424 # Enclosures.
424 # Enclosures.
425 for enclosure in item['enclosures']:
425 for enclosure in item['enclosures']:
426 handler.addQuickElement('link', '', {
426 handler.addQuickElement('link', '', {
427 'rel': 'enclosure',
427 'rel': 'enclosure',
428 'href': enclosure.url,
428 'href': enclosure.url,
429 'length': enclosure.length,
429 'length': enclosure.length,
430 'type': enclosure.mime_type,
430 'type': enclosure.mime_type,
431 })
431 })
432
432
433 # Categories.
433 # Categories.
434 for cat in item['categories']:
434 for cat in item['categories']:
435 handler.addQuickElement("category", "", {"term": cat})
435 handler.addQuickElement("category", "", {"term": cat})
436
436
437 # Rights.
437 # Rights.
438 if item['item_copyright'] is not None:
438 if item['item_copyright'] is not None:
439 handler.addQuickElement("rights", item['item_copyright'])
439 handler.addQuickElement("rights", item['item_copyright'])
440
440
441
441
442 # This isolates the decision of what the system default is, so calling code can
442 # This isolates the decision of what the system default is, so calling code can
443 # do "feedgenerator.DefaultFeed" instead of "feedgenerator.Rss201rev2Feed".
443 # do "feedgenerator.DefaultFeed" instead of "feedgenerator.Rss201rev2Feed".
444 DefaultFeed = Rss201rev2Feed No newline at end of file
444 DefaultFeed = Rss201rev2Feed
@@ -1,538 +1,538 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2013-2020 RhodeCode GmbH
3 # Copyright (C) 2013-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21
21
22 """
22 """
23 Set of hooks run by RhodeCode Enterprise
23 Set of hooks run by RhodeCode Enterprise
24 """
24 """
25
25
26 import os
26 import os
27 import logging
27 import logging
28
28
29 import rhodecode
29 import rhodecode
30 from rhodecode import events
30 from rhodecode import events
31 from rhodecode.lib import helpers as h
31 from rhodecode.lib import helpers as h
32 from rhodecode.lib import audit_logger
32 from rhodecode.lib import audit_logger
33 from rhodecode.lib.utils2 import safe_str, user_agent_normalizer
33 from rhodecode.lib.utils2 import safe_str, user_agent_normalizer
34 from rhodecode.lib.exceptions import (
34 from rhodecode.lib.exceptions import (
35 HTTPLockedRC, HTTPBranchProtected, UserCreationError)
35 HTTPLockedRC, HTTPBranchProtected, UserCreationError)
36 from rhodecode.model.db import Repository, User
36 from rhodecode.model.db import Repository, User
37 from rhodecode.lib.statsd_client import StatsdClient
37 from rhodecode.lib.statsd_client import StatsdClient
38
38
39 log = logging.getLogger(__name__)
39 log = logging.getLogger(__name__)
40
40
41
41
42 class HookResponse(object):
42 class HookResponse(object):
43 def __init__(self, status, output):
43 def __init__(self, status, output):
44 self.status = status
44 self.status = status
45 self.output = output
45 self.output = output
46
46
47 def __add__(self, other):
47 def __add__(self, other):
48 other_status = getattr(other, 'status', 0)
48 other_status = getattr(other, 'status', 0)
49 new_status = max(self.status, other_status)
49 new_status = max(self.status, other_status)
50 other_output = getattr(other, 'output', '')
50 other_output = getattr(other, 'output', '')
51 new_output = self.output + other_output
51 new_output = self.output + other_output
52
52
53 return HookResponse(new_status, new_output)
53 return HookResponse(new_status, new_output)
54
54
55 def __bool__(self):
55 def __bool__(self):
56 return self.status == 0
56 return self.status == 0
57
57
58
58
59 def is_shadow_repo(extras):
59 def is_shadow_repo(extras):
60 """
60 """
61 Returns ``True`` if this is an action executed against a shadow repository.
61 Returns ``True`` if this is an action executed against a shadow repository.
62 """
62 """
63 return extras['is_shadow_repo']
63 return extras['is_shadow_repo']
64
64
65
65
66 def _get_scm_size(alias, root_path):
66 def _get_scm_size(alias, root_path):
67
67
68 if not alias.startswith('.'):
68 if not alias.startswith('.'):
69 alias += '.'
69 alias += '.'
70
70
71 size_scm, size_root = 0, 0
71 size_scm, size_root = 0, 0
72 for path, unused_dirs, files in os.walk(safe_str(root_path)):
72 for path, unused_dirs, files in os.walk(safe_str(root_path)):
73 if path.find(alias) != -1:
73 if path.find(alias) != -1:
74 for f in files:
74 for f in files:
75 try:
75 try:
76 size_scm += os.path.getsize(os.path.join(path, f))
76 size_scm += os.path.getsize(os.path.join(path, f))
77 except OSError:
77 except OSError:
78 pass
78 pass
79 else:
79 else:
80 for f in files:
80 for f in files:
81 try:
81 try:
82 size_root += os.path.getsize(os.path.join(path, f))
82 size_root += os.path.getsize(os.path.join(path, f))
83 except OSError:
83 except OSError:
84 pass
84 pass
85
85
86 size_scm_f = h.format_byte_size_binary(size_scm)
86 size_scm_f = h.format_byte_size_binary(size_scm)
87 size_root_f = h.format_byte_size_binary(size_root)
87 size_root_f = h.format_byte_size_binary(size_root)
88 size_total_f = h.format_byte_size_binary(size_root + size_scm)
88 size_total_f = h.format_byte_size_binary(size_root + size_scm)
89
89
90 return size_scm_f, size_root_f, size_total_f
90 return size_scm_f, size_root_f, size_total_f
91
91
92
92
93 # actual hooks called by Mercurial internally, and GIT by our Python Hooks
93 # actual hooks called by Mercurial internally, and GIT by our Python Hooks
94 def repo_size(extras):
94 def repo_size(extras):
95 """Present size of repository after push."""
95 """Present size of repository after push."""
96 repo = Repository.get_by_repo_name(extras.repository)
96 repo = Repository.get_by_repo_name(extras.repository)
97 vcs_part = safe_str(u'.%s' % repo.repo_type)
97 vcs_part = safe_str('.%s' % repo.repo_type)
98 size_vcs, size_root, size_total = _get_scm_size(vcs_part,
98 size_vcs, size_root, size_total = _get_scm_size(vcs_part,
99 repo.repo_full_path)
99 repo.repo_full_path)
100 msg = ('Repository `%s` size summary %s:%s repo:%s total:%s\n'
100 msg = ('Repository `%s` size summary %s:%s repo:%s total:%s\n'
101 % (repo.repo_name, vcs_part, size_vcs, size_root, size_total))
101 % (repo.repo_name, vcs_part, size_vcs, size_root, size_total))
102 return HookResponse(0, msg)
102 return HookResponse(0, msg)
103
103
104
104
105 def pre_push(extras):
105 def pre_push(extras):
106 """
106 """
107 Hook executed before pushing code.
107 Hook executed before pushing code.
108
108
109 It bans pushing when the repository is locked.
109 It bans pushing when the repository is locked.
110 """
110 """
111
111
112 user = User.get_by_username(extras.username)
112 user = User.get_by_username(extras.username)
113 output = ''
113 output = ''
114 if extras.locked_by[0] and user.user_id != int(extras.locked_by[0]):
114 if extras.locked_by[0] and user.user_id != int(extras.locked_by[0]):
115 locked_by = User.get(extras.locked_by[0]).username
115 locked_by = User.get(extras.locked_by[0]).username
116 reason = extras.locked_by[2]
116 reason = extras.locked_by[2]
117 # this exception is interpreted in git/hg middlewares and based
117 # this exception is interpreted in git/hg middlewares and based
118 # on that proper return code is server to client
118 # on that proper return code is server to client
119 _http_ret = HTTPLockedRC(
119 _http_ret = HTTPLockedRC(
120 _locked_by_explanation(extras.repository, locked_by, reason))
120 _locked_by_explanation(extras.repository, locked_by, reason))
121 if str(_http_ret.code).startswith('2'):
121 if str(_http_ret.code).startswith('2'):
122 # 2xx Codes don't raise exceptions
122 # 2xx Codes don't raise exceptions
123 output = _http_ret.title
123 output = _http_ret.title
124 else:
124 else:
125 raise _http_ret
125 raise _http_ret
126
126
127 hook_response = ''
127 hook_response = ''
128 if not is_shadow_repo(extras):
128 if not is_shadow_repo(extras):
129 if extras.commit_ids and extras.check_branch_perms:
129 if extras.commit_ids and extras.check_branch_perms:
130
130
131 auth_user = user.AuthUser()
131 auth_user = user.AuthUser()
132 repo = Repository.get_by_repo_name(extras.repository)
132 repo = Repository.get_by_repo_name(extras.repository)
133 affected_branches = []
133 affected_branches = []
134 if repo.repo_type == 'hg':
134 if repo.repo_type == 'hg':
135 for entry in extras.commit_ids:
135 for entry in extras.commit_ids:
136 if entry['type'] == 'branch':
136 if entry['type'] == 'branch':
137 is_forced = bool(entry['multiple_heads'])
137 is_forced = bool(entry['multiple_heads'])
138 affected_branches.append([entry['name'], is_forced])
138 affected_branches.append([entry['name'], is_forced])
139 elif repo.repo_type == 'git':
139 elif repo.repo_type == 'git':
140 for entry in extras.commit_ids:
140 for entry in extras.commit_ids:
141 if entry['type'] == 'heads':
141 if entry['type'] == 'heads':
142 is_forced = bool(entry['pruned_sha'])
142 is_forced = bool(entry['pruned_sha'])
143 affected_branches.append([entry['name'], is_forced])
143 affected_branches.append([entry['name'], is_forced])
144
144
145 for branch_name, is_forced in affected_branches:
145 for branch_name, is_forced in affected_branches:
146
146
147 rule, branch_perm = auth_user.get_rule_and_branch_permission(
147 rule, branch_perm = auth_user.get_rule_and_branch_permission(
148 extras.repository, branch_name)
148 extras.repository, branch_name)
149 if not branch_perm:
149 if not branch_perm:
150 # no branch permission found for this branch, just keep checking
150 # no branch permission found for this branch, just keep checking
151 continue
151 continue
152
152
153 if branch_perm == 'branch.push_force':
153 if branch_perm == 'branch.push_force':
154 continue
154 continue
155 elif branch_perm == 'branch.push' and is_forced is False:
155 elif branch_perm == 'branch.push' and is_forced is False:
156 continue
156 continue
157 elif branch_perm == 'branch.push' and is_forced is True:
157 elif branch_perm == 'branch.push' and is_forced is True:
158 halt_message = 'Branch `{}` changes rejected by rule {}. ' \
158 halt_message = 'Branch `{}` changes rejected by rule {}. ' \
159 'FORCE PUSH FORBIDDEN.'.format(branch_name, rule)
159 'FORCE PUSH FORBIDDEN.'.format(branch_name, rule)
160 else:
160 else:
161 halt_message = 'Branch `{}` changes rejected by rule {}.'.format(
161 halt_message = 'Branch `{}` changes rejected by rule {}.'.format(
162 branch_name, rule)
162 branch_name, rule)
163
163
164 if halt_message:
164 if halt_message:
165 _http_ret = HTTPBranchProtected(halt_message)
165 _http_ret = HTTPBranchProtected(halt_message)
166 raise _http_ret
166 raise _http_ret
167
167
168 # Propagate to external components. This is done after checking the
168 # Propagate to external components. This is done after checking the
169 # lock, for consistent behavior.
169 # lock, for consistent behavior.
170 hook_response = pre_push_extension(
170 hook_response = pre_push_extension(
171 repo_store_path=Repository.base_path(), **extras)
171 repo_store_path=Repository.base_path(), **extras)
172 events.trigger(events.RepoPrePushEvent(
172 events.trigger(events.RepoPrePushEvent(
173 repo_name=extras.repository, extras=extras))
173 repo_name=extras.repository, extras=extras))
174
174
175 return HookResponse(0, output) + hook_response
175 return HookResponse(0, output) + hook_response
176
176
177
177
178 def pre_pull(extras):
178 def pre_pull(extras):
179 """
179 """
180 Hook executed before pulling the code.
180 Hook executed before pulling the code.
181
181
182 It bans pulling when the repository is locked.
182 It bans pulling when the repository is locked.
183 """
183 """
184
184
185 output = ''
185 output = ''
186 if extras.locked_by[0]:
186 if extras.locked_by[0]:
187 locked_by = User.get(extras.locked_by[0]).username
187 locked_by = User.get(extras.locked_by[0]).username
188 reason = extras.locked_by[2]
188 reason = extras.locked_by[2]
189 # this exception is interpreted in git/hg middlewares and based
189 # this exception is interpreted in git/hg middlewares and based
190 # on that proper return code is server to client
190 # on that proper return code is server to client
191 _http_ret = HTTPLockedRC(
191 _http_ret = HTTPLockedRC(
192 _locked_by_explanation(extras.repository, locked_by, reason))
192 _locked_by_explanation(extras.repository, locked_by, reason))
193 if str(_http_ret.code).startswith('2'):
193 if str(_http_ret.code).startswith('2'):
194 # 2xx Codes don't raise exceptions
194 # 2xx Codes don't raise exceptions
195 output = _http_ret.title
195 output = _http_ret.title
196 else:
196 else:
197 raise _http_ret
197 raise _http_ret
198
198
199 # Propagate to external components. This is done after checking the
199 # Propagate to external components. This is done after checking the
200 # lock, for consistent behavior.
200 # lock, for consistent behavior.
201 hook_response = ''
201 hook_response = ''
202 if not is_shadow_repo(extras):
202 if not is_shadow_repo(extras):
203 extras.hook_type = extras.hook_type or 'pre_pull'
203 extras.hook_type = extras.hook_type or 'pre_pull'
204 hook_response = pre_pull_extension(
204 hook_response = pre_pull_extension(
205 repo_store_path=Repository.base_path(), **extras)
205 repo_store_path=Repository.base_path(), **extras)
206 events.trigger(events.RepoPrePullEvent(
206 events.trigger(events.RepoPrePullEvent(
207 repo_name=extras.repository, extras=extras))
207 repo_name=extras.repository, extras=extras))
208
208
209 return HookResponse(0, output) + hook_response
209 return HookResponse(0, output) + hook_response
210
210
211
211
212 def post_pull(extras):
212 def post_pull(extras):
213 """Hook executed after client pulls the code."""
213 """Hook executed after client pulls the code."""
214
214
215 audit_user = audit_logger.UserWrap(
215 audit_user = audit_logger.UserWrap(
216 username=extras.username,
216 username=extras.username,
217 ip_addr=extras.ip)
217 ip_addr=extras.ip)
218 repo = audit_logger.RepoWrap(repo_name=extras.repository)
218 repo = audit_logger.RepoWrap(repo_name=extras.repository)
219 audit_logger.store(
219 audit_logger.store(
220 'user.pull', action_data={'user_agent': extras.user_agent},
220 'user.pull', action_data={'user_agent': extras.user_agent},
221 user=audit_user, repo=repo, commit=True)
221 user=audit_user, repo=repo, commit=True)
222
222
223 statsd = StatsdClient.statsd
223 statsd = StatsdClient.statsd
224 if statsd:
224 if statsd:
225 statsd.incr('rhodecode_pull_total', tags=[
225 statsd.incr('rhodecode_pull_total', tags=[
226 'user-agent:{}'.format(user_agent_normalizer(extras.user_agent)),
226 'user-agent:{}'.format(user_agent_normalizer(extras.user_agent)),
227 ])
227 ])
228 output = ''
228 output = ''
229 # make lock is a tri state False, True, None. We only make lock on True
229 # make lock is a tri state False, True, None. We only make lock on True
230 if extras.make_lock is True and not is_shadow_repo(extras):
230 if extras.make_lock is True and not is_shadow_repo(extras):
231 user = User.get_by_username(extras.username)
231 user = User.get_by_username(extras.username)
232 Repository.lock(Repository.get_by_repo_name(extras.repository),
232 Repository.lock(Repository.get_by_repo_name(extras.repository),
233 user.user_id,
233 user.user_id,
234 lock_reason=Repository.LOCK_PULL)
234 lock_reason=Repository.LOCK_PULL)
235 msg = 'Made lock on repo `%s`' % (extras.repository,)
235 msg = 'Made lock on repo `%s`' % (extras.repository,)
236 output += msg
236 output += msg
237
237
238 if extras.locked_by[0]:
238 if extras.locked_by[0]:
239 locked_by = User.get(extras.locked_by[0]).username
239 locked_by = User.get(extras.locked_by[0]).username
240 reason = extras.locked_by[2]
240 reason = extras.locked_by[2]
241 _http_ret = HTTPLockedRC(
241 _http_ret = HTTPLockedRC(
242 _locked_by_explanation(extras.repository, locked_by, reason))
242 _locked_by_explanation(extras.repository, locked_by, reason))
243 if str(_http_ret.code).startswith('2'):
243 if str(_http_ret.code).startswith('2'):
244 # 2xx Codes don't raise exceptions
244 # 2xx Codes don't raise exceptions
245 output += _http_ret.title
245 output += _http_ret.title
246
246
247 # Propagate to external components.
247 # Propagate to external components.
248 hook_response = ''
248 hook_response = ''
249 if not is_shadow_repo(extras):
249 if not is_shadow_repo(extras):
250 extras.hook_type = extras.hook_type or 'post_pull'
250 extras.hook_type = extras.hook_type or 'post_pull'
251 hook_response = post_pull_extension(
251 hook_response = post_pull_extension(
252 repo_store_path=Repository.base_path(), **extras)
252 repo_store_path=Repository.base_path(), **extras)
253 events.trigger(events.RepoPullEvent(
253 events.trigger(events.RepoPullEvent(
254 repo_name=extras.repository, extras=extras))
254 repo_name=extras.repository, extras=extras))
255
255
256 return HookResponse(0, output) + hook_response
256 return HookResponse(0, output) + hook_response
257
257
258
258
259 def post_push(extras):
259 def post_push(extras):
260 """Hook executed after user pushes to the repository."""
260 """Hook executed after user pushes to the repository."""
261 commit_ids = extras.commit_ids
261 commit_ids = extras.commit_ids
262
262
263 # log the push call
263 # log the push call
264 audit_user = audit_logger.UserWrap(
264 audit_user = audit_logger.UserWrap(
265 username=extras.username, ip_addr=extras.ip)
265 username=extras.username, ip_addr=extras.ip)
266 repo = audit_logger.RepoWrap(repo_name=extras.repository)
266 repo = audit_logger.RepoWrap(repo_name=extras.repository)
267 audit_logger.store(
267 audit_logger.store(
268 'user.push', action_data={
268 'user.push', action_data={
269 'user_agent': extras.user_agent,
269 'user_agent': extras.user_agent,
270 'commit_ids': commit_ids[:400]},
270 'commit_ids': commit_ids[:400]},
271 user=audit_user, repo=repo, commit=True)
271 user=audit_user, repo=repo, commit=True)
272
272
273 statsd = StatsdClient.statsd
273 statsd = StatsdClient.statsd
274 if statsd:
274 if statsd:
275 statsd.incr('rhodecode_push_total', tags=[
275 statsd.incr('rhodecode_push_total', tags=[
276 'user-agent:{}'.format(user_agent_normalizer(extras.user_agent)),
276 'user-agent:{}'.format(user_agent_normalizer(extras.user_agent)),
277 ])
277 ])
278
278
279 # Propagate to external components.
279 # Propagate to external components.
280 output = ''
280 output = ''
281 # make lock is a tri state False, True, None. We only release lock on False
281 # make lock is a tri state False, True, None. We only release lock on False
282 if extras.make_lock is False and not is_shadow_repo(extras):
282 if extras.make_lock is False and not is_shadow_repo(extras):
283 Repository.unlock(Repository.get_by_repo_name(extras.repository))
283 Repository.unlock(Repository.get_by_repo_name(extras.repository))
284 msg = 'Released lock on repo `{}`\n'.format(safe_str(extras.repository))
284 msg = 'Released lock on repo `{}`\n'.format(safe_str(extras.repository))
285 output += msg
285 output += msg
286
286
287 if extras.locked_by[0]:
287 if extras.locked_by[0]:
288 locked_by = User.get(extras.locked_by[0]).username
288 locked_by = User.get(extras.locked_by[0]).username
289 reason = extras.locked_by[2]
289 reason = extras.locked_by[2]
290 _http_ret = HTTPLockedRC(
290 _http_ret = HTTPLockedRC(
291 _locked_by_explanation(extras.repository, locked_by, reason))
291 _locked_by_explanation(extras.repository, locked_by, reason))
292 # TODO: johbo: if not?
292 # TODO: johbo: if not?
293 if str(_http_ret.code).startswith('2'):
293 if str(_http_ret.code).startswith('2'):
294 # 2xx Codes don't raise exceptions
294 # 2xx Codes don't raise exceptions
295 output += _http_ret.title
295 output += _http_ret.title
296
296
297 if extras.new_refs:
297 if extras.new_refs:
298 tmpl = '{}/{}/pull-request/new?{{ref_type}}={{ref_name}}'.format(
298 tmpl = '{}/{}/pull-request/new?{{ref_type}}={{ref_name}}'.format(
299 safe_str(extras.server_url), safe_str(extras.repository))
299 safe_str(extras.server_url), safe_str(extras.repository))
300
300
301 for branch_name in extras.new_refs['branches']:
301 for branch_name in extras.new_refs['branches']:
302 output += 'RhodeCode: open pull request link: {}\n'.format(
302 output += 'RhodeCode: open pull request link: {}\n'.format(
303 tmpl.format(ref_type='branch', ref_name=safe_str(branch_name)))
303 tmpl.format(ref_type='branch', ref_name=safe_str(branch_name)))
304
304
305 for book_name in extras.new_refs['bookmarks']:
305 for book_name in extras.new_refs['bookmarks']:
306 output += 'RhodeCode: open pull request link: {}\n'.format(
306 output += 'RhodeCode: open pull request link: {}\n'.format(
307 tmpl.format(ref_type='bookmark', ref_name=safe_str(book_name)))
307 tmpl.format(ref_type='bookmark', ref_name=safe_str(book_name)))
308
308
309 hook_response = ''
309 hook_response = ''
310 if not is_shadow_repo(extras):
310 if not is_shadow_repo(extras):
311 hook_response = post_push_extension(
311 hook_response = post_push_extension(
312 repo_store_path=Repository.base_path(),
312 repo_store_path=Repository.base_path(),
313 **extras)
313 **extras)
314 events.trigger(events.RepoPushEvent(
314 events.trigger(events.RepoPushEvent(
315 repo_name=extras.repository, pushed_commit_ids=commit_ids, extras=extras))
315 repo_name=extras.repository, pushed_commit_ids=commit_ids, extras=extras))
316
316
317 output += 'RhodeCode: push completed\n'
317 output += 'RhodeCode: push completed\n'
318 return HookResponse(0, output) + hook_response
318 return HookResponse(0, output) + hook_response
319
319
320
320
321 def _locked_by_explanation(repo_name, user_name, reason):
321 def _locked_by_explanation(repo_name, user_name, reason):
322 message = (
322 message = (
323 'Repository `%s` locked by user `%s`. Reason:`%s`'
323 'Repository `%s` locked by user `%s`. Reason:`%s`'
324 % (repo_name, user_name, reason))
324 % (repo_name, user_name, reason))
325 return message
325 return message
326
326
327
327
328 def check_allowed_create_user(user_dict, created_by, **kwargs):
328 def check_allowed_create_user(user_dict, created_by, **kwargs):
329 # pre create hooks
329 # pre create hooks
330 if pre_create_user.is_active():
330 if pre_create_user.is_active():
331 hook_result = pre_create_user(created_by=created_by, **user_dict)
331 hook_result = pre_create_user(created_by=created_by, **user_dict)
332 allowed = hook_result.status == 0
332 allowed = hook_result.status == 0
333 if not allowed:
333 if not allowed:
334 reason = hook_result.output
334 reason = hook_result.output
335 raise UserCreationError(reason)
335 raise UserCreationError(reason)
336
336
337
337
338 class ExtensionCallback(object):
338 class ExtensionCallback(object):
339 """
339 """
340 Forwards a given call to rcextensions, sanitizes keyword arguments.
340 Forwards a given call to rcextensions, sanitizes keyword arguments.
341
341
342 Does check if there is an extension active for that hook. If it is
342 Does check if there is an extension active for that hook. If it is
343 there, it will forward all `kwargs_keys` keyword arguments to the
343 there, it will forward all `kwargs_keys` keyword arguments to the
344 extension callback.
344 extension callback.
345 """
345 """
346
346
347 def __init__(self, hook_name, kwargs_keys):
347 def __init__(self, hook_name, kwargs_keys):
348 self._hook_name = hook_name
348 self._hook_name = hook_name
349 self._kwargs_keys = set(kwargs_keys)
349 self._kwargs_keys = set(kwargs_keys)
350
350
351 def __call__(self, *args, **kwargs):
351 def __call__(self, *args, **kwargs):
352 log.debug('Calling extension callback for `%s`', self._hook_name)
352 log.debug('Calling extension callback for `%s`', self._hook_name)
353 callback = self._get_callback()
353 callback = self._get_callback()
354 if not callback:
354 if not callback:
355 log.debug('extension callback `%s` not found, skipping...', self._hook_name)
355 log.debug('extension callback `%s` not found, skipping...', self._hook_name)
356 return
356 return
357
357
358 kwargs_to_pass = {}
358 kwargs_to_pass = {}
359 for key in self._kwargs_keys:
359 for key in self._kwargs_keys:
360 try:
360 try:
361 kwargs_to_pass[key] = kwargs[key]
361 kwargs_to_pass[key] = kwargs[key]
362 except KeyError:
362 except KeyError:
363 log.error('Failed to fetch %s key from given kwargs. '
363 log.error('Failed to fetch %s key from given kwargs. '
364 'Expected keys: %s', key, self._kwargs_keys)
364 'Expected keys: %s', key, self._kwargs_keys)
365 raise
365 raise
366
366
367 # backward compat for removed api_key for old hooks. This was it works
367 # backward compat for removed api_key for old hooks. This was it works
368 # with older rcextensions that require api_key present
368 # with older rcextensions that require api_key present
369 if self._hook_name in ['CREATE_USER_HOOK', 'DELETE_USER_HOOK']:
369 if self._hook_name in ['CREATE_USER_HOOK', 'DELETE_USER_HOOK']:
370 kwargs_to_pass['api_key'] = '_DEPRECATED_'
370 kwargs_to_pass['api_key'] = '_DEPRECATED_'
371 return callback(**kwargs_to_pass)
371 return callback(**kwargs_to_pass)
372
372
373 def is_active(self):
373 def is_active(self):
374 return hasattr(rhodecode.EXTENSIONS, self._hook_name)
374 return hasattr(rhodecode.EXTENSIONS, self._hook_name)
375
375
376 def _get_callback(self):
376 def _get_callback(self):
377 return getattr(rhodecode.EXTENSIONS, self._hook_name, None)
377 return getattr(rhodecode.EXTENSIONS, self._hook_name, None)
378
378
379
379
380 pre_pull_extension = ExtensionCallback(
380 pre_pull_extension = ExtensionCallback(
381 hook_name='PRE_PULL_HOOK',
381 hook_name='PRE_PULL_HOOK',
382 kwargs_keys=(
382 kwargs_keys=(
383 'server_url', 'config', 'scm', 'username', 'ip', 'action',
383 'server_url', 'config', 'scm', 'username', 'ip', 'action',
384 'repository', 'hook_type', 'user_agent', 'repo_store_path',))
384 'repository', 'hook_type', 'user_agent', 'repo_store_path',))
385
385
386
386
387 post_pull_extension = ExtensionCallback(
387 post_pull_extension = ExtensionCallback(
388 hook_name='PULL_HOOK',
388 hook_name='PULL_HOOK',
389 kwargs_keys=(
389 kwargs_keys=(
390 'server_url', 'config', 'scm', 'username', 'ip', 'action',
390 'server_url', 'config', 'scm', 'username', 'ip', 'action',
391 'repository', 'hook_type', 'user_agent', 'repo_store_path',))
391 'repository', 'hook_type', 'user_agent', 'repo_store_path',))
392
392
393
393
394 pre_push_extension = ExtensionCallback(
394 pre_push_extension = ExtensionCallback(
395 hook_name='PRE_PUSH_HOOK',
395 hook_name='PRE_PUSH_HOOK',
396 kwargs_keys=(
396 kwargs_keys=(
397 'server_url', 'config', 'scm', 'username', 'ip', 'action',
397 'server_url', 'config', 'scm', 'username', 'ip', 'action',
398 'repository', 'repo_store_path', 'commit_ids', 'hook_type', 'user_agent',))
398 'repository', 'repo_store_path', 'commit_ids', 'hook_type', 'user_agent',))
399
399
400
400
401 post_push_extension = ExtensionCallback(
401 post_push_extension = ExtensionCallback(
402 hook_name='PUSH_HOOK',
402 hook_name='PUSH_HOOK',
403 kwargs_keys=(
403 kwargs_keys=(
404 'server_url', 'config', 'scm', 'username', 'ip', 'action',
404 'server_url', 'config', 'scm', 'username', 'ip', 'action',
405 'repository', 'repo_store_path', 'commit_ids', 'hook_type', 'user_agent',))
405 'repository', 'repo_store_path', 'commit_ids', 'hook_type', 'user_agent',))
406
406
407
407
408 pre_create_user = ExtensionCallback(
408 pre_create_user = ExtensionCallback(
409 hook_name='PRE_CREATE_USER_HOOK',
409 hook_name='PRE_CREATE_USER_HOOK',
410 kwargs_keys=(
410 kwargs_keys=(
411 'username', 'password', 'email', 'firstname', 'lastname', 'active',
411 'username', 'password', 'email', 'firstname', 'lastname', 'active',
412 'admin', 'created_by'))
412 'admin', 'created_by'))
413
413
414
414
415 create_pull_request = ExtensionCallback(
415 create_pull_request = ExtensionCallback(
416 hook_name='CREATE_PULL_REQUEST',
416 hook_name='CREATE_PULL_REQUEST',
417 kwargs_keys=(
417 kwargs_keys=(
418 'server_url', 'config', 'scm', 'username', 'ip', 'action',
418 'server_url', 'config', 'scm', 'username', 'ip', 'action',
419 'repository', 'pull_request_id', 'url', 'title', 'description',
419 'repository', 'pull_request_id', 'url', 'title', 'description',
420 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
420 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
421 'mergeable', 'source', 'target', 'author', 'reviewers'))
421 'mergeable', 'source', 'target', 'author', 'reviewers'))
422
422
423
423
424 merge_pull_request = ExtensionCallback(
424 merge_pull_request = ExtensionCallback(
425 hook_name='MERGE_PULL_REQUEST',
425 hook_name='MERGE_PULL_REQUEST',
426 kwargs_keys=(
426 kwargs_keys=(
427 'server_url', 'config', 'scm', 'username', 'ip', 'action',
427 'server_url', 'config', 'scm', 'username', 'ip', 'action',
428 'repository', 'pull_request_id', 'url', 'title', 'description',
428 'repository', 'pull_request_id', 'url', 'title', 'description',
429 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
429 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
430 'mergeable', 'source', 'target', 'author', 'reviewers'))
430 'mergeable', 'source', 'target', 'author', 'reviewers'))
431
431
432
432
433 close_pull_request = ExtensionCallback(
433 close_pull_request = ExtensionCallback(
434 hook_name='CLOSE_PULL_REQUEST',
434 hook_name='CLOSE_PULL_REQUEST',
435 kwargs_keys=(
435 kwargs_keys=(
436 'server_url', 'config', 'scm', 'username', 'ip', 'action',
436 'server_url', 'config', 'scm', 'username', 'ip', 'action',
437 'repository', 'pull_request_id', 'url', 'title', 'description',
437 'repository', 'pull_request_id', 'url', 'title', 'description',
438 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
438 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
439 'mergeable', 'source', 'target', 'author', 'reviewers'))
439 'mergeable', 'source', 'target', 'author', 'reviewers'))
440
440
441
441
442 review_pull_request = ExtensionCallback(
442 review_pull_request = ExtensionCallback(
443 hook_name='REVIEW_PULL_REQUEST',
443 hook_name='REVIEW_PULL_REQUEST',
444 kwargs_keys=(
444 kwargs_keys=(
445 'server_url', 'config', 'scm', 'username', 'ip', 'action',
445 'server_url', 'config', 'scm', 'username', 'ip', 'action',
446 'repository', 'pull_request_id', 'url', 'title', 'description',
446 'repository', 'pull_request_id', 'url', 'title', 'description',
447 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
447 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
448 'mergeable', 'source', 'target', 'author', 'reviewers'))
448 'mergeable', 'source', 'target', 'author', 'reviewers'))
449
449
450
450
451 comment_pull_request = ExtensionCallback(
451 comment_pull_request = ExtensionCallback(
452 hook_name='COMMENT_PULL_REQUEST',
452 hook_name='COMMENT_PULL_REQUEST',
453 kwargs_keys=(
453 kwargs_keys=(
454 'server_url', 'config', 'scm', 'username', 'ip', 'action',
454 'server_url', 'config', 'scm', 'username', 'ip', 'action',
455 'repository', 'pull_request_id', 'url', 'title', 'description',
455 'repository', 'pull_request_id', 'url', 'title', 'description',
456 'status', 'comment', 'created_on', 'updated_on', 'commit_ids', 'review_status',
456 'status', 'comment', 'created_on', 'updated_on', 'commit_ids', 'review_status',
457 'mergeable', 'source', 'target', 'author', 'reviewers'))
457 'mergeable', 'source', 'target', 'author', 'reviewers'))
458
458
459
459
460 comment_edit_pull_request = ExtensionCallback(
460 comment_edit_pull_request = ExtensionCallback(
461 hook_name='COMMENT_EDIT_PULL_REQUEST',
461 hook_name='COMMENT_EDIT_PULL_REQUEST',
462 kwargs_keys=(
462 kwargs_keys=(
463 'server_url', 'config', 'scm', 'username', 'ip', 'action',
463 'server_url', 'config', 'scm', 'username', 'ip', 'action',
464 'repository', 'pull_request_id', 'url', 'title', 'description',
464 'repository', 'pull_request_id', 'url', 'title', 'description',
465 'status', 'comment', 'created_on', 'updated_on', 'commit_ids', 'review_status',
465 'status', 'comment', 'created_on', 'updated_on', 'commit_ids', 'review_status',
466 'mergeable', 'source', 'target', 'author', 'reviewers'))
466 'mergeable', 'source', 'target', 'author', 'reviewers'))
467
467
468
468
469 update_pull_request = ExtensionCallback(
469 update_pull_request = ExtensionCallback(
470 hook_name='UPDATE_PULL_REQUEST',
470 hook_name='UPDATE_PULL_REQUEST',
471 kwargs_keys=(
471 kwargs_keys=(
472 'server_url', 'config', 'scm', 'username', 'ip', 'action',
472 'server_url', 'config', 'scm', 'username', 'ip', 'action',
473 'repository', 'pull_request_id', 'url', 'title', 'description',
473 'repository', 'pull_request_id', 'url', 'title', 'description',
474 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
474 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
475 'mergeable', 'source', 'target', 'author', 'reviewers'))
475 'mergeable', 'source', 'target', 'author', 'reviewers'))
476
476
477
477
478 create_user = ExtensionCallback(
478 create_user = ExtensionCallback(
479 hook_name='CREATE_USER_HOOK',
479 hook_name='CREATE_USER_HOOK',
480 kwargs_keys=(
480 kwargs_keys=(
481 'username', 'full_name_or_username', 'full_contact', 'user_id',
481 'username', 'full_name_or_username', 'full_contact', 'user_id',
482 'name', 'firstname', 'short_contact', 'admin', 'lastname',
482 'name', 'firstname', 'short_contact', 'admin', 'lastname',
483 'ip_addresses', 'extern_type', 'extern_name',
483 'ip_addresses', 'extern_type', 'extern_name',
484 'email', 'api_keys', 'last_login',
484 'email', 'api_keys', 'last_login',
485 'full_name', 'active', 'password', 'emails',
485 'full_name', 'active', 'password', 'emails',
486 'inherit_default_permissions', 'created_by', 'created_on'))
486 'inherit_default_permissions', 'created_by', 'created_on'))
487
487
488
488
489 delete_user = ExtensionCallback(
489 delete_user = ExtensionCallback(
490 hook_name='DELETE_USER_HOOK',
490 hook_name='DELETE_USER_HOOK',
491 kwargs_keys=(
491 kwargs_keys=(
492 'username', 'full_name_or_username', 'full_contact', 'user_id',
492 'username', 'full_name_or_username', 'full_contact', 'user_id',
493 'name', 'firstname', 'short_contact', 'admin', 'lastname',
493 'name', 'firstname', 'short_contact', 'admin', 'lastname',
494 'ip_addresses',
494 'ip_addresses',
495 'email', 'last_login',
495 'email', 'last_login',
496 'full_name', 'active', 'password', 'emails',
496 'full_name', 'active', 'password', 'emails',
497 'inherit_default_permissions', 'deleted_by'))
497 'inherit_default_permissions', 'deleted_by'))
498
498
499
499
500 create_repository = ExtensionCallback(
500 create_repository = ExtensionCallback(
501 hook_name='CREATE_REPO_HOOK',
501 hook_name='CREATE_REPO_HOOK',
502 kwargs_keys=(
502 kwargs_keys=(
503 'repo_name', 'repo_type', 'description', 'private', 'created_on',
503 'repo_name', 'repo_type', 'description', 'private', 'created_on',
504 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
504 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
505 'clone_uri', 'fork_id', 'group_id', 'created_by'))
505 'clone_uri', 'fork_id', 'group_id', 'created_by'))
506
506
507
507
508 delete_repository = ExtensionCallback(
508 delete_repository = ExtensionCallback(
509 hook_name='DELETE_REPO_HOOK',
509 hook_name='DELETE_REPO_HOOK',
510 kwargs_keys=(
510 kwargs_keys=(
511 'repo_name', 'repo_type', 'description', 'private', 'created_on',
511 'repo_name', 'repo_type', 'description', 'private', 'created_on',
512 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
512 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
513 'clone_uri', 'fork_id', 'group_id', 'deleted_by', 'deleted_on'))
513 'clone_uri', 'fork_id', 'group_id', 'deleted_by', 'deleted_on'))
514
514
515
515
516 comment_commit_repository = ExtensionCallback(
516 comment_commit_repository = ExtensionCallback(
517 hook_name='COMMENT_COMMIT_REPO_HOOK',
517 hook_name='COMMENT_COMMIT_REPO_HOOK',
518 kwargs_keys=(
518 kwargs_keys=(
519 'repo_name', 'repo_type', 'description', 'private', 'created_on',
519 'repo_name', 'repo_type', 'description', 'private', 'created_on',
520 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
520 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
521 'clone_uri', 'fork_id', 'group_id',
521 'clone_uri', 'fork_id', 'group_id',
522 'repository', 'created_by', 'comment', 'commit'))
522 'repository', 'created_by', 'comment', 'commit'))
523
523
524 comment_edit_commit_repository = ExtensionCallback(
524 comment_edit_commit_repository = ExtensionCallback(
525 hook_name='COMMENT_EDIT_COMMIT_REPO_HOOK',
525 hook_name='COMMENT_EDIT_COMMIT_REPO_HOOK',
526 kwargs_keys=(
526 kwargs_keys=(
527 'repo_name', 'repo_type', 'description', 'private', 'created_on',
527 'repo_name', 'repo_type', 'description', 'private', 'created_on',
528 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
528 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
529 'clone_uri', 'fork_id', 'group_id',
529 'clone_uri', 'fork_id', 'group_id',
530 'repository', 'created_by', 'comment', 'commit'))
530 'repository', 'created_by', 'comment', 'commit'))
531
531
532
532
533 create_repository_group = ExtensionCallback(
533 create_repository_group = ExtensionCallback(
534 hook_name='CREATE_REPO_GROUP_HOOK',
534 hook_name='CREATE_REPO_GROUP_HOOK',
535 kwargs_keys=(
535 kwargs_keys=(
536 'group_name', 'group_parent_id', 'group_description',
536 'group_name', 'group_parent_id', 'group_description',
537 'group_id', 'user_id', 'created_by', 'created_on',
537 'group_id', 'user_id', 'created_by', 'created_on',
538 'enable_locking'))
538 'enable_locking'))
@@ -1,187 +1,187 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import sys
21 import sys
22 import logging
22 import logging
23
23
24
24
25 BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(30, 38)
25 BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = list(range(30, 38))
26
26
27 # Sequences
27 # Sequences
28 RESET_SEQ = "\033[0m"
28 RESET_SEQ = "\033[0m"
29 COLOR_SEQ = "\033[0;%dm"
29 COLOR_SEQ = "\033[0;%dm"
30 BOLD_SEQ = "\033[1m"
30 BOLD_SEQ = "\033[1m"
31
31
32 COLORS = {
32 COLORS = {
33 'CRITICAL': MAGENTA,
33 'CRITICAL': MAGENTA,
34 'ERROR': RED,
34 'ERROR': RED,
35 'WARNING': CYAN,
35 'WARNING': CYAN,
36 'INFO': GREEN,
36 'INFO': GREEN,
37 'DEBUG': BLUE,
37 'DEBUG': BLUE,
38 'SQL': YELLOW
38 'SQL': YELLOW
39 }
39 }
40
40
41
41
42 def _inject_req_id(record, with_prefix=True):
42 def _inject_req_id(record, with_prefix=True):
43 from pyramid.threadlocal import get_current_request
43 from pyramid.threadlocal import get_current_request
44 dummy = '00000000-0000-0000-0000-000000000000'
44 dummy = '00000000-0000-0000-0000-000000000000'
45 req_id = None
45 req_id = None
46
46
47 req = get_current_request()
47 req = get_current_request()
48 if req:
48 if req:
49 req_id = getattr(req, 'req_id', None)
49 req_id = getattr(req, 'req_id', None)
50 if with_prefix:
50 if with_prefix:
51 req_id = 'req_id:%-36s' % (req_id or dummy)
51 req_id = 'req_id:%-36s' % (req_id or dummy)
52 else:
52 else:
53 req_id = (req_id or dummy)
53 req_id = (req_id or dummy)
54 record.req_id = req_id
54 record.req_id = req_id
55
55
56
56
57 def _add_log_to_debug_bucket(formatted_record):
57 def _add_log_to_debug_bucket(formatted_record):
58 from pyramid.threadlocal import get_current_request
58 from pyramid.threadlocal import get_current_request
59 req = get_current_request()
59 req = get_current_request()
60 if req:
60 if req:
61 req.req_id_bucket.append(formatted_record)
61 req.req_id_bucket.append(formatted_record)
62
62
63
63
64 def one_space_trim(s):
64 def one_space_trim(s):
65 if s.find(" ") == -1:
65 if s.find(" ") == -1:
66 return s
66 return s
67 else:
67 else:
68 s = s.replace(' ', ' ')
68 s = s.replace(' ', ' ')
69 return one_space_trim(s)
69 return one_space_trim(s)
70
70
71
71
72 def format_sql(sql):
72 def format_sql(sql):
73 sql = sql.replace('\n', '')
73 sql = sql.replace('\n', '')
74 sql = one_space_trim(sql)
74 sql = one_space_trim(sql)
75 sql = sql\
75 sql = sql\
76 .replace(',', ',\n\t')\
76 .replace(',', ',\n\t')\
77 .replace('SELECT', '\n\tSELECT \n\t')\
77 .replace('SELECT', '\n\tSELECT \n\t')\
78 .replace('UPDATE', '\n\tUPDATE \n\t')\
78 .replace('UPDATE', '\n\tUPDATE \n\t')\
79 .replace('DELETE', '\n\tDELETE \n\t')\
79 .replace('DELETE', '\n\tDELETE \n\t')\
80 .replace('FROM', '\n\tFROM')\
80 .replace('FROM', '\n\tFROM')\
81 .replace('ORDER BY', '\n\tORDER BY')\
81 .replace('ORDER BY', '\n\tORDER BY')\
82 .replace('LIMIT', '\n\tLIMIT')\
82 .replace('LIMIT', '\n\tLIMIT')\
83 .replace('WHERE', '\n\tWHERE')\
83 .replace('WHERE', '\n\tWHERE')\
84 .replace('AND', '\n\tAND')\
84 .replace('AND', '\n\tAND')\
85 .replace('LEFT', '\n\tLEFT')\
85 .replace('LEFT', '\n\tLEFT')\
86 .replace('INNER', '\n\tINNER')\
86 .replace('INNER', '\n\tINNER')\
87 .replace('INSERT', '\n\tINSERT')\
87 .replace('INSERT', '\n\tINSERT')\
88 .replace('DELETE', '\n\tDELETE')
88 .replace('DELETE', '\n\tDELETE')
89 return sql
89 return sql
90
90
91
91
92 class ExceptionAwareFormatter(logging.Formatter):
92 class ExceptionAwareFormatter(logging.Formatter):
93 """
93 """
94 Extended logging formatter which prints out remote tracebacks.
94 Extended logging formatter which prints out remote tracebacks.
95 """
95 """
96
96
97 def formatException(self, ei):
97 def formatException(self, ei):
98 ex_type, ex_value, ex_tb = ei
98 ex_type, ex_value, ex_tb = ei
99
99
100 local_tb = logging.Formatter.formatException(self, ei)
100 local_tb = logging.Formatter.formatException(self, ei)
101 if hasattr(ex_value, '_vcs_server_traceback'):
101 if hasattr(ex_value, '_vcs_server_traceback'):
102
102
103 def formatRemoteTraceback(remote_tb_lines):
103 def formatRemoteTraceback(remote_tb_lines):
104 result = ["\n +--- This exception occured remotely on VCSServer - Remote traceback:\n\n"]
104 result = ["\n +--- This exception occured remotely on VCSServer - Remote traceback:\n\n"]
105 result.append(remote_tb_lines)
105 result.append(remote_tb_lines)
106 result.append("\n +--- End of remote traceback\n")
106 result.append("\n +--- End of remote traceback\n")
107 return result
107 return result
108
108
109 try:
109 try:
110 if ex_type is not None and ex_value is None and ex_tb is None:
110 if ex_type is not None and ex_value is None and ex_tb is None:
111 # possible old (3.x) call syntax where caller is only
111 # possible old (3.x) call syntax where caller is only
112 # providing exception object
112 # providing exception object
113 if type(ex_type) is not type:
113 if type(ex_type) is not type:
114 raise TypeError(
114 raise TypeError(
115 "invalid argument: ex_type should be an exception "
115 "invalid argument: ex_type should be an exception "
116 "type, or just supply no arguments at all")
116 "type, or just supply no arguments at all")
117 if ex_type is None and ex_tb is None:
117 if ex_type is None and ex_tb is None:
118 ex_type, ex_value, ex_tb = sys.exc_info()
118 ex_type, ex_value, ex_tb = sys.exc_info()
119
119
120 remote_tb = getattr(ex_value, "_vcs_server_traceback", None)
120 remote_tb = getattr(ex_value, "_vcs_server_traceback", None)
121
121
122 if remote_tb:
122 if remote_tb:
123 remote_tb = formatRemoteTraceback(remote_tb)
123 remote_tb = formatRemoteTraceback(remote_tb)
124 return local_tb + ''.join(remote_tb)
124 return local_tb + ''.join(remote_tb)
125 finally:
125 finally:
126 # clean up cycle to traceback, to allow proper GC
126 # clean up cycle to traceback, to allow proper GC
127 del ex_type, ex_value, ex_tb
127 del ex_type, ex_value, ex_tb
128
128
129 return local_tb
129 return local_tb
130
130
131
131
132 class RequestTrackingFormatter(ExceptionAwareFormatter):
132 class RequestTrackingFormatter(ExceptionAwareFormatter):
133 def format(self, record):
133 def format(self, record):
134 _inject_req_id(record)
134 _inject_req_id(record)
135 def_record = logging.Formatter.format(self, record)
135 def_record = logging.Formatter.format(self, record)
136 _add_log_to_debug_bucket(def_record)
136 _add_log_to_debug_bucket(def_record)
137 return def_record
137 return def_record
138
138
139
139
140 class ColorFormatter(ExceptionAwareFormatter):
140 class ColorFormatter(ExceptionAwareFormatter):
141
141
142 def format(self, record):
142 def format(self, record):
143 """
143 """
144 Changes record's levelname to use with COLORS enum
144 Changes record's levelname to use with COLORS enum
145 """
145 """
146 def_record = super(ColorFormatter, self).format(record)
146 def_record = super(ColorFormatter, self).format(record)
147
147
148 levelname = record.levelname
148 levelname = record.levelname
149 start = COLOR_SEQ % (COLORS[levelname])
149 start = COLOR_SEQ % (COLORS[levelname])
150 end = RESET_SEQ
150 end = RESET_SEQ
151
151
152 colored_record = ''.join([start, def_record, end])
152 colored_record = ''.join([start, def_record, end])
153 return colored_record
153 return colored_record
154
154
155
155
156 class ColorRequestTrackingFormatter(RequestTrackingFormatter):
156 class ColorRequestTrackingFormatter(RequestTrackingFormatter):
157
157
158 def format(self, record):
158 def format(self, record):
159 """
159 """
160 Changes record's levelname to use with COLORS enum
160 Changes record's levelname to use with COLORS enum
161 """
161 """
162 def_record = super(ColorRequestTrackingFormatter, self).format(record)
162 def_record = super(ColorRequestTrackingFormatter, self).format(record)
163
163
164 levelname = record.levelname
164 levelname = record.levelname
165 start = COLOR_SEQ % (COLORS[levelname])
165 start = COLOR_SEQ % (COLORS[levelname])
166 end = RESET_SEQ
166 end = RESET_SEQ
167
167
168 colored_record = ''.join([start, def_record, end])
168 colored_record = ''.join([start, def_record, end])
169 return colored_record
169 return colored_record
170
170
171
171
172 class ColorFormatterSql(logging.Formatter):
172 class ColorFormatterSql(logging.Formatter):
173
173
174 def format(self, record):
174 def format(self, record):
175 """
175 """
176 Changes record's levelname to use with COLORS enum
176 Changes record's levelname to use with COLORS enum
177 """
177 """
178
178
179 start = COLOR_SEQ % (COLORS['SQL'])
179 start = COLOR_SEQ % (COLORS['SQL'])
180 def_record = format_sql(logging.Formatter.format(self, record))
180 def_record = format_sql(logging.Formatter.format(self, record))
181 end = RESET_SEQ
181 end = RESET_SEQ
182
182
183 colored_record = ''.join([start, def_record, end])
183 colored_record = ''.join([start, def_record, end])
184 return colored_record
184 return colored_record
185
185
186 # marcink: needs to stay with this name for backward .ini compatability
186 # marcink: needs to stay with this name for backward .ini compatability
187 Pyro4AwareFormatter = ExceptionAwareFormatter
187 Pyro4AwareFormatter = ExceptionAwareFormatter
@@ -1,580 +1,580 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2011-2020 RhodeCode GmbH
3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21
21
22 """
22 """
23 Renderer for markup languages with ability to parse using rst or markdown
23 Renderer for markup languages with ability to parse using rst or markdown
24 """
24 """
25
25
26 import re
26 import re
27 import os
27 import os
28 import lxml
28 import lxml
29 import logging
29 import logging
30 import urllib.parse
30 import urllib.parse
31 import bleach
31 import bleach
32
32
33 from mako.lookup import TemplateLookup
33 from mako.lookup import TemplateLookup
34 from mako.template import Template as MakoTemplate
34 from mako.template import Template as MakoTemplate
35
35
36 from docutils.core import publish_parts
36 from docutils.core import publish_parts
37 from docutils.parsers.rst import directives
37 from docutils.parsers.rst import directives
38 from docutils import writers
38 from docutils import writers
39 from docutils.writers import html4css1
39 from docutils.writers import html4css1
40 import markdown
40 import markdown
41
41
42 from rhodecode.lib.markdown_ext import GithubFlavoredMarkdownExtension
42 from rhodecode.lib.markdown_ext import GithubFlavoredMarkdownExtension
43 from rhodecode.lib.utils2 import (safe_unicode, md5_safe, MENTIONS_REGEX)
43 from rhodecode.lib.utils2 import (safe_unicode, md5_safe, MENTIONS_REGEX)
44
44
45 log = logging.getLogger(__name__)
45 log = logging.getLogger(__name__)
46
46
47 # default renderer used to generate automated comments
47 # default renderer used to generate automated comments
48 DEFAULT_COMMENTS_RENDERER = 'rst'
48 DEFAULT_COMMENTS_RENDERER = 'rst'
49
49
50 try:
50 try:
51 from lxml.html import fromstring
51 from lxml.html import fromstring
52 from lxml.html import tostring
52 from lxml.html import tostring
53 except ImportError:
53 except ImportError:
54 log.exception('Failed to import lxml')
54 log.exception('Failed to import lxml')
55 fromstring = None
55 fromstring = None
56 tostring = None
56 tostring = None
57
57
58
58
59 class CustomHTMLTranslator(writers.html4css1.HTMLTranslator):
59 class CustomHTMLTranslator(writers.html4css1.HTMLTranslator):
60 """
60 """
61 Custom HTML Translator used for sandboxing potential
61 Custom HTML Translator used for sandboxing potential
62 JS injections in ref links
62 JS injections in ref links
63 """
63 """
64 def visit_literal_block(self, node):
64 def visit_literal_block(self, node):
65 self.body.append(self.starttag(node, 'pre', CLASS='codehilite literal-block'))
65 self.body.append(self.starttag(node, 'pre', CLASS='codehilite literal-block'))
66
66
67 def visit_reference(self, node):
67 def visit_reference(self, node):
68 if 'refuri' in node.attributes:
68 if 'refuri' in node.attributes:
69 refuri = node['refuri']
69 refuri = node['refuri']
70 if ':' in refuri:
70 if ':' in refuri:
71 prefix, link = refuri.lstrip().split(':', 1)
71 prefix, link = refuri.lstrip().split(':', 1)
72 prefix = prefix or ''
72 prefix = prefix or ''
73
73
74 if prefix.lower() == 'javascript':
74 if prefix.lower() == 'javascript':
75 # we don't allow javascript type of refs...
75 # we don't allow javascript type of refs...
76 node['refuri'] = 'javascript:alert("SandBoxedJavascript")'
76 node['refuri'] = 'javascript:alert("SandBoxedJavascript")'
77
77
78 # old style class requires this...
78 # old style class requires this...
79 return html4css1.HTMLTranslator.visit_reference(self, node)
79 return html4css1.HTMLTranslator.visit_reference(self, node)
80
80
81
81
82 class RhodeCodeWriter(writers.html4css1.Writer):
82 class RhodeCodeWriter(writers.html4css1.Writer):
83 def __init__(self):
83 def __init__(self):
84 writers.Writer.__init__(self)
84 writers.Writer.__init__(self)
85 self.translator_class = CustomHTMLTranslator
85 self.translator_class = CustomHTMLTranslator
86
86
87
87
88 def relative_links(html_source, server_paths):
88 def relative_links(html_source, server_paths):
89 if not html_source:
89 if not html_source:
90 return html_source
90 return html_source
91
91
92 if not fromstring and tostring:
92 if not fromstring and tostring:
93 return html_source
93 return html_source
94
94
95 try:
95 try:
96 doc = lxml.html.fromstring(html_source)
96 doc = lxml.html.fromstring(html_source)
97 except Exception:
97 except Exception:
98 return html_source
98 return html_source
99
99
100 for el in doc.cssselect('img, video'):
100 for el in doc.cssselect('img, video'):
101 src = el.attrib.get('src')
101 src = el.attrib.get('src')
102 if src:
102 if src:
103 el.attrib['src'] = relative_path(src, server_paths['raw'])
103 el.attrib['src'] = relative_path(src, server_paths['raw'])
104
104
105 for el in doc.cssselect('a:not(.gfm)'):
105 for el in doc.cssselect('a:not(.gfm)'):
106 src = el.attrib.get('href')
106 src = el.attrib.get('href')
107 if src:
107 if src:
108 raw_mode = el.attrib['href'].endswith('?raw=1')
108 raw_mode = el.attrib['href'].endswith('?raw=1')
109 if raw_mode:
109 if raw_mode:
110 el.attrib['href'] = relative_path(src, server_paths['raw'])
110 el.attrib['href'] = relative_path(src, server_paths['raw'])
111 else:
111 else:
112 el.attrib['href'] = relative_path(src, server_paths['standard'])
112 el.attrib['href'] = relative_path(src, server_paths['standard'])
113
113
114 return lxml.html.tostring(doc)
114 return lxml.html.tostring(doc)
115
115
116
116
117 def relative_path(path, request_path, is_repo_file=None):
117 def relative_path(path, request_path, is_repo_file=None):
118 """
118 """
119 relative link support, path is a rel path, and request_path is current
119 relative link support, path is a rel path, and request_path is current
120 server path (not absolute)
120 server path (not absolute)
121
121
122 e.g.
122 e.g.
123
123
124 path = '../logo.png'
124 path = '../logo.png'
125 request_path= '/repo/files/path/file.md'
125 request_path= '/repo/files/path/file.md'
126 produces: '/repo/files/logo.png'
126 produces: '/repo/files/logo.png'
127 """
127 """
128 # TODO(marcink): unicode/str support ?
128 # TODO(marcink): unicode/str support ?
129 # maybe=> safe_unicode(urllib.quote(safe_str(final_path), '/:'))
129 # maybe=> safe_unicode(urllib.quote(safe_str(final_path), '/:'))
130
130
131 def dummy_check(p):
131 def dummy_check(p):
132 return True # assume default is a valid file path
132 return True # assume default is a valid file path
133
133
134 is_repo_file = is_repo_file or dummy_check
134 is_repo_file = is_repo_file or dummy_check
135 if not path:
135 if not path:
136 return request_path
136 return request_path
137
137
138 path = safe_unicode(path)
138 path = safe_unicode(path)
139 request_path = safe_unicode(request_path)
139 request_path = safe_unicode(request_path)
140
140
141 if path.startswith((u'data:', u'javascript:', u'#', u':')):
141 if path.startswith(('data:', 'javascript:', '#', ':')):
142 # skip data, anchor, invalid links
142 # skip data, anchor, invalid links
143 return path
143 return path
144
144
145 is_absolute = bool(urllib.parse.urlparse(path).netloc)
145 is_absolute = bool(urllib.parse.urlparse(path).netloc)
146 if is_absolute:
146 if is_absolute:
147 return path
147 return path
148
148
149 if not request_path:
149 if not request_path:
150 return path
150 return path
151
151
152 if path.startswith(u'/'):
152 if path.startswith('/'):
153 path = path[1:]
153 path = path[1:]
154
154
155 if path.startswith(u'./'):
155 if path.startswith('./'):
156 path = path[2:]
156 path = path[2:]
157
157
158 parts = request_path.split('/')
158 parts = request_path.split('/')
159 # compute how deep we need to traverse the request_path
159 # compute how deep we need to traverse the request_path
160 depth = 0
160 depth = 0
161
161
162 if is_repo_file(request_path):
162 if is_repo_file(request_path):
163 # if request path is a VALID file, we use a relative path with
163 # if request path is a VALID file, we use a relative path with
164 # one level up
164 # one level up
165 depth += 1
165 depth += 1
166
166
167 while path.startswith(u'../'):
167 while path.startswith('../'):
168 depth += 1
168 depth += 1
169 path = path[3:]
169 path = path[3:]
170
170
171 if depth > 0:
171 if depth > 0:
172 parts = parts[:-depth]
172 parts = parts[:-depth]
173
173
174 parts.append(path)
174 parts.append(path)
175 final_path = u'/'.join(parts).lstrip(u'/')
175 final_path = '/'.join(parts).lstrip('/')
176
176
177 return u'/' + final_path
177 return '/' + final_path
178
178
179
179
180 _cached_markdown_renderer = None
180 _cached_markdown_renderer = None
181
181
182
182
183 def get_markdown_renderer(extensions, output_format):
183 def get_markdown_renderer(extensions, output_format):
184 global _cached_markdown_renderer
184 global _cached_markdown_renderer
185
185
186 if _cached_markdown_renderer is None:
186 if _cached_markdown_renderer is None:
187 _cached_markdown_renderer = markdown.Markdown(
187 _cached_markdown_renderer = markdown.Markdown(
188 extensions=extensions,
188 extensions=extensions,
189 enable_attributes=False, output_format=output_format)
189 enable_attributes=False, output_format=output_format)
190 return _cached_markdown_renderer
190 return _cached_markdown_renderer
191
191
192
192
193 _cached_markdown_renderer_flavored = None
193 _cached_markdown_renderer_flavored = None
194
194
195
195
196 def get_markdown_renderer_flavored(extensions, output_format):
196 def get_markdown_renderer_flavored(extensions, output_format):
197 global _cached_markdown_renderer_flavored
197 global _cached_markdown_renderer_flavored
198
198
199 if _cached_markdown_renderer_flavored is None:
199 if _cached_markdown_renderer_flavored is None:
200 _cached_markdown_renderer_flavored = markdown.Markdown(
200 _cached_markdown_renderer_flavored = markdown.Markdown(
201 extensions=extensions + [GithubFlavoredMarkdownExtension()],
201 extensions=extensions + [GithubFlavoredMarkdownExtension()],
202 enable_attributes=False, output_format=output_format)
202 enable_attributes=False, output_format=output_format)
203 return _cached_markdown_renderer_flavored
203 return _cached_markdown_renderer_flavored
204
204
205
205
206 class MarkupRenderer(object):
206 class MarkupRenderer(object):
207 RESTRUCTUREDTEXT_DISALLOWED_DIRECTIVES = ['include', 'meta', 'raw']
207 RESTRUCTUREDTEXT_DISALLOWED_DIRECTIVES = ['include', 'meta', 'raw']
208
208
209 MARKDOWN_PAT = re.compile(r'\.(md|mkdn?|mdown|markdown)$', re.IGNORECASE)
209 MARKDOWN_PAT = re.compile(r'\.(md|mkdn?|mdown|markdown)$', re.IGNORECASE)
210 RST_PAT = re.compile(r'\.re?st$', re.IGNORECASE)
210 RST_PAT = re.compile(r'\.re?st$', re.IGNORECASE)
211 JUPYTER_PAT = re.compile(r'\.(ipynb)$', re.IGNORECASE)
211 JUPYTER_PAT = re.compile(r'\.(ipynb)$', re.IGNORECASE)
212 PLAIN_PAT = re.compile(r'^readme$', re.IGNORECASE)
212 PLAIN_PAT = re.compile(r'^readme$', re.IGNORECASE)
213
213
214 URL_PAT = re.compile(r'(http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]'
214 URL_PAT = re.compile(r'(http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]'
215 r'|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+)')
215 r'|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+)')
216
216
217 MENTION_PAT = re.compile(MENTIONS_REGEX)
217 MENTION_PAT = re.compile(MENTIONS_REGEX)
218
218
219 extensions = ['markdown.extensions.codehilite', 'markdown.extensions.extra',
219 extensions = ['markdown.extensions.codehilite', 'markdown.extensions.extra',
220 'markdown.extensions.def_list', 'markdown.extensions.sane_lists']
220 'markdown.extensions.def_list', 'markdown.extensions.sane_lists']
221
221
222 output_format = 'html4'
222 output_format = 'html4'
223
223
224 # extension together with weights. Lower is first means we control how
224 # extension together with weights. Lower is first means we control how
225 # extensions are attached to readme names with those.
225 # extensions are attached to readme names with those.
226 PLAIN_EXTS = [
226 PLAIN_EXTS = [
227 # prefer no extension
227 # prefer no extension
228 ('', 0), # special case that renders READMES names without extension
228 ('', 0), # special case that renders READMES names without extension
229 ('.text', 2), ('.TEXT', 2),
229 ('.text', 2), ('.TEXT', 2),
230 ('.txt', 3), ('.TXT', 3)
230 ('.txt', 3), ('.TXT', 3)
231 ]
231 ]
232
232
233 RST_EXTS = [
233 RST_EXTS = [
234 ('.rst', 1), ('.rest', 1),
234 ('.rst', 1), ('.rest', 1),
235 ('.RST', 2), ('.REST', 2)
235 ('.RST', 2), ('.REST', 2)
236 ]
236 ]
237
237
238 MARKDOWN_EXTS = [
238 MARKDOWN_EXTS = [
239 ('.md', 1), ('.MD', 1),
239 ('.md', 1), ('.MD', 1),
240 ('.mkdn', 2), ('.MKDN', 2),
240 ('.mkdn', 2), ('.MKDN', 2),
241 ('.mdown', 3), ('.MDOWN', 3),
241 ('.mdown', 3), ('.MDOWN', 3),
242 ('.markdown', 4), ('.MARKDOWN', 4)
242 ('.markdown', 4), ('.MARKDOWN', 4)
243 ]
243 ]
244
244
245 def _detect_renderer(self, source, filename=None):
245 def _detect_renderer(self, source, filename=None):
246 """
246 """
247 runs detection of what renderer should be used for generating html
247 runs detection of what renderer should be used for generating html
248 from a markup language
248 from a markup language
249
249
250 filename can be also explicitly a renderer name
250 filename can be also explicitly a renderer name
251
251
252 :param source:
252 :param source:
253 :param filename:
253 :param filename:
254 """
254 """
255
255
256 if MarkupRenderer.MARKDOWN_PAT.findall(filename):
256 if MarkupRenderer.MARKDOWN_PAT.findall(filename):
257 detected_renderer = 'markdown'
257 detected_renderer = 'markdown'
258 elif MarkupRenderer.RST_PAT.findall(filename):
258 elif MarkupRenderer.RST_PAT.findall(filename):
259 detected_renderer = 'rst'
259 detected_renderer = 'rst'
260 elif MarkupRenderer.JUPYTER_PAT.findall(filename):
260 elif MarkupRenderer.JUPYTER_PAT.findall(filename):
261 detected_renderer = 'jupyter'
261 detected_renderer = 'jupyter'
262 elif MarkupRenderer.PLAIN_PAT.findall(filename):
262 elif MarkupRenderer.PLAIN_PAT.findall(filename):
263 detected_renderer = 'plain'
263 detected_renderer = 'plain'
264 else:
264 else:
265 detected_renderer = 'plain'
265 detected_renderer = 'plain'
266
266
267 return getattr(MarkupRenderer, detected_renderer)
267 return getattr(MarkupRenderer, detected_renderer)
268
268
269 @classmethod
269 @classmethod
270 def bleach_clean(cls, text):
270 def bleach_clean(cls, text):
271 from .bleach_whitelist import markdown_attrs, markdown_tags
271 from .bleach_whitelist import markdown_attrs, markdown_tags
272 allowed_tags = markdown_tags
272 allowed_tags = markdown_tags
273 allowed_attrs = markdown_attrs
273 allowed_attrs = markdown_attrs
274
274
275 try:
275 try:
276 return bleach.clean(text, tags=allowed_tags, attributes=allowed_attrs)
276 return bleach.clean(text, tags=allowed_tags, attributes=allowed_attrs)
277 except Exception:
277 except Exception:
278 return 'UNPARSEABLE TEXT'
278 return 'UNPARSEABLE TEXT'
279
279
280 @classmethod
280 @classmethod
281 def renderer_from_filename(cls, filename, exclude):
281 def renderer_from_filename(cls, filename, exclude):
282 """
282 """
283 Detect renderer markdown/rst from filename and optionally use exclude
283 Detect renderer markdown/rst from filename and optionally use exclude
284 list to remove some options. This is mostly used in helpers.
284 list to remove some options. This is mostly used in helpers.
285 Returns None when no renderer can be detected.
285 Returns None when no renderer can be detected.
286 """
286 """
287 def _filter(elements):
287 def _filter(elements):
288 if isinstance(exclude, (list, tuple)):
288 if isinstance(exclude, (list, tuple)):
289 return [x for x in elements if x not in exclude]
289 return [x for x in elements if x not in exclude]
290 return elements
290 return elements
291
291
292 if filename.endswith(
292 if filename.endswith(
293 tuple(_filter([x[0] for x in cls.MARKDOWN_EXTS if x[0]]))):
293 tuple(_filter([x[0] for x in cls.MARKDOWN_EXTS if x[0]]))):
294 return 'markdown'
294 return 'markdown'
295 if filename.endswith(tuple(_filter([x[0] for x in cls.RST_EXTS if x[0]]))):
295 if filename.endswith(tuple(_filter([x[0] for x in cls.RST_EXTS if x[0]]))):
296 return 'rst'
296 return 'rst'
297
297
298 return None
298 return None
299
299
300 def render(self, source, filename=None):
300 def render(self, source, filename=None):
301 """
301 """
302 Renders a given filename using detected renderer
302 Renders a given filename using detected renderer
303 it detects renderers based on file extension or mimetype.
303 it detects renderers based on file extension or mimetype.
304 At last it will just do a simple html replacing new lines with <br/>
304 At last it will just do a simple html replacing new lines with <br/>
305
305
306 :param file_name:
306 :param file_name:
307 :param source:
307 :param source:
308 """
308 """
309
309
310 renderer = self._detect_renderer(source, filename)
310 renderer = self._detect_renderer(source, filename)
311 readme_data = renderer(source)
311 readme_data = renderer(source)
312 return readme_data
312 return readme_data
313
313
314 @classmethod
314 @classmethod
315 def _flavored_markdown(cls, text):
315 def _flavored_markdown(cls, text):
316 """
316 """
317 Github style flavored markdown
317 Github style flavored markdown
318
318
319 :param text:
319 :param text:
320 """
320 """
321
321
322 # Extract pre blocks.
322 # Extract pre blocks.
323 extractions = {}
323 extractions = {}
324
324
325 def pre_extraction_callback(matchobj):
325 def pre_extraction_callback(matchobj):
326 digest = md5_safe(matchobj.group(0))
326 digest = md5_safe(matchobj.group(0))
327 extractions[digest] = matchobj.group(0)
327 extractions[digest] = matchobj.group(0)
328 return "{gfm-extraction-%s}" % digest
328 return "{gfm-extraction-%s}" % digest
329 pattern = re.compile(r'<pre>.*?</pre>', re.MULTILINE | re.DOTALL)
329 pattern = re.compile(r'<pre>.*?</pre>', re.MULTILINE | re.DOTALL)
330 text = re.sub(pattern, pre_extraction_callback, text)
330 text = re.sub(pattern, pre_extraction_callback, text)
331
331
332 # Prevent foo_bar_baz from ending up with an italic word in the middle.
332 # Prevent foo_bar_baz from ending up with an italic word in the middle.
333 def italic_callback(matchobj):
333 def italic_callback(matchobj):
334 s = matchobj.group(0)
334 s = matchobj.group(0)
335 if list(s).count('_') >= 2:
335 if list(s).count('_') >= 2:
336 return s.replace('_', r'\_')
336 return s.replace('_', r'\_')
337 return s
337 return s
338 text = re.sub(r'^(?! {4}|\t)\w+_\w+_\w[\w_]*', italic_callback, text)
338 text = re.sub(r'^(?! {4}|\t)\w+_\w+_\w[\w_]*', italic_callback, text)
339
339
340 # Insert pre block extractions.
340 # Insert pre block extractions.
341 def pre_insert_callback(matchobj):
341 def pre_insert_callback(matchobj):
342 return '\n\n' + extractions[matchobj.group(1)]
342 return '\n\n' + extractions[matchobj.group(1)]
343 text = re.sub(r'\{gfm-extraction-([0-9a-f]{32})\}',
343 text = re.sub(r'\{gfm-extraction-([0-9a-f]{32})\}',
344 pre_insert_callback, text)
344 pre_insert_callback, text)
345
345
346 return text
346 return text
347
347
348 @classmethod
348 @classmethod
349 def urlify_text(cls, text):
349 def urlify_text(cls, text):
350 def url_func(match_obj):
350 def url_func(match_obj):
351 url_full = match_obj.groups()[0]
351 url_full = match_obj.groups()[0]
352 return '<a href="%(url)s">%(url)s</a>' % ({'url': url_full})
352 return '<a href="%(url)s">%(url)s</a>' % ({'url': url_full})
353
353
354 return cls.URL_PAT.sub(url_func, text)
354 return cls.URL_PAT.sub(url_func, text)
355
355
356 @classmethod
356 @classmethod
357 def convert_mentions(cls, text, mode):
357 def convert_mentions(cls, text, mode):
358 mention_pat = cls.MENTION_PAT
358 mention_pat = cls.MENTION_PAT
359
359
360 def wrapp(match_obj):
360 def wrapp(match_obj):
361 uname = match_obj.groups()[0]
361 uname = match_obj.groups()[0]
362 hovercard_url = "pyroutes.url('hovercard_username', {'username': '%s'});" % uname
362 hovercard_url = "pyroutes.url('hovercard_username', {'username': '%s'});" % uname
363
363
364 if mode == 'markdown':
364 if mode == 'markdown':
365 tmpl = '<strong class="tooltip-hovercard" data-hovercard-alt="{uname}" data-hovercard-url="{hovercard_url}">@{uname}</strong>'
365 tmpl = '<strong class="tooltip-hovercard" data-hovercard-alt="{uname}" data-hovercard-url="{hovercard_url}">@{uname}</strong>'
366 elif mode == 'rst':
366 elif mode == 'rst':
367 tmpl = ' **@{uname}** '
367 tmpl = ' **@{uname}** '
368 else:
368 else:
369 raise ValueError('mode must be rst or markdown')
369 raise ValueError('mode must be rst or markdown')
370
370
371 return tmpl.format(**{'uname': uname,
371 return tmpl.format(**{'uname': uname,
372 'hovercard_url': hovercard_url})
372 'hovercard_url': hovercard_url})
373
373
374 return mention_pat.sub(wrapp, text).strip()
374 return mention_pat.sub(wrapp, text).strip()
375
375
376 @classmethod
376 @classmethod
377 def plain(cls, source, universal_newline=True, leading_newline=True):
377 def plain(cls, source, universal_newline=True, leading_newline=True):
378 source = safe_unicode(source)
378 source = safe_unicode(source)
379 if universal_newline:
379 if universal_newline:
380 newline = '\n'
380 newline = '\n'
381 source = newline.join(source.splitlines())
381 source = newline.join(source.splitlines())
382
382
383 rendered_source = cls.urlify_text(source)
383 rendered_source = cls.urlify_text(source)
384 source = ''
384 source = ''
385 if leading_newline:
385 if leading_newline:
386 source += '<br />'
386 source += '<br />'
387 source += rendered_source.replace("\n", '<br />')
387 source += rendered_source.replace("\n", '<br />')
388
388
389 rendered = cls.bleach_clean(source)
389 rendered = cls.bleach_clean(source)
390 return rendered
390 return rendered
391
391
392 @classmethod
392 @classmethod
393 def markdown(cls, source, safe=True, flavored=True, mentions=False,
393 def markdown(cls, source, safe=True, flavored=True, mentions=False,
394 clean_html=True):
394 clean_html=True):
395 """
395 """
396 returns markdown rendered code cleaned by the bleach library
396 returns markdown rendered code cleaned by the bleach library
397 """
397 """
398
398
399 if flavored:
399 if flavored:
400 markdown_renderer = get_markdown_renderer_flavored(
400 markdown_renderer = get_markdown_renderer_flavored(
401 cls.extensions, cls.output_format)
401 cls.extensions, cls.output_format)
402 else:
402 else:
403 markdown_renderer = get_markdown_renderer(
403 markdown_renderer = get_markdown_renderer(
404 cls.extensions, cls.output_format)
404 cls.extensions, cls.output_format)
405
405
406 if mentions:
406 if mentions:
407 mention_hl = cls.convert_mentions(source, mode='markdown')
407 mention_hl = cls.convert_mentions(source, mode='markdown')
408 # we extracted mentions render with this using Mentions false
408 # we extracted mentions render with this using Mentions false
409 return cls.markdown(mention_hl, safe=safe, flavored=flavored,
409 return cls.markdown(mention_hl, safe=safe, flavored=flavored,
410 mentions=False)
410 mentions=False)
411
411
412 source = safe_unicode(source)
412 source = safe_unicode(source)
413
413
414 try:
414 try:
415 if flavored:
415 if flavored:
416 source = cls._flavored_markdown(source)
416 source = cls._flavored_markdown(source)
417 rendered = markdown_renderer.convert(source)
417 rendered = markdown_renderer.convert(source)
418 except Exception:
418 except Exception:
419 log.exception('Error when rendering Markdown')
419 log.exception('Error when rendering Markdown')
420 if safe:
420 if safe:
421 log.debug('Fallback to render in plain mode')
421 log.debug('Fallback to render in plain mode')
422 rendered = cls.plain(source)
422 rendered = cls.plain(source)
423 else:
423 else:
424 raise
424 raise
425
425
426 if clean_html:
426 if clean_html:
427 rendered = cls.bleach_clean(rendered)
427 rendered = cls.bleach_clean(rendered)
428 return rendered
428 return rendered
429
429
430 @classmethod
430 @classmethod
431 def rst(cls, source, safe=True, mentions=False, clean_html=False):
431 def rst(cls, source, safe=True, mentions=False, clean_html=False):
432 if mentions:
432 if mentions:
433 mention_hl = cls.convert_mentions(source, mode='rst')
433 mention_hl = cls.convert_mentions(source, mode='rst')
434 # we extracted mentions render with this using Mentions false
434 # we extracted mentions render with this using Mentions false
435 return cls.rst(mention_hl, safe=safe, mentions=False)
435 return cls.rst(mention_hl, safe=safe, mentions=False)
436
436
437 source = safe_unicode(source)
437 source = safe_unicode(source)
438 try:
438 try:
439 docutils_settings = dict(
439 docutils_settings = dict(
440 [(alias, None) for alias in
440 [(alias, None) for alias in
441 cls.RESTRUCTUREDTEXT_DISALLOWED_DIRECTIVES])
441 cls.RESTRUCTUREDTEXT_DISALLOWED_DIRECTIVES])
442
442
443 docutils_settings.update({
443 docutils_settings.update({
444 'input_encoding': 'unicode',
444 'input_encoding': 'unicode',
445 'report_level': 4,
445 'report_level': 4,
446 'syntax_highlight': 'short',
446 'syntax_highlight': 'short',
447 })
447 })
448
448
449 for k, v in docutils_settings.items():
449 for k, v in docutils_settings.items():
450 directives.register_directive(k, v)
450 directives.register_directive(k, v)
451
451
452 parts = publish_parts(source=source,
452 parts = publish_parts(source=source,
453 writer=RhodeCodeWriter(),
453 writer=RhodeCodeWriter(),
454 settings_overrides=docutils_settings)
454 settings_overrides=docutils_settings)
455 rendered = parts["fragment"]
455 rendered = parts["fragment"]
456 if clean_html:
456 if clean_html:
457 rendered = cls.bleach_clean(rendered)
457 rendered = cls.bleach_clean(rendered)
458 return parts['html_title'] + rendered
458 return parts['html_title'] + rendered
459 except Exception:
459 except Exception:
460 log.exception('Error when rendering RST')
460 log.exception('Error when rendering RST')
461 if safe:
461 if safe:
462 log.debug('Fallback to render in plain mode')
462 log.debug('Fallback to render in plain mode')
463 return cls.plain(source)
463 return cls.plain(source)
464 else:
464 else:
465 raise
465 raise
466
466
467 @classmethod
467 @classmethod
468 def jupyter(cls, source, safe=True):
468 def jupyter(cls, source, safe=True):
469 from rhodecode.lib import helpers
469 from rhodecode.lib import helpers
470
470
471 from traitlets.config import Config
471 from traitlets.config import Config
472 import nbformat
472 import nbformat
473 from nbconvert import HTMLExporter
473 from nbconvert import HTMLExporter
474 from nbconvert.preprocessors import Preprocessor
474 from nbconvert.preprocessors import Preprocessor
475
475
476 class CustomHTMLExporter(HTMLExporter):
476 class CustomHTMLExporter(HTMLExporter):
477 def _template_file_default(self):
477 def _template_file_default(self):
478 return 'basic'
478 return 'basic'
479
479
480 class Sandbox(Preprocessor):
480 class Sandbox(Preprocessor):
481
481
482 def preprocess(self, nb, resources):
482 def preprocess(self, nb, resources):
483 sandbox_text = 'SandBoxed(IPython.core.display.Javascript object)'
483 sandbox_text = 'SandBoxed(IPython.core.display.Javascript object)'
484 for cell in nb['cells']:
484 for cell in nb['cells']:
485 if not safe:
485 if not safe:
486 continue
486 continue
487
487
488 if 'outputs' in cell:
488 if 'outputs' in cell:
489 for cell_output in cell['outputs']:
489 for cell_output in cell['outputs']:
490 if 'data' in cell_output:
490 if 'data' in cell_output:
491 if 'application/javascript' in cell_output['data']:
491 if 'application/javascript' in cell_output['data']:
492 cell_output['data']['text/plain'] = sandbox_text
492 cell_output['data']['text/plain'] = sandbox_text
493 cell_output['data'].pop('application/javascript', None)
493 cell_output['data'].pop('application/javascript', None)
494
494
495 if 'source' in cell and cell['cell_type'] == 'markdown':
495 if 'source' in cell and cell['cell_type'] == 'markdown':
496 # sanitize similar like in markdown
496 # sanitize similar like in markdown
497 cell['source'] = cls.bleach_clean(cell['source'])
497 cell['source'] = cls.bleach_clean(cell['source'])
498
498
499 return nb, resources
499 return nb, resources
500
500
501 def _sanitize_resources(input_resources):
501 def _sanitize_resources(input_resources):
502 """
502 """
503 Skip/sanitize some of the CSS generated and included in jupyter
503 Skip/sanitize some of the CSS generated and included in jupyter
504 so it doesn't messes up UI so much
504 so it doesn't messes up UI so much
505 """
505 """
506
506
507 # TODO(marcink): probably we should replace this with whole custom
507 # TODO(marcink): probably we should replace this with whole custom
508 # CSS set that doesn't screw up, but jupyter generated html has some
508 # CSS set that doesn't screw up, but jupyter generated html has some
509 # special markers, so it requires Custom HTML exporter template with
509 # special markers, so it requires Custom HTML exporter template with
510 # _default_template_path_default, to achieve that
510 # _default_template_path_default, to achieve that
511
511
512 # strip the reset CSS
512 # strip the reset CSS
513 input_resources[0] = input_resources[0][input_resources[0].find('/*! Source'):]
513 input_resources[0] = input_resources[0][input_resources[0].find('/*! Source'):]
514 return input_resources
514 return input_resources
515
515
516 def as_html(notebook):
516 def as_html(notebook):
517 conf = Config()
517 conf = Config()
518 conf.CustomHTMLExporter.preprocessors = [Sandbox]
518 conf.CustomHTMLExporter.preprocessors = [Sandbox]
519 html_exporter = CustomHTMLExporter(config=conf)
519 html_exporter = CustomHTMLExporter(config=conf)
520
520
521 (body, resources) = html_exporter.from_notebook_node(notebook)
521 (body, resources) = html_exporter.from_notebook_node(notebook)
522 header = '<!-- ## IPYTHON NOTEBOOK RENDERING ## -->'
522 header = '<!-- ## IPYTHON NOTEBOOK RENDERING ## -->'
523 js = MakoTemplate(r'''
523 js = MakoTemplate(r'''
524 <!-- MathJax configuration -->
524 <!-- MathJax configuration -->
525 <script type="text/x-mathjax-config">
525 <script type="text/x-mathjax-config">
526 MathJax.Hub.Config({
526 MathJax.Hub.Config({
527 jax: ["input/TeX","output/HTML-CSS", "output/PreviewHTML"],
527 jax: ["input/TeX","output/HTML-CSS", "output/PreviewHTML"],
528 extensions: ["tex2jax.js","MathMenu.js","MathZoom.js", "fast-preview.js", "AssistiveMML.js", "[Contrib]/a11y/accessibility-menu.js"],
528 extensions: ["tex2jax.js","MathMenu.js","MathZoom.js", "fast-preview.js", "AssistiveMML.js", "[Contrib]/a11y/accessibility-menu.js"],
529 TeX: {
529 TeX: {
530 extensions: ["AMSmath.js","AMSsymbols.js","noErrors.js","noUndefined.js"]
530 extensions: ["AMSmath.js","AMSsymbols.js","noErrors.js","noUndefined.js"]
531 },
531 },
532 tex2jax: {
532 tex2jax: {
533 inlineMath: [ ['$','$'], ["\\(","\\)"] ],
533 inlineMath: [ ['$','$'], ["\\(","\\)"] ],
534 displayMath: [ ['$$','$$'], ["\\[","\\]"] ],
534 displayMath: [ ['$$','$$'], ["\\[","\\]"] ],
535 processEscapes: true,
535 processEscapes: true,
536 processEnvironments: true
536 processEnvironments: true
537 },
537 },
538 // Center justify equations in code and markdown cells. Elsewhere
538 // Center justify equations in code and markdown cells. Elsewhere
539 // we use CSS to left justify single line equations in code cells.
539 // we use CSS to left justify single line equations in code cells.
540 displayAlign: 'center',
540 displayAlign: 'center',
541 "HTML-CSS": {
541 "HTML-CSS": {
542 styles: {'.MathJax_Display': {"margin": 0}},
542 styles: {'.MathJax_Display': {"margin": 0}},
543 linebreaks: { automatic: true },
543 linebreaks: { automatic: true },
544 availableFonts: ["STIX", "TeX"]
544 availableFonts: ["STIX", "TeX"]
545 },
545 },
546 showMathMenu: false
546 showMathMenu: false
547 });
547 });
548 </script>
548 </script>
549 <!-- End of MathJax configuration -->
549 <!-- End of MathJax configuration -->
550 <script src="${h.asset('js/src/math_jax/MathJax.js')}"></script>
550 <script src="${h.asset('js/src/math_jax/MathJax.js')}"></script>
551 ''').render(h=helpers)
551 ''').render(h=helpers)
552
552
553 css = MakoTemplate(r'''
553 css = MakoTemplate(r'''
554 <link rel="stylesheet" type="text/css" href="${h.asset('css/style-ipython.css', ver=ver)}" media="screen"/>
554 <link rel="stylesheet" type="text/css" href="${h.asset('css/style-ipython.css', ver=ver)}" media="screen"/>
555 ''').render(h=helpers, ver='ver1')
555 ''').render(h=helpers, ver='ver1')
556
556
557 body = '\n'.join([header, css, js, body])
557 body = '\n'.join([header, css, js, body])
558 return body, resources
558 return body, resources
559
559
560 notebook = nbformat.reads(source, as_version=4)
560 notebook = nbformat.reads(source, as_version=4)
561 (body, resources) = as_html(notebook)
561 (body, resources) = as_html(notebook)
562 return body
562 return body
563
563
564
564
565 class RstTemplateRenderer(object):
565 class RstTemplateRenderer(object):
566
566
567 def __init__(self):
567 def __init__(self):
568 base = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
568 base = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
569 rst_template_dirs = [os.path.join(base, 'templates', 'rst_templates')]
569 rst_template_dirs = [os.path.join(base, 'templates', 'rst_templates')]
570 self.template_store = TemplateLookup(
570 self.template_store = TemplateLookup(
571 directories=rst_template_dirs,
571 directories=rst_template_dirs,
572 input_encoding='utf-8',
572 input_encoding='utf-8',
573 imports=['from rhodecode.lib import helpers as h'])
573 imports=['from rhodecode.lib import helpers as h'])
574
574
575 def _get_template(self, templatename):
575 def _get_template(self, templatename):
576 return self.template_store.get_template(templatename)
576 return self.template_store.get_template(templatename)
577
577
578 def render(self, template_name, **kwargs):
578 def render(self, template_name, **kwargs):
579 template = self._get_template(template_name)
579 template = self._get_template(template_name)
580 return template.render(**kwargs)
580 return template.render(**kwargs)
@@ -1,156 +1,156 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 """
21 """
22 SimpleGit middleware for handling git protocol request (push/clone etc.)
22 SimpleGit middleware for handling git protocol request (push/clone etc.)
23 It's implemented with basic auth function
23 It's implemented with basic auth function
24 """
24 """
25 import os
25 import os
26 import re
26 import re
27 import logging
27 import logging
28 import urllib.parse
28 import urllib.parse
29
29
30 import rhodecode
30 import rhodecode
31 from rhodecode.lib import utils
31 from rhodecode.lib import utils
32 from rhodecode.lib import utils2
32 from rhodecode.lib import utils2
33 from rhodecode.lib.middleware import simplevcs
33 from rhodecode.lib.middleware import simplevcs
34
34
35 log = logging.getLogger(__name__)
35 log = logging.getLogger(__name__)
36
36
37
37
38 GIT_PROTO_PAT = re.compile(
38 GIT_PROTO_PAT = re.compile(
39 r'^/(.+)/(info/refs|info/lfs/(.+)|git-upload-pack|git-receive-pack)')
39 r'^/(.+)/(info/refs|info/lfs/(.+)|git-upload-pack|git-receive-pack)')
40 GIT_LFS_PROTO_PAT = re.compile(r'^/(.+)/(info/lfs/(.+))')
40 GIT_LFS_PROTO_PAT = re.compile(r'^/(.+)/(info/lfs/(.+))')
41
41
42
42
43 def default_lfs_store():
43 def default_lfs_store():
44 """
44 """
45 Default lfs store location, it's consistent with Mercurials large file
45 Default lfs store location, it's consistent with Mercurials large file
46 store which is in .cache/largefiles
46 store which is in .cache/largefiles
47 """
47 """
48 from rhodecode.lib.vcs.backends.git import lfs_store
48 from rhodecode.lib.vcs.backends.git import lfs_store
49 user_home = os.path.expanduser("~")
49 user_home = os.path.expanduser("~")
50 return lfs_store(user_home)
50 return lfs_store(user_home)
51
51
52
52
53 class SimpleGit(simplevcs.SimpleVCS):
53 class SimpleGit(simplevcs.SimpleVCS):
54
54
55 SCM = 'git'
55 SCM = 'git'
56
56
57 def _get_repository_name(self, environ):
57 def _get_repository_name(self, environ):
58 """
58 """
59 Gets repository name out of PATH_INFO header
59 Gets repository name out of PATH_INFO header
60
60
61 :param environ: environ where PATH_INFO is stored
61 :param environ: environ where PATH_INFO is stored
62 """
62 """
63 repo_name = GIT_PROTO_PAT.match(environ['PATH_INFO']).group(1)
63 repo_name = GIT_PROTO_PAT.match(environ['PATH_INFO']).group(1)
64 # for GIT LFS, and bare format strip .git suffix from names
64 # for GIT LFS, and bare format strip .git suffix from names
65 if repo_name.endswith('.git'):
65 if repo_name.endswith('.git'):
66 repo_name = repo_name[:-4]
66 repo_name = repo_name[:-4]
67 return repo_name
67 return repo_name
68
68
69 def _get_lfs_action(self, path, request_method):
69 def _get_lfs_action(self, path, request_method):
70 """
70 """
71 return an action based on LFS requests type.
71 return an action based on LFS requests type.
72 Those routes are handled inside vcsserver app.
72 Those routes are handled inside vcsserver app.
73
73
74 batch -> POST to /info/lfs/objects/batch => PUSH/PULL
74 batch -> POST to /info/lfs/objects/batch => PUSH/PULL
75 batch is based on the `operation.
75 batch is based on the `operation.
76 that could be download or upload, but those are only
76 that could be download or upload, but those are only
77 instructions to fetch so we return pull always
77 instructions to fetch so we return pull always
78
78
79 download -> GET to /info/lfs/{oid} => PULL
79 download -> GET to /info/lfs/{oid} => PULL
80 upload -> PUT to /info/lfs/{oid} => PUSH
80 upload -> PUT to /info/lfs/{oid} => PUSH
81
81
82 verification -> POST to /info/lfs/verify => PULL
82 verification -> POST to /info/lfs/verify => PULL
83
83
84 """
84 """
85
85
86 match_obj = GIT_LFS_PROTO_PAT.match(path)
86 match_obj = GIT_LFS_PROTO_PAT.match(path)
87 _parts = match_obj.groups()
87 _parts = match_obj.groups()
88 repo_name, path, operation = _parts
88 repo_name, path, operation = _parts
89 log.debug(
89 log.debug(
90 'LFS: detecting operation based on following '
90 'LFS: detecting operation based on following '
91 'data: %s, req_method:%s', _parts, request_method)
91 'data: %s, req_method:%s', _parts, request_method)
92
92
93 if operation == 'verify':
93 if operation == 'verify':
94 return 'pull'
94 return 'pull'
95 elif operation == 'objects/batch':
95 elif operation == 'objects/batch':
96 # batch sends back instructions for API to dl/upl we report it
96 # batch sends back instructions for API to dl/upl we report it
97 # as pull
97 # as pull
98 if request_method == 'POST':
98 if request_method == 'POST':
99 return 'pull'
99 return 'pull'
100
100
101 elif operation:
101 elif operation:
102 # probably a OID, upload is PUT, download a GET
102 # probably a OID, upload is PUT, download a GET
103 if request_method == 'GET':
103 if request_method == 'GET':
104 return 'pull'
104 return 'pull'
105 else:
105 else:
106 return 'push'
106 return 'push'
107
107
108 # if default not found require push, as action
108 # if default not found require push, as action
109 return 'push'
109 return 'push'
110
110
111 _ACTION_MAPPING = {
111 _ACTION_MAPPING = {
112 'git-receive-pack': 'push',
112 'git-receive-pack': 'push',
113 'git-upload-pack': 'pull',
113 'git-upload-pack': 'pull',
114 }
114 }
115
115
116 def _get_action(self, environ):
116 def _get_action(self, environ):
117 """
117 """
118 Maps git request commands into a pull or push command.
118 Maps git request commands into a pull or push command.
119 In case of unknown/unexpected data, it returns 'pull' to be safe.
119 In case of unknown/unexpected data, it returns 'pull' to be safe.
120
120
121 :param environ:
121 :param environ:
122 """
122 """
123 path = environ['PATH_INFO']
123 path = environ['PATH_INFO']
124
124
125 if path.endswith('/info/refs'):
125 if path.endswith('/info/refs'):
126 query = urllib.parse.urlparse.parse_qs(environ['QUERY_STRING'])
126 query = urllib.parse.parse_qs(environ['QUERY_STRING'])
127 service_cmd = query.get('service', [''])[0]
127 service_cmd = query.get('service', [''])[0]
128 return self._ACTION_MAPPING.get(service_cmd, 'pull')
128 return self._ACTION_MAPPING.get(service_cmd, 'pull')
129
129
130 elif GIT_LFS_PROTO_PAT.match(environ['PATH_INFO']):
130 elif GIT_LFS_PROTO_PAT.match(environ['PATH_INFO']):
131 return self._get_lfs_action(
131 return self._get_lfs_action(
132 environ['PATH_INFO'], environ['REQUEST_METHOD'])
132 environ['PATH_INFO'], environ['REQUEST_METHOD'])
133
133
134 elif path.endswith('/git-receive-pack'):
134 elif path.endswith('/git-receive-pack'):
135 return 'push'
135 return 'push'
136 elif path.endswith('/git-upload-pack'):
136 elif path.endswith('/git-upload-pack'):
137 return 'pull'
137 return 'pull'
138
138
139 return 'pull'
139 return 'pull'
140
140
141 def _create_wsgi_app(self, repo_path, repo_name, config):
141 def _create_wsgi_app(self, repo_path, repo_name, config):
142 return self.scm_app.create_git_wsgi_app(
142 return self.scm_app.create_git_wsgi_app(
143 repo_path, repo_name, config)
143 repo_path, repo_name, config)
144
144
145 def _create_config(self, extras, repo_name, scheme='http'):
145 def _create_config(self, extras, repo_name, scheme='http'):
146 extras['git_update_server_info'] = utils2.str2bool(
146 extras['git_update_server_info'] = utils2.str2bool(
147 rhodecode.CONFIG.get('git_update_server_info'))
147 rhodecode.CONFIG.get('git_update_server_info'))
148
148
149 config = utils.make_db_config(repo=repo_name)
149 config = utils.make_db_config(repo=repo_name)
150 custom_store = config.get('vcs_git_lfs', 'store_location')
150 custom_store = config.get('vcs_git_lfs', 'store_location')
151
151
152 extras['git_lfs_enabled'] = utils2.str2bool(
152 extras['git_lfs_enabled'] = utils2.str2bool(
153 config.get('vcs_git_lfs', 'enabled'))
153 config.get('vcs_git_lfs', 'enabled'))
154 extras['git_lfs_store_path'] = custom_store or default_lfs_store()
154 extras['git_lfs_store_path'] = custom_store or default_lfs_store()
155 extras['git_lfs_http_scheme'] = scheme
155 extras['git_lfs_http_scheme'] = scheme
156 return extras
156 return extras
@@ -1,160 +1,159 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 """
21 """
22 SimpleHG middleware for handling mercurial protocol request
22 SimpleHG middleware for handling mercurial protocol request
23 (push/clone etc.). It's implemented with basic auth function
23 (push/clone etc.). It's implemented with basic auth function
24 """
24 """
25
25
26 import logging
26 import logging
27 import urllib.parse
27 import urllib.parse
28 import urllib.request, urllib.parse, urllib.error
28 import urllib.request, urllib.parse, urllib.error
29
29
30 from rhodecode.lib import utils
30 from rhodecode.lib import utils
31 from rhodecode.lib.ext_json import json
31 from rhodecode.lib.ext_json import json
32 from rhodecode.lib.middleware import simplevcs
32 from rhodecode.lib.middleware import simplevcs
33
33
34 log = logging.getLogger(__name__)
34 log = logging.getLogger(__name__)
35
35
36
36
37 class SimpleHg(simplevcs.SimpleVCS):
37 class SimpleHg(simplevcs.SimpleVCS):
38
38
39 SCM = 'hg'
39 SCM = 'hg'
40
40
41 def _get_repository_name(self, environ):
41 def _get_repository_name(self, environ):
42 """
42 """
43 Gets repository name out of PATH_INFO header
43 Gets repository name out of PATH_INFO header
44
44
45 :param environ: environ where PATH_INFO is stored
45 :param environ: environ where PATH_INFO is stored
46 """
46 """
47 repo_name = environ['PATH_INFO']
47 repo_name = environ['PATH_INFO']
48 if repo_name and repo_name.startswith('/'):
48 if repo_name and repo_name.startswith('/'):
49 # remove only the first leading /
49 # remove only the first leading /
50 repo_name = repo_name[1:]
50 repo_name = repo_name[1:]
51 return repo_name.rstrip('/')
51 return repo_name.rstrip('/')
52
52
53 _ACTION_MAPPING = {
53 _ACTION_MAPPING = {
54 'changegroup': 'pull',
54 'changegroup': 'pull',
55 'changegroupsubset': 'pull',
55 'changegroupsubset': 'pull',
56 'getbundle': 'pull',
56 'getbundle': 'pull',
57 'stream_out': 'pull',
57 'stream_out': 'pull',
58 'listkeys': 'pull',
58 'listkeys': 'pull',
59 'between': 'pull',
59 'between': 'pull',
60 'branchmap': 'pull',
60 'branchmap': 'pull',
61 'branches': 'pull',
61 'branches': 'pull',
62 'clonebundles': 'pull',
62 'clonebundles': 'pull',
63 'capabilities': 'pull',
63 'capabilities': 'pull',
64 'debugwireargs': 'pull',
64 'debugwireargs': 'pull',
65 'heads': 'pull',
65 'heads': 'pull',
66 'lookup': 'pull',
66 'lookup': 'pull',
67 'hello': 'pull',
67 'hello': 'pull',
68 'known': 'pull',
68 'known': 'pull',
69
69
70 # largefiles
70 # largefiles
71 'putlfile': 'push',
71 'putlfile': 'push',
72 'getlfile': 'pull',
72 'getlfile': 'pull',
73 'statlfile': 'pull',
73 'statlfile': 'pull',
74 'lheads': 'pull',
74 'lheads': 'pull',
75
75
76 # evolve
76 # evolve
77 'evoext_obshashrange_v1': 'pull',
77 'evoext_obshashrange_v1': 'pull',
78 'evoext_obshash': 'pull',
78 'evoext_obshash': 'pull',
79 'evoext_obshash1': 'pull',
79 'evoext_obshash1': 'pull',
80
80
81 'unbundle': 'push',
81 'unbundle': 'push',
82 'pushkey': 'push',
82 'pushkey': 'push',
83 }
83 }
84
84
85 @classmethod
85 @classmethod
86 def _get_xarg_headers(cls, environ):
86 def _get_xarg_headers(cls, environ):
87 i = 1
87 i = 1
88 chunks = [] # gather chunks stored in multiple 'hgarg_N'
88 chunks = [] # gather chunks stored in multiple 'hgarg_N'
89 while True:
89 while True:
90 head = environ.get('HTTP_X_HGARG_{}'.format(i))
90 head = environ.get('HTTP_X_HGARG_{}'.format(i))
91 if not head:
91 if not head:
92 break
92 break
93 i += 1
93 i += 1
94 chunks.append(urllib.parse.unquote_plus(head))
94 chunks.append(urllib.parse.unquote_plus(head))
95 full_arg = ''.join(chunks)
95 full_arg = ''.join(chunks)
96 pref = 'cmds='
96 pref = 'cmds='
97 if full_arg.startswith(pref):
97 if full_arg.startswith(pref):
98 # strip the cmds= header defining our batch commands
98 # strip the cmds= header defining our batch commands
99 full_arg = full_arg[len(pref):]
99 full_arg = full_arg[len(pref):]
100 cmds = full_arg.split(';')
100 cmds = full_arg.split(';')
101 return cmds
101 return cmds
102
102
103 @classmethod
103 @classmethod
104 def _get_batch_cmd(cls, environ):
104 def _get_batch_cmd(cls, environ):
105 """
105 """
106 Handle batch command send commands. Those are ';' separated commands
106 Handle batch command send commands. Those are ';' separated commands
107 sent by batch command that server needs to execute. We need to extract
107 sent by batch command that server needs to execute. We need to extract
108 those, and map them to our ACTION_MAPPING to get all push/pull commands
108 those, and map them to our ACTION_MAPPING to get all push/pull commands
109 specified in the batch
109 specified in the batch
110 """
110 """
111 default = 'push'
111 default = 'push'
112 batch_cmds = []
112 batch_cmds = []
113 try:
113 try:
114 cmds = cls._get_xarg_headers(environ)
114 cmds = cls._get_xarg_headers(environ)
115 for pair in cmds:
115 for pair in cmds:
116 parts = pair.split(' ', 1)
116 parts = pair.split(' ', 1)
117 if len(parts) != 2:
117 if len(parts) != 2:
118 continue
118 continue
119 # entry should be in a format `key ARGS`
119 # entry should be in a format `key ARGS`
120 cmd, args = parts
120 cmd, args = parts
121 action = cls._ACTION_MAPPING.get(cmd, default)
121 action = cls._ACTION_MAPPING.get(cmd, default)
122 batch_cmds.append(action)
122 batch_cmds.append(action)
123 except Exception:
123 except Exception:
124 log.exception('Failed to extract batch commands operations')
124 log.exception('Failed to extract batch commands operations')
125
125
126 # in case we failed, (e.g malformed data) assume it's PUSH sub-command
126 # in case we failed, (e.g malformed data) assume it's PUSH sub-command
127 # for safety
127 # for safety
128 return batch_cmds or [default]
128 return batch_cmds or [default]
129
129
130 def _get_action(self, environ):
130 def _get_action(self, environ):
131 """
131 """
132 Maps mercurial request commands into a pull or push command.
132 Maps mercurial request commands into a pull or push command.
133 In case of unknown/unexpected data, it returns 'push' to be safe.
133 In case of unknown/unexpected data, it returns 'push' to be safe.
134
134
135 :param environ:
135 :param environ:
136 """
136 """
137 default = 'push'
137 default = 'push'
138 query = urllib.parse.urlparse.parse_qs(environ['QUERY_STRING'],
138 query = urllib.parse.parse_qs(environ['QUERY_STRING'], keep_blank_values=True)
139 keep_blank_values=True)
140
139
141 if 'cmd' in query:
140 if 'cmd' in query:
142 cmd = query['cmd'][0]
141 cmd = query['cmd'][0]
143 if cmd == 'batch':
142 if cmd == 'batch':
144 cmds = self._get_batch_cmd(environ)
143 cmds = self._get_batch_cmd(environ)
145 if 'push' in cmds:
144 if 'push' in cmds:
146 return 'push'
145 return 'push'
147 else:
146 else:
148 return 'pull'
147 return 'pull'
149 return self._ACTION_MAPPING.get(cmd, default)
148 return self._ACTION_MAPPING.get(cmd, default)
150
149
151 return default
150 return default
152
151
153 def _create_wsgi_app(self, repo_path, repo_name, config):
152 def _create_wsgi_app(self, repo_path, repo_name, config):
154 return self.scm_app.create_hg_wsgi_app(repo_path, repo_name, config)
153 return self.scm_app.create_hg_wsgi_app(repo_path, repo_name, config)
155
154
156 def _create_config(self, extras, repo_name, scheme='http'):
155 def _create_config(self, extras, repo_name, scheme='http'):
157 config = utils.make_db_config(repo=repo_name)
156 config = utils.make_db_config(repo=repo_name)
158 config.set('rhodecode', 'RC_SCM_DATA', json.dumps(extras))
157 config.set('rhodecode', 'RC_SCM_DATA', json.dumps(extras))
159
158
160 return config.serialize()
159 return config.serialize()
@@ -1,679 +1,679 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2014-2020 RhodeCode GmbH
3 # Copyright (C) 2014-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 """
21 """
22 SimpleVCS middleware for handling protocol request (push/clone etc.)
22 SimpleVCS middleware for handling protocol request (push/clone etc.)
23 It's implemented with basic auth function
23 It's implemented with basic auth function
24 """
24 """
25
25
26 import os
26 import os
27 import re
27 import re
28 import io
28 import logging
29 import logging
29 import importlib
30 import importlib
30 from functools import wraps
31 from functools import wraps
31 from io import StringIO
32 from lxml import etree
32 from lxml import etree
33
33
34 import time
34 import time
35 from paste.httpheaders import REMOTE_USER, AUTH_TYPE
35 from paste.httpheaders import REMOTE_USER, AUTH_TYPE
36
36
37 from pyramid.httpexceptions import (
37 from pyramid.httpexceptions import (
38 HTTPNotFound, HTTPForbidden, HTTPNotAcceptable, HTTPInternalServerError)
38 HTTPNotFound, HTTPForbidden, HTTPNotAcceptable, HTTPInternalServerError)
39 from zope.cachedescriptors.property import Lazy as LazyProperty
39 from zope.cachedescriptors.property import Lazy as LazyProperty
40
40
41 import rhodecode
41 import rhodecode
42 from rhodecode.authentication.base import authenticate, VCS_TYPE, loadplugin
42 from rhodecode.authentication.base import authenticate, VCS_TYPE, loadplugin
43 from rhodecode.lib import rc_cache
43 from rhodecode.lib import rc_cache
44 from rhodecode.lib.auth import AuthUser, HasPermissionAnyMiddleware
44 from rhodecode.lib.auth import AuthUser, HasPermissionAnyMiddleware
45 from rhodecode.lib.base import (
45 from rhodecode.lib.base import (
46 BasicAuth, get_ip_addr, get_user_agent, vcs_operation_context)
46 BasicAuth, get_ip_addr, get_user_agent, vcs_operation_context)
47 from rhodecode.lib.exceptions import (UserCreationError, NotAllowedToCreateUserError)
47 from rhodecode.lib.exceptions import (UserCreationError, NotAllowedToCreateUserError)
48 from rhodecode.lib.hooks_daemon import prepare_callback_daemon
48 from rhodecode.lib.hooks_daemon import prepare_callback_daemon
49 from rhodecode.lib.middleware import appenlight
49 from rhodecode.lib.middleware import appenlight
50 from rhodecode.lib.middleware.utils import scm_app_http
50 from rhodecode.lib.middleware.utils import scm_app_http
51 from rhodecode.lib.utils import is_valid_repo, SLUG_RE
51 from rhodecode.lib.utils import is_valid_repo, SLUG_RE
52 from rhodecode.lib.utils2 import safe_str, fix_PATH, str2bool, safe_unicode
52 from rhodecode.lib.utils2 import safe_str, fix_PATH, str2bool, safe_unicode
53 from rhodecode.lib.vcs.conf import settings as vcs_settings
53 from rhodecode.lib.vcs.conf import settings as vcs_settings
54 from rhodecode.lib.vcs.backends import base
54 from rhodecode.lib.vcs.backends import base
55
55
56 from rhodecode.model import meta
56 from rhodecode.model import meta
57 from rhodecode.model.db import User, Repository, PullRequest
57 from rhodecode.model.db import User, Repository, PullRequest
58 from rhodecode.model.scm import ScmModel
58 from rhodecode.model.scm import ScmModel
59 from rhodecode.model.pull_request import PullRequestModel
59 from rhodecode.model.pull_request import PullRequestModel
60 from rhodecode.model.settings import SettingsModel, VcsSettingsModel
60 from rhodecode.model.settings import SettingsModel, VcsSettingsModel
61
61
62 log = logging.getLogger(__name__)
62 log = logging.getLogger(__name__)
63
63
64
64
65 def extract_svn_txn_id(acl_repo_name, data):
65 def extract_svn_txn_id(acl_repo_name, data):
66 """
66 """
67 Helper method for extraction of svn txn_id from submitted XML data during
67 Helper method for extraction of svn txn_id from submitted XML data during
68 POST operations
68 POST operations
69 """
69 """
70 try:
70 try:
71 root = etree.fromstring(data)
71 root = etree.fromstring(data)
72 pat = re.compile(r'/txn/(?P<txn_id>.*)')
72 pat = re.compile(r'/txn/(?P<txn_id>.*)')
73 for el in root:
73 for el in root:
74 if el.tag == '{DAV:}source':
74 if el.tag == '{DAV:}source':
75 for sub_el in el:
75 for sub_el in el:
76 if sub_el.tag == '{DAV:}href':
76 if sub_el.tag == '{DAV:}href':
77 match = pat.search(sub_el.text)
77 match = pat.search(sub_el.text)
78 if match:
78 if match:
79 svn_tx_id = match.groupdict()['txn_id']
79 svn_tx_id = match.groupdict()['txn_id']
80 txn_id = rc_cache.utils.compute_key_from_params(
80 txn_id = rc_cache.utils.compute_key_from_params(
81 acl_repo_name, svn_tx_id)
81 acl_repo_name, svn_tx_id)
82 return txn_id
82 return txn_id
83 except Exception:
83 except Exception:
84 log.exception('Failed to extract txn_id')
84 log.exception('Failed to extract txn_id')
85
85
86
86
87 def initialize_generator(factory):
87 def initialize_generator(factory):
88 """
88 """
89 Initializes the returned generator by draining its first element.
89 Initializes the returned generator by draining its first element.
90
90
91 This can be used to give a generator an initializer, which is the code
91 This can be used to give a generator an initializer, which is the code
92 up to the first yield statement. This decorator enforces that the first
92 up to the first yield statement. This decorator enforces that the first
93 produced element has the value ``"__init__"`` to make its special
93 produced element has the value ``"__init__"`` to make its special
94 purpose very explicit in the using code.
94 purpose very explicit in the using code.
95 """
95 """
96
96
97 @wraps(factory)
97 @wraps(factory)
98 def wrapper(*args, **kwargs):
98 def wrapper(*args, **kwargs):
99 gen = factory(*args, **kwargs)
99 gen = factory(*args, **kwargs)
100 try:
100 try:
101 init = next(gen)
101 init = next(gen)
102 except StopIteration:
102 except StopIteration:
103 raise ValueError('Generator must yield at least one element.')
103 raise ValueError('Generator must yield at least one element.')
104 if init != "__init__":
104 if init != "__init__":
105 raise ValueError('First yielded element must be "__init__".')
105 raise ValueError('First yielded element must be "__init__".')
106 return gen
106 return gen
107 return wrapper
107 return wrapper
108
108
109
109
110 class SimpleVCS(object):
110 class SimpleVCS(object):
111 """Common functionality for SCM HTTP handlers."""
111 """Common functionality for SCM HTTP handlers."""
112
112
113 SCM = 'unknown'
113 SCM = 'unknown'
114
114
115 acl_repo_name = None
115 acl_repo_name = None
116 url_repo_name = None
116 url_repo_name = None
117 vcs_repo_name = None
117 vcs_repo_name = None
118 rc_extras = {}
118 rc_extras = {}
119
119
120 # We have to handle requests to shadow repositories different than requests
120 # We have to handle requests to shadow repositories different than requests
121 # to normal repositories. Therefore we have to distinguish them. To do this
121 # to normal repositories. Therefore we have to distinguish them. To do this
122 # we use this regex which will match only on URLs pointing to shadow
122 # we use this regex which will match only on URLs pointing to shadow
123 # repositories.
123 # repositories.
124 shadow_repo_re = re.compile(
124 shadow_repo_re = re.compile(
125 '(?P<groups>(?:{slug_pat}/)*)' # repo groups
125 '(?P<groups>(?:{slug_pat}/)*)' # repo groups
126 '(?P<target>{slug_pat})/' # target repo
126 '(?P<target>{slug_pat})/' # target repo
127 'pull-request/(?P<pr_id>\d+)/' # pull request
127 'pull-request/(?P<pr_id>\\d+)/' # pull request
128 'repository$' # shadow repo
128 'repository$' # shadow repo
129 .format(slug_pat=SLUG_RE.pattern))
129 .format(slug_pat=SLUG_RE.pattern))
130
130
131 def __init__(self, config, registry):
131 def __init__(self, config, registry):
132 self.registry = registry
132 self.registry = registry
133 self.config = config
133 self.config = config
134 # re-populated by specialized middleware
134 # re-populated by specialized middleware
135 self.repo_vcs_config = base.Config()
135 self.repo_vcs_config = base.Config()
136
136
137 rc_settings = SettingsModel().get_all_settings(cache=True, from_request=False)
137 rc_settings = SettingsModel().get_all_settings(cache=True, from_request=False)
138 realm = rc_settings.get('rhodecode_realm') or 'RhodeCode AUTH'
138 realm = rc_settings.get('rhodecode_realm') or 'RhodeCode AUTH'
139
139
140 # authenticate this VCS request using authfunc
140 # authenticate this VCS request using authfunc
141 auth_ret_code_detection = \
141 auth_ret_code_detection = \
142 str2bool(self.config.get('auth_ret_code_detection', False))
142 str2bool(self.config.get('auth_ret_code_detection', False))
143 self.authenticate = BasicAuth(
143 self.authenticate = BasicAuth(
144 '', authenticate, registry, config.get('auth_ret_code'),
144 '', authenticate, registry, config.get('auth_ret_code'),
145 auth_ret_code_detection, rc_realm=realm)
145 auth_ret_code_detection, rc_realm=realm)
146 self.ip_addr = '0.0.0.0'
146 self.ip_addr = '0.0.0.0'
147
147
148 @LazyProperty
148 @LazyProperty
149 def global_vcs_config(self):
149 def global_vcs_config(self):
150 try:
150 try:
151 return VcsSettingsModel().get_ui_settings_as_config_obj()
151 return VcsSettingsModel().get_ui_settings_as_config_obj()
152 except Exception:
152 except Exception:
153 return base.Config()
153 return base.Config()
154
154
155 @property
155 @property
156 def base_path(self):
156 def base_path(self):
157 settings_path = self.repo_vcs_config.get(*VcsSettingsModel.PATH_SETTING)
157 settings_path = self.repo_vcs_config.get(*VcsSettingsModel.PATH_SETTING)
158
158
159 if not settings_path:
159 if not settings_path:
160 settings_path = self.global_vcs_config.get(*VcsSettingsModel.PATH_SETTING)
160 settings_path = self.global_vcs_config.get(*VcsSettingsModel.PATH_SETTING)
161
161
162 if not settings_path:
162 if not settings_path:
163 # try, maybe we passed in explicitly as config option
163 # try, maybe we passed in explicitly as config option
164 settings_path = self.config.get('base_path')
164 settings_path = self.config.get('base_path')
165
165
166 if not settings_path:
166 if not settings_path:
167 raise ValueError('FATAL: base_path is empty')
167 raise ValueError('FATAL: base_path is empty')
168 return settings_path
168 return settings_path
169
169
170 def set_repo_names(self, environ):
170 def set_repo_names(self, environ):
171 """
171 """
172 This will populate the attributes acl_repo_name, url_repo_name,
172 This will populate the attributes acl_repo_name, url_repo_name,
173 vcs_repo_name and is_shadow_repo. In case of requests to normal (non
173 vcs_repo_name and is_shadow_repo. In case of requests to normal (non
174 shadow) repositories all names are equal. In case of requests to a
174 shadow) repositories all names are equal. In case of requests to a
175 shadow repository the acl-name points to the target repo of the pull
175 shadow repository the acl-name points to the target repo of the pull
176 request and the vcs-name points to the shadow repo file system path.
176 request and the vcs-name points to the shadow repo file system path.
177 The url-name is always the URL used by the vcs client program.
177 The url-name is always the URL used by the vcs client program.
178
178
179 Example in case of a shadow repo:
179 Example in case of a shadow repo:
180 acl_repo_name = RepoGroup/MyRepo
180 acl_repo_name = RepoGroup/MyRepo
181 url_repo_name = RepoGroup/MyRepo/pull-request/3/repository
181 url_repo_name = RepoGroup/MyRepo/pull-request/3/repository
182 vcs_repo_name = /repo/base/path/RepoGroup/.__shadow_MyRepo_pr-3'
182 vcs_repo_name = /repo/base/path/RepoGroup/.__shadow_MyRepo_pr-3'
183 """
183 """
184 # First we set the repo name from URL for all attributes. This is the
184 # First we set the repo name from URL for all attributes. This is the
185 # default if handling normal (non shadow) repo requests.
185 # default if handling normal (non shadow) repo requests.
186 self.url_repo_name = self._get_repository_name(environ)
186 self.url_repo_name = self._get_repository_name(environ)
187 self.acl_repo_name = self.vcs_repo_name = self.url_repo_name
187 self.acl_repo_name = self.vcs_repo_name = self.url_repo_name
188 self.is_shadow_repo = False
188 self.is_shadow_repo = False
189
189
190 # Check if this is a request to a shadow repository.
190 # Check if this is a request to a shadow repository.
191 match = self.shadow_repo_re.match(self.url_repo_name)
191 match = self.shadow_repo_re.match(self.url_repo_name)
192 if match:
192 if match:
193 match_dict = match.groupdict()
193 match_dict = match.groupdict()
194
194
195 # Build acl repo name from regex match.
195 # Build acl repo name from regex match.
196 acl_repo_name = safe_unicode('{groups}{target}'.format(
196 acl_repo_name = safe_unicode('{groups}{target}'.format(
197 groups=match_dict['groups'] or '',
197 groups=match_dict['groups'] or '',
198 target=match_dict['target']))
198 target=match_dict['target']))
199
199
200 # Retrieve pull request instance by ID from regex match.
200 # Retrieve pull request instance by ID from regex match.
201 pull_request = PullRequest.get(match_dict['pr_id'])
201 pull_request = PullRequest.get(match_dict['pr_id'])
202
202
203 # Only proceed if we got a pull request and if acl repo name from
203 # Only proceed if we got a pull request and if acl repo name from
204 # URL equals the target repo name of the pull request.
204 # URL equals the target repo name of the pull request.
205 if pull_request and (acl_repo_name == pull_request.target_repo.repo_name):
205 if pull_request and (acl_repo_name == pull_request.target_repo.repo_name):
206
206
207 # Get file system path to shadow repository.
207 # Get file system path to shadow repository.
208 workspace_id = PullRequestModel()._workspace_id(pull_request)
208 workspace_id = PullRequestModel()._workspace_id(pull_request)
209 vcs_repo_name = pull_request.target_repo.get_shadow_repository_path(workspace_id)
209 vcs_repo_name = pull_request.target_repo.get_shadow_repository_path(workspace_id)
210
210
211 # Store names for later usage.
211 # Store names for later usage.
212 self.vcs_repo_name = vcs_repo_name
212 self.vcs_repo_name = vcs_repo_name
213 self.acl_repo_name = acl_repo_name
213 self.acl_repo_name = acl_repo_name
214 self.is_shadow_repo = True
214 self.is_shadow_repo = True
215
215
216 log.debug('Setting all VCS repository names: %s', {
216 log.debug('Setting all VCS repository names: %s', {
217 'acl_repo_name': self.acl_repo_name,
217 'acl_repo_name': self.acl_repo_name,
218 'url_repo_name': self.url_repo_name,
218 'url_repo_name': self.url_repo_name,
219 'vcs_repo_name': self.vcs_repo_name,
219 'vcs_repo_name': self.vcs_repo_name,
220 })
220 })
221
221
222 @property
222 @property
223 def scm_app(self):
223 def scm_app(self):
224 custom_implementation = self.config['vcs.scm_app_implementation']
224 custom_implementation = self.config['vcs.scm_app_implementation']
225 if custom_implementation == 'http':
225 if custom_implementation == 'http':
226 log.debug('Using HTTP implementation of scm app.')
226 log.debug('Using HTTP implementation of scm app.')
227 scm_app_impl = scm_app_http
227 scm_app_impl = scm_app_http
228 else:
228 else:
229 log.debug('Using custom implementation of scm_app: "{}"'.format(
229 log.debug('Using custom implementation of scm_app: "{}"'.format(
230 custom_implementation))
230 custom_implementation))
231 scm_app_impl = importlib.import_module(custom_implementation)
231 scm_app_impl = importlib.import_module(custom_implementation)
232 return scm_app_impl
232 return scm_app_impl
233
233
234 def _get_by_id(self, repo_name):
234 def _get_by_id(self, repo_name):
235 """
235 """
236 Gets a special pattern _<ID> from clone url and tries to replace it
236 Gets a special pattern _<ID> from clone url and tries to replace it
237 with a repository_name for support of _<ID> non changeable urls
237 with a repository_name for support of _<ID> non changeable urls
238 """
238 """
239
239
240 data = repo_name.split('/')
240 data = repo_name.split('/')
241 if len(data) >= 2:
241 if len(data) >= 2:
242 from rhodecode.model.repo import RepoModel
242 from rhodecode.model.repo import RepoModel
243 by_id_match = RepoModel().get_repo_by_id(repo_name)
243 by_id_match = RepoModel().get_repo_by_id(repo_name)
244 if by_id_match:
244 if by_id_match:
245 data[1] = by_id_match.repo_name
245 data[1] = by_id_match.repo_name
246
246
247 return safe_str('/'.join(data))
247 return safe_str('/'.join(data))
248
248
249 def _invalidate_cache(self, repo_name):
249 def _invalidate_cache(self, repo_name):
250 """
250 """
251 Set's cache for this repository for invalidation on next access
251 Set's cache for this repository for invalidation on next access
252
252
253 :param repo_name: full repo name, also a cache key
253 :param repo_name: full repo name, also a cache key
254 """
254 """
255 ScmModel().mark_for_invalidation(repo_name)
255 ScmModel().mark_for_invalidation(repo_name)
256
256
257 def is_valid_and_existing_repo(self, repo_name, base_path, scm_type):
257 def is_valid_and_existing_repo(self, repo_name, base_path, scm_type):
258 db_repo = Repository.get_by_repo_name(repo_name)
258 db_repo = Repository.get_by_repo_name(repo_name)
259 if not db_repo:
259 if not db_repo:
260 log.debug('Repository `%s` not found inside the database.',
260 log.debug('Repository `%s` not found inside the database.',
261 repo_name)
261 repo_name)
262 return False
262 return False
263
263
264 if db_repo.repo_type != scm_type:
264 if db_repo.repo_type != scm_type:
265 log.warning(
265 log.warning(
266 'Repository `%s` have incorrect scm_type, expected %s got %s',
266 'Repository `%s` have incorrect scm_type, expected %s got %s',
267 repo_name, db_repo.repo_type, scm_type)
267 repo_name, db_repo.repo_type, scm_type)
268 return False
268 return False
269
269
270 config = db_repo._config
270 config = db_repo._config
271 config.set('extensions', 'largefiles', '')
271 config.set('extensions', 'largefiles', '')
272 return is_valid_repo(
272 return is_valid_repo(
273 repo_name, base_path,
273 repo_name, base_path,
274 explicit_scm=scm_type, expect_scm=scm_type, config=config)
274 explicit_scm=scm_type, expect_scm=scm_type, config=config)
275
275
276 def valid_and_active_user(self, user):
276 def valid_and_active_user(self, user):
277 """
277 """
278 Checks if that user is not empty, and if it's actually object it checks
278 Checks if that user is not empty, and if it's actually object it checks
279 if he's active.
279 if he's active.
280
280
281 :param user: user object or None
281 :param user: user object or None
282 :return: boolean
282 :return: boolean
283 """
283 """
284 if user is None:
284 if user is None:
285 return False
285 return False
286
286
287 elif user.active:
287 elif user.active:
288 return True
288 return True
289
289
290 return False
290 return False
291
291
292 @property
292 @property
293 def is_shadow_repo_dir(self):
293 def is_shadow_repo_dir(self):
294 return os.path.isdir(self.vcs_repo_name)
294 return os.path.isdir(self.vcs_repo_name)
295
295
296 def _check_permission(self, action, user, auth_user, repo_name, ip_addr=None,
296 def _check_permission(self, action, user, auth_user, repo_name, ip_addr=None,
297 plugin_id='', plugin_cache_active=False, cache_ttl=0):
297 plugin_id='', plugin_cache_active=False, cache_ttl=0):
298 """
298 """
299 Checks permissions using action (push/pull) user and repository
299 Checks permissions using action (push/pull) user and repository
300 name. If plugin_cache and ttl is set it will use the plugin which
300 name. If plugin_cache and ttl is set it will use the plugin which
301 authenticated the user to store the cached permissions result for N
301 authenticated the user to store the cached permissions result for N
302 amount of seconds as in cache_ttl
302 amount of seconds as in cache_ttl
303
303
304 :param action: push or pull action
304 :param action: push or pull action
305 :param user: user instance
305 :param user: user instance
306 :param repo_name: repository name
306 :param repo_name: repository name
307 """
307 """
308
308
309 log.debug('AUTH_CACHE_TTL for permissions `%s` active: %s (TTL: %s)',
309 log.debug('AUTH_CACHE_TTL for permissions `%s` active: %s (TTL: %s)',
310 plugin_id, plugin_cache_active, cache_ttl)
310 plugin_id, plugin_cache_active, cache_ttl)
311
311
312 user_id = user.user_id
312 user_id = user.user_id
313 cache_namespace_uid = 'cache_user_auth.{}'.format(user_id)
313 cache_namespace_uid = 'cache_user_auth.{}'.format(user_id)
314 region = rc_cache.get_or_create_region('cache_perms', cache_namespace_uid)
314 region = rc_cache.get_or_create_region('cache_perms', cache_namespace_uid)
315
315
316 @region.conditional_cache_on_arguments(namespace=cache_namespace_uid,
316 @region.conditional_cache_on_arguments(namespace=cache_namespace_uid,
317 expiration_time=cache_ttl,
317 expiration_time=cache_ttl,
318 condition=plugin_cache_active)
318 condition=plugin_cache_active)
319 def compute_perm_vcs(
319 def compute_perm_vcs(
320 cache_name, plugin_id, action, user_id, repo_name, ip_addr):
320 cache_name, plugin_id, action, user_id, repo_name, ip_addr):
321
321
322 log.debug('auth: calculating permission access now...')
322 log.debug('auth: calculating permission access now...')
323 # check IP
323 # check IP
324 inherit = user.inherit_default_permissions
324 inherit = user.inherit_default_permissions
325 ip_allowed = AuthUser.check_ip_allowed(
325 ip_allowed = AuthUser.check_ip_allowed(
326 user_id, ip_addr, inherit_from_default=inherit)
326 user_id, ip_addr, inherit_from_default=inherit)
327 if ip_allowed:
327 if ip_allowed:
328 log.info('Access for IP:%s allowed', ip_addr)
328 log.info('Access for IP:%s allowed', ip_addr)
329 else:
329 else:
330 return False
330 return False
331
331
332 if action == 'push':
332 if action == 'push':
333 perms = ('repository.write', 'repository.admin')
333 perms = ('repository.write', 'repository.admin')
334 if not HasPermissionAnyMiddleware(*perms)(auth_user, repo_name):
334 if not HasPermissionAnyMiddleware(*perms)(auth_user, repo_name):
335 return False
335 return False
336
336
337 else:
337 else:
338 # any other action need at least read permission
338 # any other action need at least read permission
339 perms = (
339 perms = (
340 'repository.read', 'repository.write', 'repository.admin')
340 'repository.read', 'repository.write', 'repository.admin')
341 if not HasPermissionAnyMiddleware(*perms)(auth_user, repo_name):
341 if not HasPermissionAnyMiddleware(*perms)(auth_user, repo_name):
342 return False
342 return False
343
343
344 return True
344 return True
345
345
346 start = time.time()
346 start = time.time()
347 log.debug('Running plugin `%s` permissions check', plugin_id)
347 log.debug('Running plugin `%s` permissions check', plugin_id)
348
348
349 # for environ based auth, password can be empty, but then the validation is
349 # for environ based auth, password can be empty, but then the validation is
350 # on the server that fills in the env data needed for authentication
350 # on the server that fills in the env data needed for authentication
351 perm_result = compute_perm_vcs(
351 perm_result = compute_perm_vcs(
352 'vcs_permissions', plugin_id, action, user.user_id, repo_name, ip_addr)
352 'vcs_permissions', plugin_id, action, user.user_id, repo_name, ip_addr)
353
353
354 auth_time = time.time() - start
354 auth_time = time.time() - start
355 log.debug('Permissions for plugin `%s` completed in %.4fs, '
355 log.debug('Permissions for plugin `%s` completed in %.4fs, '
356 'expiration time of fetched cache %.1fs.',
356 'expiration time of fetched cache %.1fs.',
357 plugin_id, auth_time, cache_ttl)
357 plugin_id, auth_time, cache_ttl)
358
358
359 return perm_result
359 return perm_result
360
360
361 def _get_http_scheme(self, environ):
361 def _get_http_scheme(self, environ):
362 try:
362 try:
363 return environ['wsgi.url_scheme']
363 return environ['wsgi.url_scheme']
364 except Exception:
364 except Exception:
365 log.exception('Failed to read http scheme')
365 log.exception('Failed to read http scheme')
366 return 'http'
366 return 'http'
367
367
368 def _check_ssl(self, environ, start_response):
368 def _check_ssl(self, environ, start_response):
369 """
369 """
370 Checks the SSL check flag and returns False if SSL is not present
370 Checks the SSL check flag and returns False if SSL is not present
371 and required True otherwise
371 and required True otherwise
372 """
372 """
373 org_proto = environ['wsgi._org_proto']
373 org_proto = environ['wsgi._org_proto']
374 # check if we have SSL required ! if not it's a bad request !
374 # check if we have SSL required ! if not it's a bad request !
375 require_ssl = str2bool(self.repo_vcs_config.get('web', 'push_ssl'))
375 require_ssl = str2bool(self.repo_vcs_config.get('web', 'push_ssl'))
376 if require_ssl and org_proto == 'http':
376 if require_ssl and org_proto == 'http':
377 log.debug(
377 log.debug(
378 'Bad request: detected protocol is `%s` and '
378 'Bad request: detected protocol is `%s` and '
379 'SSL/HTTPS is required.', org_proto)
379 'SSL/HTTPS is required.', org_proto)
380 return False
380 return False
381 return True
381 return True
382
382
383 def _get_default_cache_ttl(self):
383 def _get_default_cache_ttl(self):
384 # take AUTH_CACHE_TTL from the `rhodecode` auth plugin
384 # take AUTH_CACHE_TTL from the `rhodecode` auth plugin
385 plugin = loadplugin('egg:rhodecode-enterprise-ce#rhodecode')
385 plugin = loadplugin('egg:rhodecode-enterprise-ce#rhodecode')
386 plugin_settings = plugin.get_settings()
386 plugin_settings = plugin.get_settings()
387 plugin_cache_active, cache_ttl = plugin.get_ttl_cache(
387 plugin_cache_active, cache_ttl = plugin.get_ttl_cache(
388 plugin_settings) or (False, 0)
388 plugin_settings) or (False, 0)
389 return plugin_cache_active, cache_ttl
389 return plugin_cache_active, cache_ttl
390
390
391 def __call__(self, environ, start_response):
391 def __call__(self, environ, start_response):
392 try:
392 try:
393 return self._handle_request(environ, start_response)
393 return self._handle_request(environ, start_response)
394 except Exception:
394 except Exception:
395 log.exception("Exception while handling request")
395 log.exception("Exception while handling request")
396 appenlight.track_exception(environ)
396 appenlight.track_exception(environ)
397 return HTTPInternalServerError()(environ, start_response)
397 return HTTPInternalServerError()(environ, start_response)
398 finally:
398 finally:
399 meta.Session.remove()
399 meta.Session.remove()
400
400
401 def _handle_request(self, environ, start_response):
401 def _handle_request(self, environ, start_response):
402 if not self._check_ssl(environ, start_response):
402 if not self._check_ssl(environ, start_response):
403 reason = ('SSL required, while RhodeCode was unable '
403 reason = ('SSL required, while RhodeCode was unable '
404 'to detect this as SSL request')
404 'to detect this as SSL request')
405 log.debug('User not allowed to proceed, %s', reason)
405 log.debug('User not allowed to proceed, %s', reason)
406 return HTTPNotAcceptable(reason)(environ, start_response)
406 return HTTPNotAcceptable(reason)(environ, start_response)
407
407
408 if not self.url_repo_name:
408 if not self.url_repo_name:
409 log.warning('Repository name is empty: %s', self.url_repo_name)
409 log.warning('Repository name is empty: %s', self.url_repo_name)
410 # failed to get repo name, we fail now
410 # failed to get repo name, we fail now
411 return HTTPNotFound()(environ, start_response)
411 return HTTPNotFound()(environ, start_response)
412 log.debug('Extracted repo name is %s', self.url_repo_name)
412 log.debug('Extracted repo name is %s', self.url_repo_name)
413
413
414 ip_addr = get_ip_addr(environ)
414 ip_addr = get_ip_addr(environ)
415 user_agent = get_user_agent(environ)
415 user_agent = get_user_agent(environ)
416 username = None
416 username = None
417
417
418 # skip passing error to error controller
418 # skip passing error to error controller
419 environ['pylons.status_code_redirect'] = True
419 environ['pylons.status_code_redirect'] = True
420
420
421 # ======================================================================
421 # ======================================================================
422 # GET ACTION PULL or PUSH
422 # GET ACTION PULL or PUSH
423 # ======================================================================
423 # ======================================================================
424 action = self._get_action(environ)
424 action = self._get_action(environ)
425
425
426 # ======================================================================
426 # ======================================================================
427 # Check if this is a request to a shadow repository of a pull request.
427 # Check if this is a request to a shadow repository of a pull request.
428 # In this case only pull action is allowed.
428 # In this case only pull action is allowed.
429 # ======================================================================
429 # ======================================================================
430 if self.is_shadow_repo and action != 'pull':
430 if self.is_shadow_repo and action != 'pull':
431 reason = 'Only pull action is allowed for shadow repositories.'
431 reason = 'Only pull action is allowed for shadow repositories.'
432 log.debug('User not allowed to proceed, %s', reason)
432 log.debug('User not allowed to proceed, %s', reason)
433 return HTTPNotAcceptable(reason)(environ, start_response)
433 return HTTPNotAcceptable(reason)(environ, start_response)
434
434
435 # Check if the shadow repo actually exists, in case someone refers
435 # Check if the shadow repo actually exists, in case someone refers
436 # to it, and it has been deleted because of successful merge.
436 # to it, and it has been deleted because of successful merge.
437 if self.is_shadow_repo and not self.is_shadow_repo_dir:
437 if self.is_shadow_repo and not self.is_shadow_repo_dir:
438 log.debug(
438 log.debug(
439 'Shadow repo detected, and shadow repo dir `%s` is missing',
439 'Shadow repo detected, and shadow repo dir `%s` is missing',
440 self.is_shadow_repo_dir)
440 self.is_shadow_repo_dir)
441 return HTTPNotFound()(environ, start_response)
441 return HTTPNotFound()(environ, start_response)
442
442
443 # ======================================================================
443 # ======================================================================
444 # CHECK ANONYMOUS PERMISSION
444 # CHECK ANONYMOUS PERMISSION
445 # ======================================================================
445 # ======================================================================
446 detect_force_push = False
446 detect_force_push = False
447 check_branch_perms = False
447 check_branch_perms = False
448 if action in ['pull', 'push']:
448 if action in ['pull', 'push']:
449 user_obj = anonymous_user = User.get_default_user()
449 user_obj = anonymous_user = User.get_default_user()
450 auth_user = user_obj.AuthUser()
450 auth_user = user_obj.AuthUser()
451 username = anonymous_user.username
451 username = anonymous_user.username
452 if anonymous_user.active:
452 if anonymous_user.active:
453 plugin_cache_active, cache_ttl = self._get_default_cache_ttl()
453 plugin_cache_active, cache_ttl = self._get_default_cache_ttl()
454 # ONLY check permissions if the user is activated
454 # ONLY check permissions if the user is activated
455 anonymous_perm = self._check_permission(
455 anonymous_perm = self._check_permission(
456 action, anonymous_user, auth_user, self.acl_repo_name, ip_addr,
456 action, anonymous_user, auth_user, self.acl_repo_name, ip_addr,
457 plugin_id='anonymous_access',
457 plugin_id='anonymous_access',
458 plugin_cache_active=plugin_cache_active,
458 plugin_cache_active=plugin_cache_active,
459 cache_ttl=cache_ttl,
459 cache_ttl=cache_ttl,
460 )
460 )
461 else:
461 else:
462 anonymous_perm = False
462 anonymous_perm = False
463
463
464 if not anonymous_user.active or not anonymous_perm:
464 if not anonymous_user.active or not anonymous_perm:
465 if not anonymous_user.active:
465 if not anonymous_user.active:
466 log.debug('Anonymous access is disabled, running '
466 log.debug('Anonymous access is disabled, running '
467 'authentication')
467 'authentication')
468
468
469 if not anonymous_perm:
469 if not anonymous_perm:
470 log.debug('Not enough credentials to access this '
470 log.debug('Not enough credentials to access this '
471 'repository as anonymous user')
471 'repository as anonymous user')
472
472
473 username = None
473 username = None
474 # ==============================================================
474 # ==============================================================
475 # DEFAULT PERM FAILED OR ANONYMOUS ACCESS IS DISABLED SO WE
475 # DEFAULT PERM FAILED OR ANONYMOUS ACCESS IS DISABLED SO WE
476 # NEED TO AUTHENTICATE AND ASK FOR AUTH USER PERMISSIONS
476 # NEED TO AUTHENTICATE AND ASK FOR AUTH USER PERMISSIONS
477 # ==============================================================
477 # ==============================================================
478
478
479 # try to auth based on environ, container auth methods
479 # try to auth based on environ, container auth methods
480 log.debug('Running PRE-AUTH for container based authentication')
480 log.debug('Running PRE-AUTH for container based authentication')
481 pre_auth = authenticate(
481 pre_auth = authenticate(
482 '', '', environ, VCS_TYPE, registry=self.registry,
482 '', '', environ, VCS_TYPE, registry=self.registry,
483 acl_repo_name=self.acl_repo_name)
483 acl_repo_name=self.acl_repo_name)
484 if pre_auth and pre_auth.get('username'):
484 if pre_auth and pre_auth.get('username'):
485 username = pre_auth['username']
485 username = pre_auth['username']
486 log.debug('PRE-AUTH got %s as username', username)
486 log.debug('PRE-AUTH got %s as username', username)
487 if pre_auth:
487 if pre_auth:
488 log.debug('PRE-AUTH successful from %s',
488 log.debug('PRE-AUTH successful from %s',
489 pre_auth.get('auth_data', {}).get('_plugin'))
489 pre_auth.get('auth_data', {}).get('_plugin'))
490
490
491 # If not authenticated by the container, running basic auth
491 # If not authenticated by the container, running basic auth
492 # before inject the calling repo_name for special scope checks
492 # before inject the calling repo_name for special scope checks
493 self.authenticate.acl_repo_name = self.acl_repo_name
493 self.authenticate.acl_repo_name = self.acl_repo_name
494
494
495 plugin_cache_active, cache_ttl = False, 0
495 plugin_cache_active, cache_ttl = False, 0
496 plugin = None
496 plugin = None
497 if not username:
497 if not username:
498 self.authenticate.realm = self.authenticate.get_rc_realm()
498 self.authenticate.realm = self.authenticate.get_rc_realm()
499
499
500 try:
500 try:
501 auth_result = self.authenticate(environ)
501 auth_result = self.authenticate(environ)
502 except (UserCreationError, NotAllowedToCreateUserError) as e:
502 except (UserCreationError, NotAllowedToCreateUserError) as e:
503 log.error(e)
503 log.error(e)
504 reason = safe_str(e)
504 reason = safe_str(e)
505 return HTTPNotAcceptable(reason)(environ, start_response)
505 return HTTPNotAcceptable(reason)(environ, start_response)
506
506
507 if isinstance(auth_result, dict):
507 if isinstance(auth_result, dict):
508 AUTH_TYPE.update(environ, 'basic')
508 AUTH_TYPE.update(environ, 'basic')
509 REMOTE_USER.update(environ, auth_result['username'])
509 REMOTE_USER.update(environ, auth_result['username'])
510 username = auth_result['username']
510 username = auth_result['username']
511 plugin = auth_result.get('auth_data', {}).get('_plugin')
511 plugin = auth_result.get('auth_data', {}).get('_plugin')
512 log.info(
512 log.info(
513 'MAIN-AUTH successful for user `%s` from %s plugin',
513 'MAIN-AUTH successful for user `%s` from %s plugin',
514 username, plugin)
514 username, plugin)
515
515
516 plugin_cache_active, cache_ttl = auth_result.get(
516 plugin_cache_active, cache_ttl = auth_result.get(
517 'auth_data', {}).get('_ttl_cache') or (False, 0)
517 'auth_data', {}).get('_ttl_cache') or (False, 0)
518 else:
518 else:
519 return auth_result.wsgi_application(environ, start_response)
519 return auth_result.wsgi_application(environ, start_response)
520
520
521 # ==============================================================
521 # ==============================================================
522 # CHECK PERMISSIONS FOR THIS REQUEST USING GIVEN USERNAME
522 # CHECK PERMISSIONS FOR THIS REQUEST USING GIVEN USERNAME
523 # ==============================================================
523 # ==============================================================
524 user = User.get_by_username(username)
524 user = User.get_by_username(username)
525 if not self.valid_and_active_user(user):
525 if not self.valid_and_active_user(user):
526 return HTTPForbidden()(environ, start_response)
526 return HTTPForbidden()(environ, start_response)
527 username = user.username
527 username = user.username
528 user_id = user.user_id
528 user_id = user.user_id
529
529
530 # check user attributes for password change flag
530 # check user attributes for password change flag
531 user_obj = user
531 user_obj = user
532 auth_user = user_obj.AuthUser()
532 auth_user = user_obj.AuthUser()
533 if user_obj and user_obj.username != User.DEFAULT_USER and \
533 if user_obj and user_obj.username != User.DEFAULT_USER and \
534 user_obj.user_data.get('force_password_change'):
534 user_obj.user_data.get('force_password_change'):
535 reason = 'password change required'
535 reason = 'password change required'
536 log.debug('User not allowed to authenticate, %s', reason)
536 log.debug('User not allowed to authenticate, %s', reason)
537 return HTTPNotAcceptable(reason)(environ, start_response)
537 return HTTPNotAcceptable(reason)(environ, start_response)
538
538
539 # check permissions for this repository
539 # check permissions for this repository
540 perm = self._check_permission(
540 perm = self._check_permission(
541 action, user, auth_user, self.acl_repo_name, ip_addr,
541 action, user, auth_user, self.acl_repo_name, ip_addr,
542 plugin, plugin_cache_active, cache_ttl)
542 plugin, plugin_cache_active, cache_ttl)
543 if not perm:
543 if not perm:
544 return HTTPForbidden()(environ, start_response)
544 return HTTPForbidden()(environ, start_response)
545 environ['rc_auth_user_id'] = user_id
545 environ['rc_auth_user_id'] = user_id
546
546
547 if action == 'push':
547 if action == 'push':
548 perms = auth_user.get_branch_permissions(self.acl_repo_name)
548 perms = auth_user.get_branch_permissions(self.acl_repo_name)
549 if perms:
549 if perms:
550 check_branch_perms = True
550 check_branch_perms = True
551 detect_force_push = True
551 detect_force_push = True
552
552
553 # extras are injected into UI object and later available
553 # extras are injected into UI object and later available
554 # in hooks executed by RhodeCode
554 # in hooks executed by RhodeCode
555 check_locking = _should_check_locking(environ.get('QUERY_STRING'))
555 check_locking = _should_check_locking(environ.get('QUERY_STRING'))
556
556
557 extras = vcs_operation_context(
557 extras = vcs_operation_context(
558 environ, repo_name=self.acl_repo_name, username=username,
558 environ, repo_name=self.acl_repo_name, username=username,
559 action=action, scm=self.SCM, check_locking=check_locking,
559 action=action, scm=self.SCM, check_locking=check_locking,
560 is_shadow_repo=self.is_shadow_repo, check_branch_perms=check_branch_perms,
560 is_shadow_repo=self.is_shadow_repo, check_branch_perms=check_branch_perms,
561 detect_force_push=detect_force_push
561 detect_force_push=detect_force_push
562 )
562 )
563
563
564 # ======================================================================
564 # ======================================================================
565 # REQUEST HANDLING
565 # REQUEST HANDLING
566 # ======================================================================
566 # ======================================================================
567 repo_path = os.path.join(
567 repo_path = os.path.join(
568 safe_str(self.base_path), safe_str(self.vcs_repo_name))
568 safe_str(self.base_path), safe_str(self.vcs_repo_name))
569 log.debug('Repository path is %s', repo_path)
569 log.debug('Repository path is %s', repo_path)
570
570
571 fix_PATH()
571 fix_PATH()
572
572
573 log.info(
573 log.info(
574 '%s action on %s repo "%s" by "%s" from %s %s',
574 '%s action on %s repo "%s" by "%s" from %s %s',
575 action, self.SCM, safe_str(self.url_repo_name),
575 action, self.SCM, safe_str(self.url_repo_name),
576 safe_str(username), ip_addr, user_agent)
576 safe_str(username), ip_addr, user_agent)
577
577
578 return self._generate_vcs_response(
578 return self._generate_vcs_response(
579 environ, start_response, repo_path, extras, action)
579 environ, start_response, repo_path, extras, action)
580
580
581 @initialize_generator
581 @initialize_generator
582 def _generate_vcs_response(
582 def _generate_vcs_response(
583 self, environ, start_response, repo_path, extras, action):
583 self, environ, start_response, repo_path, extras, action):
584 """
584 """
585 Returns a generator for the response content.
585 Returns a generator for the response content.
586
586
587 This method is implemented as a generator, so that it can trigger
587 This method is implemented as a generator, so that it can trigger
588 the cache validation after all content sent back to the client. It
588 the cache validation after all content sent back to the client. It
589 also handles the locking exceptions which will be triggered when
589 also handles the locking exceptions which will be triggered when
590 the first chunk is produced by the underlying WSGI application.
590 the first chunk is produced by the underlying WSGI application.
591 """
591 """
592 txn_id = ''
592 txn_id = ''
593 if 'CONTENT_LENGTH' in environ and environ['REQUEST_METHOD'] == 'MERGE':
593 if 'CONTENT_LENGTH' in environ and environ['REQUEST_METHOD'] == 'MERGE':
594 # case for SVN, we want to re-use the callback daemon port
594 # case for SVN, we want to re-use the callback daemon port
595 # so we use the txn_id, for this we peek the body, and still save
595 # so we use the txn_id, for this we peek the body, and still save
596 # it as wsgi.input
596 # it as wsgi.input
597 data = environ['wsgi.input'].read()
597 data = environ['wsgi.input'].read()
598 environ['wsgi.input'] = StringIO(data)
598 environ['wsgi.input'] = io.StringIO(data)
599 txn_id = extract_svn_txn_id(self.acl_repo_name, data)
599 txn_id = extract_svn_txn_id(self.acl_repo_name, data)
600
600
601 callback_daemon, extras = self._prepare_callback_daemon(
601 callback_daemon, extras = self._prepare_callback_daemon(
602 extras, environ, action, txn_id=txn_id)
602 extras, environ, action, txn_id=txn_id)
603 log.debug('HOOKS extras is %s', extras)
603 log.debug('HOOKS extras is %s', extras)
604
604
605 http_scheme = self._get_http_scheme(environ)
605 http_scheme = self._get_http_scheme(environ)
606
606
607 config = self._create_config(extras, self.acl_repo_name, scheme=http_scheme)
607 config = self._create_config(extras, self.acl_repo_name, scheme=http_scheme)
608 app = self._create_wsgi_app(repo_path, self.url_repo_name, config)
608 app = self._create_wsgi_app(repo_path, self.url_repo_name, config)
609 with callback_daemon:
609 with callback_daemon:
610 app.rc_extras = extras
610 app.rc_extras = extras
611
611
612 try:
612 try:
613 response = app(environ, start_response)
613 response = app(environ, start_response)
614 finally:
614 finally:
615 # This statement works together with the decorator
615 # This statement works together with the decorator
616 # "initialize_generator" above. The decorator ensures that
616 # "initialize_generator" above. The decorator ensures that
617 # we hit the first yield statement before the generator is
617 # we hit the first yield statement before the generator is
618 # returned back to the WSGI server. This is needed to
618 # returned back to the WSGI server. This is needed to
619 # ensure that the call to "app" above triggers the
619 # ensure that the call to "app" above triggers the
620 # needed callback to "start_response" before the
620 # needed callback to "start_response" before the
621 # generator is actually used.
621 # generator is actually used.
622 yield "__init__"
622 yield "__init__"
623
623
624 # iter content
624 # iter content
625 for chunk in response:
625 for chunk in response:
626 yield chunk
626 yield chunk
627
627
628 try:
628 try:
629 # invalidate cache on push
629 # invalidate cache on push
630 if action == 'push':
630 if action == 'push':
631 self._invalidate_cache(self.url_repo_name)
631 self._invalidate_cache(self.url_repo_name)
632 finally:
632 finally:
633 meta.Session.remove()
633 meta.Session.remove()
634
634
635 def _get_repository_name(self, environ):
635 def _get_repository_name(self, environ):
636 """Get repository name out of the environmnent
636 """Get repository name out of the environmnent
637
637
638 :param environ: WSGI environment
638 :param environ: WSGI environment
639 """
639 """
640 raise NotImplementedError()
640 raise NotImplementedError()
641
641
642 def _get_action(self, environ):
642 def _get_action(self, environ):
643 """Map request commands into a pull or push command.
643 """Map request commands into a pull or push command.
644
644
645 :param environ: WSGI environment
645 :param environ: WSGI environment
646 """
646 """
647 raise NotImplementedError()
647 raise NotImplementedError()
648
648
649 def _create_wsgi_app(self, repo_path, repo_name, config):
649 def _create_wsgi_app(self, repo_path, repo_name, config):
650 """Return the WSGI app that will finally handle the request."""
650 """Return the WSGI app that will finally handle the request."""
651 raise NotImplementedError()
651 raise NotImplementedError()
652
652
653 def _create_config(self, extras, repo_name, scheme='http'):
653 def _create_config(self, extras, repo_name, scheme='http'):
654 """Create a safe config representation."""
654 """Create a safe config representation."""
655 raise NotImplementedError()
655 raise NotImplementedError()
656
656
657 def _should_use_callback_daemon(self, extras, environ, action):
657 def _should_use_callback_daemon(self, extras, environ, action):
658 if extras.get('is_shadow_repo'):
658 if extras.get('is_shadow_repo'):
659 # we don't want to execute hooks, and callback daemon for shadow repos
659 # we don't want to execute hooks, and callback daemon for shadow repos
660 return False
660 return False
661 return True
661 return True
662
662
663 def _prepare_callback_daemon(self, extras, environ, action, txn_id=None):
663 def _prepare_callback_daemon(self, extras, environ, action, txn_id=None):
664 direct_calls = vcs_settings.HOOKS_DIRECT_CALLS
664 direct_calls = vcs_settings.HOOKS_DIRECT_CALLS
665 if not self._should_use_callback_daemon(extras, environ, action):
665 if not self._should_use_callback_daemon(extras, environ, action):
666 # disable callback daemon for actions that don't require it
666 # disable callback daemon for actions that don't require it
667 direct_calls = True
667 direct_calls = True
668
668
669 return prepare_callback_daemon(
669 return prepare_callback_daemon(
670 extras, protocol=vcs_settings.HOOKS_PROTOCOL,
670 extras, protocol=vcs_settings.HOOKS_PROTOCOL,
671 host=vcs_settings.HOOKS_HOST, use_direct_calls=direct_calls, txn_id=txn_id)
671 host=vcs_settings.HOOKS_HOST, use_direct_calls=direct_calls, txn_id=txn_id)
672
672
673
673
674 def _should_check_locking(query_string):
674 def _should_check_locking(query_string):
675 # this is kind of hacky, but due to how mercurial handles client-server
675 # this is kind of hacky, but due to how mercurial handles client-server
676 # server see all operation on commit; bookmarks, phases and
676 # server see all operation on commit; bookmarks, phases and
677 # obsolescence marker in different transaction, we don't want to check
677 # obsolescence marker in different transaction, we don't want to check
678 # locking on those
678 # locking on those
679 return query_string not in ['cmd=listkeys']
679 return query_string not in ['cmd=listkeys']
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
General Comments 0
You need to be logged in to leave comments. Login now