##// END OF EJS Templates
python3: fixed various code issues...
super-admin -
r4973:5e52ba1a default
parent child Browse files
Show More

The requested changes are too big and content was truncated. Show full diff

@@ -1,578 +1,578 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import itertools
22 22 import logging
23 23 import sys
24 24 import types
25 25 import fnmatch
26 26
27 27 import decorator
28 28 import venusian
29 29 from collections import OrderedDict
30 30
31 31 from pyramid.exceptions import ConfigurationError
32 32 from pyramid.renderers import render
33 33 from pyramid.response import Response
34 34 from pyramid.httpexceptions import HTTPNotFound
35 35
36 36 from rhodecode.api.exc import (
37 37 JSONRPCBaseError, JSONRPCError, JSONRPCForbidden, JSONRPCValidationError)
38 38 from rhodecode.apps._base import TemplateArgs
39 39 from rhodecode.lib.auth import AuthUser
40 40 from rhodecode.lib.base import get_ip_addr, attach_context_attributes
41 41 from rhodecode.lib.exc_tracking import store_exception
42 42 from rhodecode.lib.ext_json import json
43 43 from rhodecode.lib.utils2 import safe_str
44 44 from rhodecode.lib.plugins.utils import get_plugin_settings
45 45 from rhodecode.model.db import User, UserApiKeys
46 46
47 47 log = logging.getLogger(__name__)
48 48
49 49 DEFAULT_RENDERER = 'jsonrpc_renderer'
50 50 DEFAULT_URL = '/_admin/apiv2'
51 51
52 52
53 53 def find_methods(jsonrpc_methods, pattern):
54 54 matches = OrderedDict()
55 55 if not isinstance(pattern, (list, tuple)):
56 56 pattern = [pattern]
57 57
58 58 for single_pattern in pattern:
59 59 for method_name, method in jsonrpc_methods.items():
60 60 if fnmatch.fnmatch(method_name, single_pattern):
61 61 matches[method_name] = method
62 62 return matches
63 63
64 64
65 65 class ExtJsonRenderer(object):
66 66 """
67 67 Custom renderer that mkaes use of our ext_json lib
68 68
69 69 """
70 70
71 71 def __init__(self, serializer=json.dumps, **kw):
72 72 """ Any keyword arguments will be passed to the ``serializer``
73 73 function."""
74 74 self.serializer = serializer
75 75 self.kw = kw
76 76
77 77 def __call__(self, info):
78 78 """ Returns a plain JSON-encoded string with content-type
79 79 ``application/json``. The content-type may be overridden by
80 80 setting ``request.response.content_type``."""
81 81
82 82 def _render(value, system):
83 83 request = system.get('request')
84 84 if request is not None:
85 85 response = request.response
86 86 ct = response.content_type
87 87 if ct == response.default_content_type:
88 88 response.content_type = 'application/json'
89 89
90 90 return self.serializer(value, **self.kw)
91 91
92 92 return _render
93 93
94 94
95 95 def jsonrpc_response(request, result):
96 96 rpc_id = getattr(request, 'rpc_id', None)
97 97 response = request.response
98 98
99 99 # store content_type before render is called
100 100 ct = response.content_type
101 101
102 102 ret_value = ''
103 103 if rpc_id:
104 104 ret_value = {
105 105 'id': rpc_id,
106 106 'result': result,
107 107 'error': None,
108 108 }
109 109
110 110 # fetch deprecation warnings, and store it inside results
111 111 deprecation = getattr(request, 'rpc_deprecation', None)
112 112 if deprecation:
113 113 ret_value['DEPRECATION_WARNING'] = deprecation
114 114
115 115 raw_body = render(DEFAULT_RENDERER, ret_value, request=request)
116 116 response.body = safe_str(raw_body, response.charset)
117 117
118 118 if ct == response.default_content_type:
119 119 response.content_type = 'application/json'
120 120
121 121 return response
122 122
123 123
124 124 def jsonrpc_error(request, message, retid=None, code=None, headers=None):
125 125 """
126 126 Generate a Response object with a JSON-RPC error body
127 127
128 128 :param code:
129 129 :param retid:
130 130 :param message:
131 131 """
132 132 err_dict = {'id': retid, 'result': None, 'error': message}
133 133 body = render(DEFAULT_RENDERER, err_dict, request=request).encode('utf-8')
134 134
135 135 return Response(
136 136 body=body,
137 137 status=code,
138 138 content_type='application/json',
139 139 headerlist=headers
140 140 )
141 141
142 142
143 143 def exception_view(exc, request):
144 144 rpc_id = getattr(request, 'rpc_id', None)
145 145
146 146 if isinstance(exc, JSONRPCError):
147 147 fault_message = safe_str(exc.message)
148 148 log.debug('json-rpc error rpc_id:%s "%s"', rpc_id, fault_message)
149 149 elif isinstance(exc, JSONRPCValidationError):
150 150 colander_exc = exc.colander_exception
151 151 # TODO(marcink): think maybe of nicer way to serialize errors ?
152 152 fault_message = colander_exc.asdict()
153 153 log.debug('json-rpc colander error rpc_id:%s "%s"', rpc_id, fault_message)
154 154 elif isinstance(exc, JSONRPCForbidden):
155 155 fault_message = 'Access was denied to this resource.'
156 156 log.warning('json-rpc forbidden call rpc_id:%s "%s"', rpc_id, fault_message)
157 157 elif isinstance(exc, HTTPNotFound):
158 158 method = request.rpc_method
159 159 log.debug('json-rpc method `%s` not found in list of '
160 160 'api calls: %s, rpc_id:%s',
161 161 method, request.registry.jsonrpc_methods.keys(), rpc_id)
162 162
163 163 similar = 'none'
164 164 try:
165 165 similar_paterns = ['*{}*'.format(x) for x in method.split('_')]
166 166 similar_found = find_methods(
167 167 request.registry.jsonrpc_methods, similar_paterns)
168 168 similar = ', '.join(similar_found.keys()) or similar
169 169 except Exception:
170 170 # make the whole above block safe
171 171 pass
172 172
173 173 fault_message = "No such method: {}. Similar methods: {}".format(
174 174 method, similar)
175 175 else:
176 176 fault_message = 'undefined error'
177 177 exc_info = exc.exc_info()
178 178 store_exception(id(exc_info), exc_info, prefix='rhodecode-api')
179 179
180 180 statsd = request.registry.statsd
181 181 if statsd:
182 182 exc_type = "{}.{}".format(exc.__class__.__module__, exc.__class__.__name__)
183 183 statsd.incr('rhodecode_exception_total',
184 184 tags=["exc_source:api", "type:{}".format(exc_type)])
185 185
186 186 return jsonrpc_error(request, fault_message, rpc_id)
187 187
188 188
189 189 def request_view(request):
190 190 """
191 191 Main request handling method. It handles all logic to call a specific
192 192 exposed method
193 193 """
194 194 # cython compatible inspect
195 195 from rhodecode.config.patches import inspect_getargspec
196 196 inspect = inspect_getargspec()
197 197
198 198 # check if we can find this session using api_key, get_by_auth_token
199 199 # search not expired tokens only
200 200 try:
201 201 api_user = User.get_by_auth_token(request.rpc_api_key)
202 202
203 203 if api_user is None:
204 204 return jsonrpc_error(
205 205 request, retid=request.rpc_id, message='Invalid API KEY')
206 206
207 207 if not api_user.active:
208 208 return jsonrpc_error(
209 209 request, retid=request.rpc_id,
210 210 message='Request from this user not allowed')
211 211
212 212 # check if we are allowed to use this IP
213 213 auth_u = AuthUser(
214 214 api_user.user_id, request.rpc_api_key, ip_addr=request.rpc_ip_addr)
215 215 if not auth_u.ip_allowed:
216 216 return jsonrpc_error(
217 217 request, retid=request.rpc_id,
218 218 message='Request from IP:%s not allowed' % (
219 219 request.rpc_ip_addr,))
220 220 else:
221 221 log.info('Access for IP:%s allowed', request.rpc_ip_addr)
222 222
223 223 # register our auth-user
224 224 request.rpc_user = auth_u
225 225 request.environ['rc_auth_user_id'] = auth_u.user_id
226 226
227 227 # now check if token is valid for API
228 228 auth_token = request.rpc_api_key
229 229 token_match = api_user.authenticate_by_token(
230 230 auth_token, roles=[UserApiKeys.ROLE_API])
231 231 invalid_token = not token_match
232 232
233 233 log.debug('Checking if API KEY is valid with proper role')
234 234 if invalid_token:
235 235 return jsonrpc_error(
236 236 request, retid=request.rpc_id,
237 237 message='API KEY invalid or, has bad role for an API call')
238 238
239 239 except Exception:
240 240 log.exception('Error on API AUTH')
241 241 return jsonrpc_error(
242 242 request, retid=request.rpc_id, message='Invalid API KEY')
243 243
244 244 method = request.rpc_method
245 245 func = request.registry.jsonrpc_methods[method]
246 246
247 247 # now that we have a method, add request._req_params to
248 248 # self.kargs and dispatch control to WGIController
249 249 argspec = inspect.getargspec(func)
250 250 arglist = argspec[0]
251 251 defaults = map(type, argspec[3] or [])
252 252 default_empty = types.NotImplementedType
253 253
254 254 # kw arguments required by this method
255 func_kwargs = dict(itertools.izip_longest(
255 func_kwargs = dict(itertools.zip_longest(
256 256 reversed(arglist), reversed(defaults), fillvalue=default_empty))
257 257
258 258 # This attribute will need to be first param of a method that uses
259 259 # api_key, which is translated to instance of user at that name
260 260 user_var = 'apiuser'
261 261 request_var = 'request'
262 262
263 263 for arg in [user_var, request_var]:
264 264 if arg not in arglist:
265 265 return jsonrpc_error(
266 266 request,
267 267 retid=request.rpc_id,
268 268 message='This method [%s] does not support '
269 269 'required parameter `%s`' % (func.__name__, arg))
270 270
271 271 # get our arglist and check if we provided them as args
272 272 for arg, default in func_kwargs.items():
273 273 if arg in [user_var, request_var]:
274 274 # user_var and request_var are pre-hardcoded parameters and we
275 275 # don't need to do any translation
276 276 continue
277 277
278 278 # skip the required param check if it's default value is
279 279 # NotImplementedType (default_empty)
280 280 if default == default_empty and arg not in request.rpc_params:
281 281 return jsonrpc_error(
282 282 request,
283 283 retid=request.rpc_id,
284 284 message=('Missing non optional `%s` arg in JSON DATA' % arg)
285 285 )
286 286
287 287 # sanitize extra passed arguments
288 288 for k in request.rpc_params.keys()[:]:
289 289 if k not in func_kwargs:
290 290 del request.rpc_params[k]
291 291
292 292 call_params = request.rpc_params
293 293 call_params.update({
294 294 'request': request,
295 295 'apiuser': auth_u
296 296 })
297 297
298 298 # register some common functions for usage
299 299 attach_context_attributes(TemplateArgs(), request, request.rpc_user.user_id)
300 300
301 301 statsd = request.registry.statsd
302 302
303 303 try:
304 304 ret_value = func(**call_params)
305 305 resp = jsonrpc_response(request, ret_value)
306 306 if statsd:
307 307 statsd.incr('rhodecode_api_call_success_total')
308 308 return resp
309 309 except JSONRPCBaseError:
310 310 raise
311 311 except Exception:
312 312 log.exception('Unhandled exception occurred on api call: %s', func)
313 313 exc_info = sys.exc_info()
314 314 exc_id, exc_type_name = store_exception(
315 315 id(exc_info), exc_info, prefix='rhodecode-api')
316 316 error_headers = [('RhodeCode-Exception-Id', str(exc_id)),
317 317 ('RhodeCode-Exception-Type', str(exc_type_name))]
318 318 err_resp = jsonrpc_error(
319 319 request, retid=request.rpc_id, message='Internal server error',
320 320 headers=error_headers)
321 321 if statsd:
322 322 statsd.incr('rhodecode_api_call_fail_total')
323 323 return err_resp
324 324
325 325
326 326 def setup_request(request):
327 327 """
328 328 Parse a JSON-RPC request body. It's used inside the predicates method
329 329 to validate and bootstrap requests for usage in rpc calls.
330 330
331 331 We need to raise JSONRPCError here if we want to return some errors back to
332 332 user.
333 333 """
334 334
335 335 log.debug('Executing setup request: %r', request)
336 336 request.rpc_ip_addr = get_ip_addr(request.environ)
337 337 # TODO(marcink): deprecate GET at some point
338 338 if request.method not in ['POST', 'GET']:
339 339 log.debug('unsupported request method "%s"', request.method)
340 340 raise JSONRPCError(
341 341 'unsupported request method "%s". Please use POST' % request.method)
342 342
343 343 if 'CONTENT_LENGTH' not in request.environ:
344 344 log.debug("No Content-Length")
345 345 raise JSONRPCError("Empty body, No Content-Length in request")
346 346
347 347 else:
348 348 length = request.environ['CONTENT_LENGTH']
349 349 log.debug('Content-Length: %s', length)
350 350
351 351 if length == 0:
352 352 log.debug("Content-Length is 0")
353 353 raise JSONRPCError("Content-Length is 0")
354 354
355 355 raw_body = request.body
356 356 log.debug("Loading JSON body now")
357 357 try:
358 358 json_body = json.loads(raw_body)
359 359 except ValueError as e:
360 360 # catch JSON errors Here
361 361 raise JSONRPCError("JSON parse error ERR:%s RAW:%r" % (e, raw_body))
362 362
363 363 request.rpc_id = json_body.get('id')
364 364 request.rpc_method = json_body.get('method')
365 365
366 366 # check required base parameters
367 367 try:
368 368 api_key = json_body.get('api_key')
369 369 if not api_key:
370 370 api_key = json_body.get('auth_token')
371 371
372 372 if not api_key:
373 373 raise KeyError('api_key or auth_token')
374 374
375 375 # TODO(marcink): support passing in token in request header
376 376
377 377 request.rpc_api_key = api_key
378 378 request.rpc_id = json_body['id']
379 379 request.rpc_method = json_body['method']
380 380 request.rpc_params = json_body['args'] \
381 381 if isinstance(json_body['args'], dict) else {}
382 382
383 383 log.debug('method: %s, params: %.10240r', request.rpc_method, request.rpc_params)
384 384 except KeyError as e:
385 385 raise JSONRPCError('Incorrect JSON data. Missing %s' % e)
386 386
387 387 log.debug('setup complete, now handling method:%s rpcid:%s',
388 388 request.rpc_method, request.rpc_id, )
389 389
390 390
391 391 class RoutePredicate(object):
392 392 def __init__(self, val, config):
393 393 self.val = val
394 394
395 395 def text(self):
396 396 return 'jsonrpc route = %s' % self.val
397 397
398 398 phash = text
399 399
400 400 def __call__(self, info, request):
401 401 if self.val:
402 402 # potentially setup and bootstrap our call
403 403 setup_request(request)
404 404
405 405 # Always return True so that even if it isn't a valid RPC it
406 406 # will fall through to the underlaying handlers like notfound_view
407 407 return True
408 408
409 409
410 410 class NotFoundPredicate(object):
411 411 def __init__(self, val, config):
412 412 self.val = val
413 413 self.methods = config.registry.jsonrpc_methods
414 414
415 415 def text(self):
416 416 return 'jsonrpc method not found = {}.'.format(self.val)
417 417
418 418 phash = text
419 419
420 420 def __call__(self, info, request):
421 421 return hasattr(request, 'rpc_method')
422 422
423 423
424 424 class MethodPredicate(object):
425 425 def __init__(self, val, config):
426 426 self.method = val
427 427
428 428 def text(self):
429 429 return 'jsonrpc method = %s' % self.method
430 430
431 431 phash = text
432 432
433 433 def __call__(self, context, request):
434 434 # we need to explicitly return False here, so pyramid doesn't try to
435 435 # execute our view directly. We need our main handler to execute things
436 436 return getattr(request, 'rpc_method') == self.method
437 437
438 438
439 439 def add_jsonrpc_method(config, view, **kwargs):
440 440 # pop the method name
441 441 method = kwargs.pop('method', None)
442 442
443 443 if method is None:
444 444 raise ConfigurationError(
445 445 'Cannot register a JSON-RPC method without specifying the "method"')
446 446
447 447 # we define custom predicate, to enable to detect conflicting methods,
448 448 # those predicates are kind of "translation" from the decorator variables
449 449 # to internal predicates names
450 450
451 451 kwargs['jsonrpc_method'] = method
452 452
453 453 # register our view into global view store for validation
454 454 config.registry.jsonrpc_methods[method] = view
455 455
456 456 # we're using our main request_view handler, here, so each method
457 457 # has a unified handler for itself
458 458 config.add_view(request_view, route_name='apiv2', **kwargs)
459 459
460 460
461 461 class jsonrpc_method(object):
462 462 """
463 463 decorator that works similar to @add_view_config decorator,
464 464 but tailored for our JSON RPC
465 465 """
466 466
467 467 venusian = venusian # for testing injection
468 468
469 469 def __init__(self, method=None, **kwargs):
470 470 self.method = method
471 471 self.kwargs = kwargs
472 472
473 473 def __call__(self, wrapped):
474 474 kwargs = self.kwargs.copy()
475 475 kwargs['method'] = self.method or wrapped.__name__
476 476 depth = kwargs.pop('_depth', 0)
477 477
478 478 def callback(context, name, ob):
479 479 config = context.config.with_package(info.module)
480 480 config.add_jsonrpc_method(view=ob, **kwargs)
481 481
482 482 info = venusian.attach(wrapped, callback, category='pyramid',
483 483 depth=depth + 1)
484 484 if info.scope == 'class':
485 485 # ensure that attr is set if decorating a class method
486 486 kwargs.setdefault('attr', wrapped.__name__)
487 487
488 488 kwargs['_info'] = info.codeinfo # fbo action_method
489 489 return wrapped
490 490
491 491
492 492 class jsonrpc_deprecated_method(object):
493 493 """
494 494 Marks method as deprecated, adds log.warning, and inject special key to
495 495 the request variable to mark method as deprecated.
496 496 Also injects special docstring that extract_docs will catch to mark
497 497 method as deprecated.
498 498
499 499 :param use_method: specify which method should be used instead of
500 500 the decorated one
501 501
502 502 Use like::
503 503
504 504 @jsonrpc_method()
505 505 @jsonrpc_deprecated_method(use_method='new_func', deprecated_at_version='3.0.0')
506 506 def old_func(request, apiuser, arg1, arg2):
507 507 ...
508 508 """
509 509
510 510 def __init__(self, use_method, deprecated_at_version):
511 511 self.use_method = use_method
512 512 self.deprecated_at_version = deprecated_at_version
513 513 self.deprecated_msg = ''
514 514
515 515 def __call__(self, func):
516 516 self.deprecated_msg = 'Please use method `{method}` instead.'.format(
517 517 method=self.use_method)
518 518
519 519 docstring = """\n
520 520 .. deprecated:: {version}
521 521
522 522 {deprecation_message}
523 523
524 524 {original_docstring}
525 525 """
526 526 func.__doc__ = docstring.format(
527 527 version=self.deprecated_at_version,
528 528 deprecation_message=self.deprecated_msg,
529 529 original_docstring=func.__doc__)
530 530 return decorator.decorator(self.__wrapper, func)
531 531
532 532 def __wrapper(self, func, *fargs, **fkwargs):
533 533 log.warning('DEPRECATED API CALL on function %s, please '
534 534 'use `%s` instead', func, self.use_method)
535 535 # alter function docstring to mark as deprecated, this is picked up
536 536 # via fabric file that generates API DOC.
537 537 result = func(*fargs, **fkwargs)
538 538
539 539 request = fargs[0]
540 540 request.rpc_deprecation = 'DEPRECATED METHOD ' + self.deprecated_msg
541 541 return result
542 542
543 543
544 544 def add_api_methods(config):
545 545 from rhodecode.api.views import (
546 546 deprecated_api, gist_api, pull_request_api, repo_api, repo_group_api,
547 547 server_api, search_api, testing_api, user_api, user_group_api)
548 548
549 549 config.scan('rhodecode.api.views')
550 550
551 551
552 552 def includeme(config):
553 553 plugin_module = 'rhodecode.api'
554 554 plugin_settings = get_plugin_settings(
555 555 plugin_module, config.registry.settings)
556 556
557 557 if not hasattr(config.registry, 'jsonrpc_methods'):
558 558 config.registry.jsonrpc_methods = OrderedDict()
559 559
560 560 # match filter by given method only
561 561 config.add_view_predicate('jsonrpc_method', MethodPredicate)
562 562 config.add_view_predicate('jsonrpc_method_not_found', NotFoundPredicate)
563 563
564 564 config.add_renderer(DEFAULT_RENDERER, ExtJsonRenderer(
565 565 serializer=json.dumps, indent=4))
566 566 config.add_directive('add_jsonrpc_method', add_jsonrpc_method)
567 567
568 568 config.add_route_predicate(
569 569 'jsonrpc_call', RoutePredicate)
570 570
571 571 config.add_route(
572 572 'apiv2', plugin_settings.get('url', DEFAULT_URL), jsonrpc_call=True)
573 573
574 574 # register some exception handling view
575 575 config.add_view(exception_view, context=JSONRPCBaseError)
576 576 config.add_notfound_view(exception_view, jsonrpc_method_not_found=True)
577 577
578 578 add_api_methods(config)
@@ -1,419 +1,419 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import logging
22 22 import itertools
23 23 import base64
24 24
25 25 from rhodecode.api import (
26 26 jsonrpc_method, JSONRPCError, JSONRPCForbidden, find_methods)
27 27
28 28 from rhodecode.api.utils import (
29 29 Optional, OAttr, has_superadmin_permission, get_user_or_error)
30 30 from rhodecode.lib.utils import repo2db_mapper
31 31 from rhodecode.lib import system_info
32 32 from rhodecode.lib import user_sessions
33 33 from rhodecode.lib import exc_tracking
34 34 from rhodecode.lib.ext_json import json
35 35 from rhodecode.lib.utils2 import safe_int
36 36 from rhodecode.model.db import UserIpMap
37 37 from rhodecode.model.scm import ScmModel
38 38 from rhodecode.model.settings import VcsSettingsModel
39 39 from rhodecode.apps.file_store import utils
40 40 from rhodecode.apps.file_store.exceptions import FileNotAllowedException, \
41 41 FileOverSizeException
42 42
43 43 log = logging.getLogger(__name__)
44 44
45 45
46 46 @jsonrpc_method()
47 47 def get_server_info(request, apiuser):
48 48 """
49 49 Returns the |RCE| server information.
50 50
51 51 This includes the running version of |RCE| and all installed
52 52 packages. This command takes the following options:
53 53
54 54 :param apiuser: This is filled automatically from the |authtoken|.
55 55 :type apiuser: AuthUser
56 56
57 57 Example output:
58 58
59 59 .. code-block:: bash
60 60
61 61 id : <id_given_in_input>
62 62 result : {
63 63 'modules': [<module name>,...]
64 64 'py_version': <python version>,
65 65 'platform': <platform type>,
66 66 'rhodecode_version': <rhodecode version>
67 67 }
68 68 error : null
69 69 """
70 70
71 71 if not has_superadmin_permission(apiuser):
72 72 raise JSONRPCForbidden()
73 73
74 74 server_info = ScmModel().get_server_info(request.environ)
75 75 # rhodecode-index requires those
76 76
77 77 server_info['index_storage'] = server_info['search']['value']['location']
78 78 server_info['storage'] = server_info['storage']['value']['path']
79 79
80 80 return server_info
81 81
82 82
83 83 @jsonrpc_method()
84 84 def get_repo_store(request, apiuser):
85 85 """
86 86 Returns the |RCE| repository storage information.
87 87
88 88 :param apiuser: This is filled automatically from the |authtoken|.
89 89 :type apiuser: AuthUser
90 90
91 91 Example output:
92 92
93 93 .. code-block:: bash
94 94
95 95 id : <id_given_in_input>
96 96 result : {
97 97 'modules': [<module name>,...]
98 98 'py_version': <python version>,
99 99 'platform': <platform type>,
100 100 'rhodecode_version': <rhodecode version>
101 101 }
102 102 error : null
103 103 """
104 104
105 105 if not has_superadmin_permission(apiuser):
106 106 raise JSONRPCForbidden()
107 107
108 108 path = VcsSettingsModel().get_repos_location()
109 109 return {"path": path}
110 110
111 111
112 112 @jsonrpc_method()
113 113 def get_ip(request, apiuser, userid=Optional(OAttr('apiuser'))):
114 114 """
115 115 Displays the IP Address as seen from the |RCE| server.
116 116
117 117 * This command displays the IP Address, as well as all the defined IP
118 118 addresses for the specified user. If the ``userid`` is not set, the
119 119 data returned is for the user calling the method.
120 120
121 121 This command can only be run using an |authtoken| with admin rights to
122 122 the specified repository.
123 123
124 124 This command takes the following options:
125 125
126 126 :param apiuser: This is filled automatically from |authtoken|.
127 127 :type apiuser: AuthUser
128 128 :param userid: Sets the userid for which associated IP Address data
129 129 is returned.
130 130 :type userid: Optional(str or int)
131 131
132 132 Example output:
133 133
134 134 .. code-block:: bash
135 135
136 136 id : <id_given_in_input>
137 137 result : {
138 138 "server_ip_addr": "<ip_from_clien>",
139 139 "user_ips": [
140 140 {
141 141 "ip_addr": "<ip_with_mask>",
142 142 "ip_range": ["<start_ip>", "<end_ip>"],
143 143 },
144 144 ...
145 145 ]
146 146 }
147 147
148 148 """
149 149 if not has_superadmin_permission(apiuser):
150 150 raise JSONRPCForbidden()
151 151
152 152 userid = Optional.extract(userid, evaluate_locals=locals())
153 153 userid = getattr(userid, 'user_id', userid)
154 154
155 155 user = get_user_or_error(userid)
156 156 ips = UserIpMap.query().filter(UserIpMap.user == user).all()
157 157 return {
158 158 'server_ip_addr': request.rpc_ip_addr,
159 159 'user_ips': ips
160 160 }
161 161
162 162
163 163 @jsonrpc_method()
164 164 def rescan_repos(request, apiuser, remove_obsolete=Optional(False)):
165 165 """
166 166 Triggers a rescan of the specified repositories.
167 167
168 168 * If the ``remove_obsolete`` option is set, it also deletes repositories
169 169 that are found in the database but not on the file system, so called
170 170 "clean zombies".
171 171
172 172 This command can only be run using an |authtoken| with admin rights to
173 173 the specified repository.
174 174
175 175 This command takes the following options:
176 176
177 177 :param apiuser: This is filled automatically from the |authtoken|.
178 178 :type apiuser: AuthUser
179 179 :param remove_obsolete: Deletes repositories from the database that
180 180 are not found on the filesystem.
181 181 :type remove_obsolete: Optional(``True`` | ``False``)
182 182
183 183 Example output:
184 184
185 185 .. code-block:: bash
186 186
187 187 id : <id_given_in_input>
188 188 result : {
189 189 'added': [<added repository name>,...]
190 190 'removed': [<removed repository name>,...]
191 191 }
192 192 error : null
193 193
194 194 Example error output:
195 195
196 196 .. code-block:: bash
197 197
198 198 id : <id_given_in_input>
199 199 result : null
200 200 error : {
201 201 'Error occurred during rescan repositories action'
202 202 }
203 203
204 204 """
205 205 if not has_superadmin_permission(apiuser):
206 206 raise JSONRPCForbidden()
207 207
208 208 try:
209 209 rm_obsolete = Optional.extract(remove_obsolete)
210 210 added, removed = repo2db_mapper(ScmModel().repo_scan(),
211 211 remove_obsolete=rm_obsolete)
212 212 return {'added': added, 'removed': removed}
213 213 except Exception:
214 214 log.exception('Failed to run repo rescann')
215 215 raise JSONRPCError(
216 216 'Error occurred during rescan repositories action'
217 217 )
218 218
219 219
220 220 @jsonrpc_method()
221 221 def cleanup_sessions(request, apiuser, older_then=Optional(60)):
222 222 """
223 223 Triggers a session cleanup action.
224 224
225 225 If the ``older_then`` option is set, only sessions that hasn't been
226 226 accessed in the given number of days will be removed.
227 227
228 228 This command can only be run using an |authtoken| with admin rights to
229 229 the specified repository.
230 230
231 231 This command takes the following options:
232 232
233 233 :param apiuser: This is filled automatically from the |authtoken|.
234 234 :type apiuser: AuthUser
235 235 :param older_then: Deletes session that hasn't been accessed
236 236 in given number of days.
237 237 :type older_then: Optional(int)
238 238
239 239 Example output:
240 240
241 241 .. code-block:: bash
242 242
243 243 id : <id_given_in_input>
244 244 result: {
245 245 "backend": "<type of backend>",
246 246 "sessions_removed": <number_of_removed_sessions>
247 247 }
248 248 error : null
249 249
250 250 Example error output:
251 251
252 252 .. code-block:: bash
253 253
254 254 id : <id_given_in_input>
255 255 result : null
256 256 error : {
257 257 'Error occurred during session cleanup'
258 258 }
259 259
260 260 """
261 261 if not has_superadmin_permission(apiuser):
262 262 raise JSONRPCForbidden()
263 263
264 264 older_then = safe_int(Optional.extract(older_then)) or 60
265 265 older_than_seconds = 60 * 60 * 24 * older_then
266 266
267 267 config = system_info.rhodecode_config().get_value()['value']['config']
268 268 session_model = user_sessions.get_session_handler(
269 269 config.get('beaker.session.type', 'memory'))(config)
270 270
271 271 backend = session_model.SESSION_TYPE
272 272 try:
273 273 cleaned = session_model.clean_sessions(
274 274 older_than_seconds=older_than_seconds)
275 275 return {'sessions_removed': cleaned, 'backend': backend}
276 276 except user_sessions.CleanupCommand as msg:
277 277 return {'cleanup_command': msg.message, 'backend': backend}
278 278 except Exception as e:
279 279 log.exception('Failed session cleanup')
280 280 raise JSONRPCError(
281 281 'Error occurred during session cleanup'
282 282 )
283 283
284 284
285 285 @jsonrpc_method()
286 286 def get_method(request, apiuser, pattern=Optional('*')):
287 287 """
288 288 Returns list of all available API methods. By default match pattern
289 289 os "*" but any other pattern can be specified. eg *comment* will return
290 290 all methods with comment inside them. If just single method is matched
291 291 returned data will also include method specification
292 292
293 293 This command can only be run using an |authtoken| with admin rights to
294 294 the specified repository.
295 295
296 296 This command takes the following options:
297 297
298 298 :param apiuser: This is filled automatically from the |authtoken|.
299 299 :type apiuser: AuthUser
300 300 :param pattern: pattern to match method names against
301 301 :type pattern: Optional("*")
302 302
303 303 Example output:
304 304
305 305 .. code-block:: bash
306 306
307 307 id : <id_given_in_input>
308 308 "result": [
309 309 "changeset_comment",
310 310 "comment_pull_request",
311 311 "comment_commit"
312 312 ]
313 313 error : null
314 314
315 315 .. code-block:: bash
316 316
317 317 id : <id_given_in_input>
318 318 "result": [
319 319 "comment_commit",
320 320 {
321 321 "apiuser": "<RequiredType>",
322 322 "comment_type": "<Optional:u'note'>",
323 323 "commit_id": "<RequiredType>",
324 324 "message": "<RequiredType>",
325 325 "repoid": "<RequiredType>",
326 326 "request": "<RequiredType>",
327 327 "resolves_comment_id": "<Optional:None>",
328 328 "status": "<Optional:None>",
329 329 "userid": "<Optional:<OptionalAttr:apiuser>>"
330 330 }
331 331 ]
332 332 error : null
333 333 """
334 334 from rhodecode.config.patches import inspect_getargspec
335 335 inspect = inspect_getargspec()
336 336
337 337 if not has_superadmin_permission(apiuser):
338 338 raise JSONRPCForbidden()
339 339
340 340 pattern = Optional.extract(pattern)
341 341
342 342 matches = find_methods(request.registry.jsonrpc_methods, pattern)
343 343
344 344 args_desc = []
345 345 if len(matches) == 1:
346 346 func = matches[matches.keys()[0]]
347 347
348 348 argspec = inspect.getargspec(func)
349 349 arglist = argspec[0]
350 350 defaults = map(repr, argspec[3] or [])
351 351
352 352 default_empty = '<RequiredType>'
353 353
354 354 # kw arguments required by this method
355 func_kwargs = dict(itertools.izip_longest(
355 func_kwargs = dict(itertools.zip_longest(
356 356 reversed(arglist), reversed(defaults), fillvalue=default_empty))
357 357 args_desc.append(func_kwargs)
358 358
359 359 return matches.keys() + args_desc
360 360
361 361
362 362 @jsonrpc_method()
363 363 def store_exception(request, apiuser, exc_data_json, prefix=Optional('rhodecode')):
364 364 """
365 365 Stores sent exception inside the built-in exception tracker in |RCE| server.
366 366
367 367 This command can only be run using an |authtoken| with admin rights to
368 368 the specified repository.
369 369
370 370 This command takes the following options:
371 371
372 372 :param apiuser: This is filled automatically from the |authtoken|.
373 373 :type apiuser: AuthUser
374 374
375 375 :param exc_data_json: JSON data with exception e.g
376 376 {"exc_traceback": "Value `1` is not allowed", "exc_type_name": "ValueError"}
377 377 :type exc_data_json: JSON data
378 378
379 379 :param prefix: prefix for error type, e.g 'rhodecode', 'vcsserver', 'rhodecode-tools'
380 380 :type prefix: Optional("rhodecode")
381 381
382 382 Example output:
383 383
384 384 .. code-block:: bash
385 385
386 386 id : <id_given_in_input>
387 387 "result": {
388 388 "exc_id": 139718459226384,
389 389 "exc_url": "http://localhost:8080/_admin/settings/exceptions/139718459226384"
390 390 }
391 391 error : null
392 392 """
393 393 if not has_superadmin_permission(apiuser):
394 394 raise JSONRPCForbidden()
395 395
396 396 prefix = Optional.extract(prefix)
397 397 exc_id = exc_tracking.generate_id()
398 398
399 399 try:
400 400 exc_data = json.loads(exc_data_json)
401 401 except Exception:
402 402 log.error('Failed to parse JSON: %r', exc_data_json)
403 403 raise JSONRPCError('Failed to parse JSON data from exc_data_json field. '
404 404 'Please make sure it contains a valid JSON.')
405 405
406 406 try:
407 407 exc_traceback = exc_data['exc_traceback']
408 408 exc_type_name = exc_data['exc_type_name']
409 409 except KeyError as err:
410 410 raise JSONRPCError('Missing exc_traceback, or exc_type_name '
411 411 'in exc_data_json field. Missing: {}'.format(err))
412 412
413 413 exc_tracking._store_exception(
414 414 exc_id=exc_id, exc_traceback=exc_traceback,
415 415 exc_type_name=exc_type_name, prefix=prefix)
416 416
417 417 exc_url = request.route_url(
418 418 'admin_settings_exception_tracker_show', exception_id=exc_id)
419 419 return {'exc_id': exc_id, 'exc_url': exc_url}
@@ -1,479 +1,479 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2016-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import re
22 22 import logging
23 23 import formencode
24 24 import formencode.htmlfill
25 25 import datetime
26 26 from pyramid.interfaces import IRoutesMapper
27 27
28 28 from pyramid.httpexceptions import HTTPFound
29 29 from pyramid.renderers import render
30 30 from pyramid.response import Response
31 31
32 32 from rhodecode.apps._base import BaseAppView, DataGridAppView
33 33 from rhodecode.apps.ssh_support import SshKeyFileChangeEvent
34 34 from rhodecode import events
35 35
36 36 from rhodecode.lib import helpers as h
37 37 from rhodecode.lib.auth import (
38 38 LoginRequired, HasPermissionAllDecorator, CSRFRequired)
39 39 from rhodecode.lib.utils2 import aslist, safe_unicode
40 40 from rhodecode.model.db import (
41 41 or_, coalesce, User, UserIpMap, UserSshKeys)
42 42 from rhodecode.model.forms import (
43 43 ApplicationPermissionsForm, ObjectPermissionsForm, UserPermissionsForm)
44 44 from rhodecode.model.meta import Session
45 45 from rhodecode.model.permission import PermissionModel
46 46 from rhodecode.model.settings import SettingsModel
47 47
48 48
49 49 log = logging.getLogger(__name__)
50 50
51 51
52 52 class AdminPermissionsView(BaseAppView, DataGridAppView):
53 53 def load_default_context(self):
54 54 c = self._get_local_tmpl_context()
55 55 PermissionModel().set_global_permission_choices(
56 56 c, gettext_translator=self.request.translate)
57 57 return c
58 58
59 59 @LoginRequired()
60 60 @HasPermissionAllDecorator('hg.admin')
61 61 def permissions_application(self):
62 62 c = self.load_default_context()
63 63 c.active = 'application'
64 64
65 65 c.user = User.get_default_user(refresh=True)
66 66
67 67 app_settings = c.rc_config
68 68
69 69 defaults = {
70 70 'anonymous': c.user.active,
71 71 'default_register_message': app_settings.get(
72 72 'rhodecode_register_message')
73 73 }
74 74 defaults.update(c.user.get_default_perms())
75 75
76 76 data = render('rhodecode:templates/admin/permissions/permissions.mako',
77 77 self._get_template_context(c), self.request)
78 78 html = formencode.htmlfill.render(
79 79 data,
80 80 defaults=defaults,
81 81 encoding="UTF-8",
82 82 force_defaults=False
83 83 )
84 84 return Response(html)
85 85
86 86 @LoginRequired()
87 87 @HasPermissionAllDecorator('hg.admin')
88 88 @CSRFRequired()
89 89 def permissions_application_update(self):
90 90 _ = self.request.translate
91 91 c = self.load_default_context()
92 92 c.active = 'application'
93 93
94 94 _form = ApplicationPermissionsForm(
95 95 self.request.translate,
96 96 [x[0] for x in c.register_choices],
97 97 [x[0] for x in c.password_reset_choices],
98 98 [x[0] for x in c.extern_activate_choices])()
99 99
100 100 try:
101 101 form_result = _form.to_python(dict(self.request.POST))
102 102 form_result.update({'perm_user_name': User.DEFAULT_USER})
103 103 PermissionModel().update_application_permissions(form_result)
104 104
105 105 settings = [
106 106 ('register_message', 'default_register_message'),
107 107 ]
108 108 for setting, form_key in settings:
109 109 sett = SettingsModel().create_or_update_setting(
110 110 setting, form_result[form_key])
111 111 Session().add(sett)
112 112
113 113 Session().commit()
114 114 h.flash(_('Application permissions updated successfully'),
115 115 category='success')
116 116
117 117 except formencode.Invalid as errors:
118 118 defaults = errors.value
119 119
120 120 data = render(
121 121 'rhodecode:templates/admin/permissions/permissions.mako',
122 122 self._get_template_context(c), self.request)
123 123 html = formencode.htmlfill.render(
124 124 data,
125 125 defaults=defaults,
126 126 errors=errors.error_dict or {},
127 127 prefix_error=False,
128 128 encoding="UTF-8",
129 129 force_defaults=False
130 130 )
131 131 return Response(html)
132 132
133 133 except Exception:
134 134 log.exception("Exception during update of permissions")
135 135 h.flash(_('Error occurred during update of permissions'),
136 136 category='error')
137 137
138 138 affected_user_ids = [User.get_default_user_id()]
139 139 PermissionModel().trigger_permission_flush(affected_user_ids)
140 140
141 141 raise HTTPFound(h.route_path('admin_permissions_application'))
142 142
143 143 @LoginRequired()
144 144 @HasPermissionAllDecorator('hg.admin')
145 145 def permissions_objects(self):
146 146 c = self.load_default_context()
147 147 c.active = 'objects'
148 148
149 149 c.user = User.get_default_user(refresh=True)
150 150 defaults = {}
151 151 defaults.update(c.user.get_default_perms())
152 152
153 153 data = render(
154 154 'rhodecode:templates/admin/permissions/permissions.mako',
155 155 self._get_template_context(c), self.request)
156 156 html = formencode.htmlfill.render(
157 157 data,
158 158 defaults=defaults,
159 159 encoding="UTF-8",
160 160 force_defaults=False
161 161 )
162 162 return Response(html)
163 163
164 164 @LoginRequired()
165 165 @HasPermissionAllDecorator('hg.admin')
166 166 @CSRFRequired()
167 167 def permissions_objects_update(self):
168 168 _ = self.request.translate
169 169 c = self.load_default_context()
170 170 c.active = 'objects'
171 171
172 172 _form = ObjectPermissionsForm(
173 173 self.request.translate,
174 174 [x[0] for x in c.repo_perms_choices],
175 175 [x[0] for x in c.group_perms_choices],
176 176 [x[0] for x in c.user_group_perms_choices],
177 177 )()
178 178
179 179 try:
180 180 form_result = _form.to_python(dict(self.request.POST))
181 181 form_result.update({'perm_user_name': User.DEFAULT_USER})
182 182 PermissionModel().update_object_permissions(form_result)
183 183
184 184 Session().commit()
185 185 h.flash(_('Object permissions updated successfully'),
186 186 category='success')
187 187
188 188 except formencode.Invalid as errors:
189 189 defaults = errors.value
190 190
191 191 data = render(
192 192 'rhodecode:templates/admin/permissions/permissions.mako',
193 193 self._get_template_context(c), self.request)
194 194 html = formencode.htmlfill.render(
195 195 data,
196 196 defaults=defaults,
197 197 errors=errors.error_dict or {},
198 198 prefix_error=False,
199 199 encoding="UTF-8",
200 200 force_defaults=False
201 201 )
202 202 return Response(html)
203 203 except Exception:
204 204 log.exception("Exception during update of permissions")
205 205 h.flash(_('Error occurred during update of permissions'),
206 206 category='error')
207 207
208 208 affected_user_ids = [User.get_default_user_id()]
209 209 PermissionModel().trigger_permission_flush(affected_user_ids)
210 210
211 211 raise HTTPFound(h.route_path('admin_permissions_object'))
212 212
213 213 @LoginRequired()
214 214 @HasPermissionAllDecorator('hg.admin')
215 215 def permissions_branch(self):
216 216 c = self.load_default_context()
217 217 c.active = 'branch'
218 218
219 219 c.user = User.get_default_user(refresh=True)
220 220 defaults = {}
221 221 defaults.update(c.user.get_default_perms())
222 222
223 223 data = render(
224 224 'rhodecode:templates/admin/permissions/permissions.mako',
225 225 self._get_template_context(c), self.request)
226 226 html = formencode.htmlfill.render(
227 227 data,
228 228 defaults=defaults,
229 229 encoding="UTF-8",
230 230 force_defaults=False
231 231 )
232 232 return Response(html)
233 233
234 234 @LoginRequired()
235 235 @HasPermissionAllDecorator('hg.admin')
236 236 def permissions_global(self):
237 237 c = self.load_default_context()
238 238 c.active = 'global'
239 239
240 240 c.user = User.get_default_user(refresh=True)
241 241 defaults = {}
242 242 defaults.update(c.user.get_default_perms())
243 243
244 244 data = render(
245 245 'rhodecode:templates/admin/permissions/permissions.mako',
246 246 self._get_template_context(c), self.request)
247 247 html = formencode.htmlfill.render(
248 248 data,
249 249 defaults=defaults,
250 250 encoding="UTF-8",
251 251 force_defaults=False
252 252 )
253 253 return Response(html)
254 254
255 255 @LoginRequired()
256 256 @HasPermissionAllDecorator('hg.admin')
257 257 @CSRFRequired()
258 258 def permissions_global_update(self):
259 259 _ = self.request.translate
260 260 c = self.load_default_context()
261 261 c.active = 'global'
262 262
263 263 _form = UserPermissionsForm(
264 264 self.request.translate,
265 265 [x[0] for x in c.repo_create_choices],
266 266 [x[0] for x in c.repo_create_on_write_choices],
267 267 [x[0] for x in c.repo_group_create_choices],
268 268 [x[0] for x in c.user_group_create_choices],
269 269 [x[0] for x in c.fork_choices],
270 270 [x[0] for x in c.inherit_default_permission_choices])()
271 271
272 272 try:
273 273 form_result = _form.to_python(dict(self.request.POST))
274 274 form_result.update({'perm_user_name': User.DEFAULT_USER})
275 275 PermissionModel().update_user_permissions(form_result)
276 276
277 277 Session().commit()
278 278 h.flash(_('Global permissions updated successfully'),
279 279 category='success')
280 280
281 281 except formencode.Invalid as errors:
282 282 defaults = errors.value
283 283
284 284 data = render(
285 285 'rhodecode:templates/admin/permissions/permissions.mako',
286 286 self._get_template_context(c), self.request)
287 287 html = formencode.htmlfill.render(
288 288 data,
289 289 defaults=defaults,
290 290 errors=errors.error_dict or {},
291 291 prefix_error=False,
292 292 encoding="UTF-8",
293 293 force_defaults=False
294 294 )
295 295 return Response(html)
296 296 except Exception:
297 297 log.exception("Exception during update of permissions")
298 298 h.flash(_('Error occurred during update of permissions'),
299 299 category='error')
300 300
301 301 affected_user_ids = [User.get_default_user_id()]
302 302 PermissionModel().trigger_permission_flush(affected_user_ids)
303 303
304 304 raise HTTPFound(h.route_path('admin_permissions_global'))
305 305
306 306 @LoginRequired()
307 307 @HasPermissionAllDecorator('hg.admin')
308 308 def permissions_ips(self):
309 309 c = self.load_default_context()
310 310 c.active = 'ips'
311 311
312 312 c.user = User.get_default_user(refresh=True)
313 313 c.user_ip_map = (
314 314 UserIpMap.query().filter(UserIpMap.user == c.user).all())
315 315
316 316 return self._get_template_context(c)
317 317
318 318 @LoginRequired()
319 319 @HasPermissionAllDecorator('hg.admin')
320 320 def permissions_overview(self):
321 321 c = self.load_default_context()
322 322 c.active = 'perms'
323 323
324 324 c.user = User.get_default_user(refresh=True)
325 325 c.perm_user = c.user.AuthUser()
326 326 return self._get_template_context(c)
327 327
328 328 @LoginRequired()
329 329 @HasPermissionAllDecorator('hg.admin')
330 330 def auth_token_access(self):
331 331 from rhodecode import CONFIG
332 332
333 333 c = self.load_default_context()
334 334 c.active = 'auth_token_access'
335 335
336 336 c.user = User.get_default_user(refresh=True)
337 337 c.perm_user = c.user.AuthUser()
338 338
339 339 mapper = self.request.registry.queryUtility(IRoutesMapper)
340 340 c.view_data = []
341 341
342 _argument_prog = re.compile('\{(.*?)\}|:\((.*)\)')
342 _argument_prog = re.compile(r'\{(.*?)\}|:\((.*)\)')
343 343 introspector = self.request.registry.introspector
344 344
345 345 view_intr = {}
346 346 for view_data in introspector.get_category('views'):
347 347 intr = view_data['introspectable']
348 348
349 349 if 'route_name' in intr and intr['attr']:
350 350 view_intr[intr['route_name']] = '{}:{}'.format(
351 351 str(intr['derived_callable'].__name__), intr['attr']
352 352 )
353 353
354 354 c.whitelist_key = 'api_access_controllers_whitelist'
355 355 c.whitelist_file = CONFIG.get('__file__')
356 356 whitelist_views = aslist(
357 357 CONFIG.get(c.whitelist_key), sep=',')
358 358
359 359 for route_info in mapper.get_routes():
360 360 if not route_info.name.startswith('__'):
361 361 routepath = route_info.pattern
362 362
363 363 def replace(matchobj):
364 364 if matchobj.group(1):
365 365 return "{%s}" % matchobj.group(1).split(':')[0]
366 366 else:
367 367 return "{%s}" % matchobj.group(2)
368 368
369 369 routepath = _argument_prog.sub(replace, routepath)
370 370
371 371 if not routepath.startswith('/'):
372 372 routepath = '/' + routepath
373 373
374 374 view_fqn = view_intr.get(route_info.name, 'NOT AVAILABLE')
375 375 active = view_fqn in whitelist_views
376 376 c.view_data.append((route_info.name, view_fqn, routepath, active))
377 377
378 378 c.whitelist_views = whitelist_views
379 379 return self._get_template_context(c)
380 380
381 381 def ssh_enabled(self):
382 382 return self.request.registry.settings.get(
383 383 'ssh.generate_authorized_keyfile')
384 384
385 385 @LoginRequired()
386 386 @HasPermissionAllDecorator('hg.admin')
387 387 def ssh_keys(self):
388 388 c = self.load_default_context()
389 389 c.active = 'ssh_keys'
390 390 c.ssh_enabled = self.ssh_enabled()
391 391 return self._get_template_context(c)
392 392
393 393 @LoginRequired()
394 394 @HasPermissionAllDecorator('hg.admin')
395 395 def ssh_keys_data(self):
396 396 _ = self.request.translate
397 397 self.load_default_context()
398 398 column_map = {
399 399 'fingerprint': 'ssh_key_fingerprint',
400 400 'username': User.username
401 401 }
402 402 draw, start, limit = self._extract_chunk(self.request)
403 403 search_q, order_by, order_dir = self._extract_ordering(
404 404 self.request, column_map=column_map)
405 405
406 406 ssh_keys_data_total_count = UserSshKeys.query()\
407 407 .count()
408 408
409 409 # json generate
410 410 base_q = UserSshKeys.query().join(UserSshKeys.user)
411 411
412 412 if search_q:
413 413 like_expression = u'%{}%'.format(safe_unicode(search_q))
414 414 base_q = base_q.filter(or_(
415 415 User.username.ilike(like_expression),
416 416 UserSshKeys.ssh_key_fingerprint.ilike(like_expression),
417 417 ))
418 418
419 419 users_data_total_filtered_count = base_q.count()
420 420
421 421 sort_col = self._get_order_col(order_by, UserSshKeys)
422 422 if sort_col:
423 423 if order_dir == 'asc':
424 424 # handle null values properly to order by NULL last
425 425 if order_by in ['created_on']:
426 426 sort_col = coalesce(sort_col, datetime.date.max)
427 427 sort_col = sort_col.asc()
428 428 else:
429 429 # handle null values properly to order by NULL last
430 430 if order_by in ['created_on']:
431 431 sort_col = coalesce(sort_col, datetime.date.min)
432 432 sort_col = sort_col.desc()
433 433
434 434 base_q = base_q.order_by(sort_col)
435 435 base_q = base_q.offset(start).limit(limit)
436 436
437 437 ssh_keys = base_q.all()
438 438
439 439 ssh_keys_data = []
440 440 for ssh_key in ssh_keys:
441 441 ssh_keys_data.append({
442 442 "username": h.gravatar_with_user(self.request, ssh_key.user.username),
443 443 "fingerprint": ssh_key.ssh_key_fingerprint,
444 444 "description": ssh_key.description,
445 445 "created_on": h.format_date(ssh_key.created_on),
446 446 "accessed_on": h.format_date(ssh_key.accessed_on),
447 447 "action": h.link_to(
448 448 _('Edit'), h.route_path('edit_user_ssh_keys',
449 449 user_id=ssh_key.user.user_id))
450 450 })
451 451
452 452 data = ({
453 453 'draw': draw,
454 454 'data': ssh_keys_data,
455 455 'recordsTotal': ssh_keys_data_total_count,
456 456 'recordsFiltered': users_data_total_filtered_count,
457 457 })
458 458
459 459 return data
460 460
461 461 @LoginRequired()
462 462 @HasPermissionAllDecorator('hg.admin')
463 463 @CSRFRequired()
464 464 def ssh_keys_update(self):
465 465 _ = self.request.translate
466 466 self.load_default_context()
467 467
468 468 ssh_enabled = self.ssh_enabled()
469 469 key_file = self.request.registry.settings.get(
470 470 'ssh.authorized_keys_file_path')
471 471 if ssh_enabled:
472 472 events.trigger(SshKeyFileChangeEvent(), self.request.registry)
473 473 h.flash(_('Updated SSH keys file: {}').format(key_file),
474 474 category='success')
475 475 else:
476 476 h.flash(_('SSH key support is disabled in .ini file'),
477 477 category='warning')
478 478
479 479 raise HTTPFound(h.route_path('admin_permissions_ssh_keys'))
@@ -1,58 +1,57 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2016-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21
21 import io
22 22 import uuid
23 from io import StringIO
24 23 import pathlib2
25 24
26 25
27 26 def get_file_storage(settings):
28 27 from rhodecode.apps.file_store.backends.local_store import LocalFileStorage
29 28 from rhodecode.apps.file_store import config_keys
30 29 store_path = settings.get(config_keys.store_path)
31 30 return LocalFileStorage(base_path=store_path)
32 31
33 32
34 33 def splitext(filename):
35 34 ext = ''.join(pathlib2.Path(filename).suffixes)
36 35 return filename, ext
37 36
38 37
39 38 def uid_filename(filename, randomized=True):
40 39 """
41 40 Generates a randomized or stable (uuid) filename,
42 41 preserving the original extension.
43 42
44 43 :param filename: the original filename
45 44 :param randomized: define if filename should be stable (sha1 based) or randomized
46 45 """
47 46
48 47 _, ext = splitext(filename)
49 48 if randomized:
50 49 uid = uuid.uuid4()
51 50 else:
52 51 hash_key = '{}.{}'.format(filename, 'store')
53 52 uid = uuid.uuid5(uuid.NAMESPACE_URL, hash_key)
54 53 return str(uid) + ext.lower()
55 54
56 55
57 56 def bytes_to_file_obj(bytes_data):
58 return StringIO.StringIO(bytes_data)
57 return io.StringIO(bytes_data)
@@ -1,580 +1,580 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import urllib.parse
22 22
23 23 import mock
24 24 import pytest
25 25
26 26 from rhodecode.tests import (
27 27 assert_session_flash, HG_REPO, TEST_USER_ADMIN_LOGIN,
28 28 no_newline_id_generator)
29 29 from rhodecode.tests.fixture import Fixture
30 30 from rhodecode.lib.auth import check_password
31 31 from rhodecode.lib import helpers as h
32 32 from rhodecode.model.auth_token import AuthTokenModel
33 33 from rhodecode.model.db import User, Notification, UserApiKeys
34 34 from rhodecode.model.meta import Session
35 35
36 36 fixture = Fixture()
37 37
38 38 whitelist_view = ['RepoCommitsView:repo_commit_raw']
39 39
40 40
41 41 def route_path(name, params=None, **kwargs):
42 42 import urllib.request, urllib.parse, urllib.error
43 43 from rhodecode.apps._base import ADMIN_PREFIX
44 44
45 45 base_url = {
46 46 'login': ADMIN_PREFIX + '/login',
47 47 'logout': ADMIN_PREFIX + '/logout',
48 48 'register': ADMIN_PREFIX + '/register',
49 49 'reset_password':
50 50 ADMIN_PREFIX + '/password_reset',
51 51 'reset_password_confirmation':
52 52 ADMIN_PREFIX + '/password_reset_confirmation',
53 53
54 54 'admin_permissions_application':
55 55 ADMIN_PREFIX + '/permissions/application',
56 56 'admin_permissions_application_update':
57 57 ADMIN_PREFIX + '/permissions/application/update',
58 58
59 59 'repo_commit_raw': '/{repo_name}/raw-changeset/{commit_id}'
60 60
61 61 }[name].format(**kwargs)
62 62
63 63 if params:
64 64 base_url = '{}?{}'.format(base_url, urllib.parse.urlencode(params))
65 65 return base_url
66 66
67 67
68 68 @pytest.mark.usefixtures('app')
69 69 class TestLoginController(object):
70 70 destroy_users = set()
71 71
72 72 @classmethod
73 73 def teardown_class(cls):
74 74 fixture.destroy_users(cls.destroy_users)
75 75
76 76 def teardown_method(self, method):
77 77 for n in Notification.query().all():
78 78 Session().delete(n)
79 79
80 80 Session().commit()
81 81 assert Notification.query().all() == []
82 82
83 83 def test_index(self):
84 84 response = self.app.get(route_path('login'))
85 85 assert response.status == '200 OK'
86 86 # Test response...
87 87
88 88 def test_login_admin_ok(self):
89 89 response = self.app.post(route_path('login'),
90 90 {'username': 'test_admin',
91 91 'password': 'test12'}, status=302)
92 92 response = response.follow()
93 93 session = response.get_session_from_response()
94 94 username = session['rhodecode_user'].get('username')
95 95 assert username == 'test_admin'
96 96 response.mustcontain('logout')
97 97
98 98 def test_login_regular_ok(self):
99 99 response = self.app.post(route_path('login'),
100 100 {'username': 'test_regular',
101 101 'password': 'test12'}, status=302)
102 102
103 103 response = response.follow()
104 104 session = response.get_session_from_response()
105 105 username = session['rhodecode_user'].get('username')
106 106 assert username == 'test_regular'
107 107 response.mustcontain('logout')
108 108
109 109 def test_login_regular_forbidden_when_super_admin_restriction(self):
110 110 from rhodecode.authentication.plugins.auth_rhodecode import RhodeCodeAuthPlugin
111 111 with fixture.auth_restriction(self.app._pyramid_registry,
112 112 RhodeCodeAuthPlugin.AUTH_RESTRICTION_SUPER_ADMIN):
113 113 response = self.app.post(route_path('login'),
114 114 {'username': 'test_regular',
115 115 'password': 'test12'})
116 116
117 117 response.mustcontain('invalid user name')
118 118 response.mustcontain('invalid password')
119 119
120 120 def test_login_regular_forbidden_when_scope_restriction(self):
121 121 from rhodecode.authentication.plugins.auth_rhodecode import RhodeCodeAuthPlugin
122 122 with fixture.scope_restriction(self.app._pyramid_registry,
123 123 RhodeCodeAuthPlugin.AUTH_RESTRICTION_SCOPE_VCS):
124 124 response = self.app.post(route_path('login'),
125 125 {'username': 'test_regular',
126 126 'password': 'test12'})
127 127
128 128 response.mustcontain('invalid user name')
129 129 response.mustcontain('invalid password')
130 130
131 131 def test_login_ok_came_from(self):
132 132 test_came_from = '/_admin/users?branch=stable'
133 133 _url = '{}?came_from={}'.format(route_path('login'), test_came_from)
134 134 response = self.app.post(
135 135 _url, {'username': 'test_admin', 'password': 'test12'}, status=302)
136 136
137 137 assert 'branch=stable' in response.location
138 138 response = response.follow()
139 139
140 140 assert response.status == '200 OK'
141 141 response.mustcontain('Users administration')
142 142
143 143 def test_redirect_to_login_with_get_args(self):
144 144 with fixture.anon_access(False):
145 145 kwargs = {'branch': 'stable'}
146 146 response = self.app.get(
147 147 h.route_path('repo_summary', repo_name=HG_REPO, _query=kwargs),
148 148 status=302)
149 149
150 response_query = urllib.parse.urlparse.parse_qsl(response.location)
150 response_query = urllib.parse.parse_qsl(response.location)
151 151 assert 'branch=stable' in response_query[0][1]
152 152
153 153 def test_login_form_with_get_args(self):
154 154 _url = '{}?came_from=/_admin/users,branch=stable'.format(route_path('login'))
155 155 response = self.app.get(_url)
156 156 assert 'branch%3Dstable' in response.form.action
157 157
158 158 @pytest.mark.parametrize("url_came_from", [
159 159 'data:text/html,<script>window.alert("xss")</script>',
160 160 'mailto:test@rhodecode.org',
161 161 'file:///etc/passwd',
162 162 'ftp://some.ftp.server',
163 163 'http://other.domain',
164 164 '/\r\nX-Forwarded-Host: http://example.org',
165 165 ], ids=no_newline_id_generator)
166 166 def test_login_bad_came_froms(self, url_came_from):
167 167 _url = '{}?came_from={}'.format(route_path('login'), url_came_from)
168 168 response = self.app.post(
169 169 _url,
170 170 {'username': 'test_admin', 'password': 'test12'})
171 171 assert response.status == '302 Found'
172 172 response = response.follow()
173 173 assert response.status == '200 OK'
174 174 assert response.request.path == '/'
175 175
176 176 def test_login_short_password(self):
177 177 response = self.app.post(route_path('login'),
178 178 {'username': 'test_admin',
179 179 'password': 'as'})
180 180 assert response.status == '200 OK'
181 181
182 182 response.mustcontain('Enter 3 characters or more')
183 183
184 184 def test_login_wrong_non_ascii_password(self, user_regular):
185 185 response = self.app.post(
186 186 route_path('login'),
187 187 {'username': user_regular.username,
188 188 'password': u'invalid-non-asci\xe4'.encode('utf8')})
189 189
190 190 response.mustcontain('invalid user name')
191 191 response.mustcontain('invalid password')
192 192
193 193 def test_login_with_non_ascii_password(self, user_util):
194 194 password = u'valid-non-ascii\xe4'
195 195 user = user_util.create_user(password=password)
196 196 response = self.app.post(
197 197 route_path('login'),
198 198 {'username': user.username,
199 199 'password': password})
200 200 assert response.status_code == 302
201 201
202 202 def test_login_wrong_username_password(self):
203 203 response = self.app.post(route_path('login'),
204 204 {'username': 'error',
205 205 'password': 'test12'})
206 206
207 207 response.mustcontain('invalid user name')
208 208 response.mustcontain('invalid password')
209 209
210 210 def test_login_admin_ok_password_migration(self, real_crypto_backend):
211 211 from rhodecode.lib import auth
212 212
213 213 # create new user, with sha256 password
214 214 temp_user = 'test_admin_sha256'
215 215 user = fixture.create_user(temp_user)
216 216 user.password = auth._RhodeCodeCryptoSha256().hash_create(
217 217 b'test123')
218 218 Session().add(user)
219 219 Session().commit()
220 220 self.destroy_users.add(temp_user)
221 221 response = self.app.post(route_path('login'),
222 222 {'username': temp_user,
223 223 'password': 'test123'}, status=302)
224 224
225 225 response = response.follow()
226 226 session = response.get_session_from_response()
227 227 username = session['rhodecode_user'].get('username')
228 228 assert username == temp_user
229 229 response.mustcontain('logout')
230 230
231 231 # new password should be bcrypted, after log-in and transfer
232 232 user = User.get_by_username(temp_user)
233 233 assert user.password.startswith('$')
234 234
235 235 # REGISTRATIONS
236 236 def test_register(self):
237 237 response = self.app.get(route_path('register'))
238 238 response.mustcontain('Create an Account')
239 239
240 240 def test_register_err_same_username(self):
241 241 uname = 'test_admin'
242 242 response = self.app.post(
243 243 route_path('register'),
244 244 {
245 245 'username': uname,
246 246 'password': 'test12',
247 247 'password_confirmation': 'test12',
248 248 'email': 'goodmail@domain.com',
249 249 'firstname': 'test',
250 250 'lastname': 'test'
251 251 }
252 252 )
253 253
254 254 assertr = response.assert_response()
255 255 msg = 'Username "%(username)s" already exists'
256 256 msg = msg % {'username': uname}
257 257 assertr.element_contains('#username+.error-message', msg)
258 258
259 259 def test_register_err_same_email(self):
260 260 response = self.app.post(
261 261 route_path('register'),
262 262 {
263 263 'username': 'test_admin_0',
264 264 'password': 'test12',
265 265 'password_confirmation': 'test12',
266 266 'email': 'test_admin@mail.com',
267 267 'firstname': 'test',
268 268 'lastname': 'test'
269 269 }
270 270 )
271 271
272 272 assertr = response.assert_response()
273 273 msg = u'This e-mail address is already taken'
274 274 assertr.element_contains('#email+.error-message', msg)
275 275
276 276 def test_register_err_same_email_case_sensitive(self):
277 277 response = self.app.post(
278 278 route_path('register'),
279 279 {
280 280 'username': 'test_admin_1',
281 281 'password': 'test12',
282 282 'password_confirmation': 'test12',
283 283 'email': 'TesT_Admin@mail.COM',
284 284 'firstname': 'test',
285 285 'lastname': 'test'
286 286 }
287 287 )
288 288 assertr = response.assert_response()
289 289 msg = u'This e-mail address is already taken'
290 290 assertr.element_contains('#email+.error-message', msg)
291 291
292 292 def test_register_err_wrong_data(self):
293 293 response = self.app.post(
294 294 route_path('register'),
295 295 {
296 296 'username': 'xs',
297 297 'password': 'test',
298 298 'password_confirmation': 'test',
299 299 'email': 'goodmailm',
300 300 'firstname': 'test',
301 301 'lastname': 'test'
302 302 }
303 303 )
304 304 assert response.status == '200 OK'
305 305 response.mustcontain('An email address must contain a single @')
306 306 response.mustcontain('Enter a value 6 characters long or more')
307 307
308 308 def test_register_err_username(self):
309 309 response = self.app.post(
310 310 route_path('register'),
311 311 {
312 312 'username': 'error user',
313 313 'password': 'test12',
314 314 'password_confirmation': 'test12',
315 315 'email': 'goodmailm',
316 316 'firstname': 'test',
317 317 'lastname': 'test'
318 318 }
319 319 )
320 320
321 321 response.mustcontain('An email address must contain a single @')
322 322 response.mustcontain(
323 323 'Username may only contain '
324 324 'alphanumeric characters underscores, '
325 325 'periods or dashes and must begin with '
326 326 'alphanumeric character')
327 327
328 328 def test_register_err_case_sensitive(self):
329 329 usr = 'Test_Admin'
330 330 response = self.app.post(
331 331 route_path('register'),
332 332 {
333 333 'username': usr,
334 334 'password': 'test12',
335 335 'password_confirmation': 'test12',
336 336 'email': 'goodmailm',
337 337 'firstname': 'test',
338 338 'lastname': 'test'
339 339 }
340 340 )
341 341
342 342 assertr = response.assert_response()
343 343 msg = u'Username "%(username)s" already exists'
344 344 msg = msg % {'username': usr}
345 345 assertr.element_contains('#username+.error-message', msg)
346 346
347 347 def test_register_special_chars(self):
348 348 response = self.app.post(
349 349 route_path('register'),
350 350 {
351 351 'username': 'xxxaxn',
352 352 'password': 'ąćźżąśśśś',
353 353 'password_confirmation': 'ąćźżąśśśś',
354 354 'email': 'goodmailm@test.plx',
355 355 'firstname': 'test',
356 356 'lastname': 'test'
357 357 }
358 358 )
359 359
360 360 msg = u'Invalid characters (non-ascii) in password'
361 361 response.mustcontain(msg)
362 362
363 363 def test_register_password_mismatch(self):
364 364 response = self.app.post(
365 365 route_path('register'),
366 366 {
367 367 'username': 'xs',
368 368 'password': '123qwe',
369 369 'password_confirmation': 'qwe123',
370 370 'email': 'goodmailm@test.plxa',
371 371 'firstname': 'test',
372 372 'lastname': 'test'
373 373 }
374 374 )
375 375 msg = u'Passwords do not match'
376 376 response.mustcontain(msg)
377 377
378 378 def test_register_ok(self):
379 379 username = 'test_regular4'
380 380 password = 'qweqwe'
381 381 email = 'marcin@test.com'
382 382 name = 'testname'
383 383 lastname = 'testlastname'
384 384
385 385 # this initializes a session
386 386 response = self.app.get(route_path('register'))
387 387 response.mustcontain('Create an Account')
388 388
389 389
390 390 response = self.app.post(
391 391 route_path('register'),
392 392 {
393 393 'username': username,
394 394 'password': password,
395 395 'password_confirmation': password,
396 396 'email': email,
397 397 'firstname': name,
398 398 'lastname': lastname,
399 399 'admin': True
400 400 },
401 401 status=302
402 402 ) # This should be overridden
403 403
404 404 assert_session_flash(
405 405 response, 'You have successfully registered with RhodeCode. You can log-in now.')
406 406
407 407 ret = Session().query(User).filter(
408 408 User.username == 'test_regular4').one()
409 409 assert ret.username == username
410 410 assert check_password(password, ret.password)
411 411 assert ret.email == email
412 412 assert ret.name == name
413 413 assert ret.lastname == lastname
414 414 assert ret.auth_tokens is not None
415 415 assert not ret.admin
416 416
417 417 def test_forgot_password_wrong_mail(self):
418 418 bad_email = 'marcin@wrongmail.org'
419 419 # this initializes a session
420 420 self.app.get(route_path('reset_password'))
421 421
422 422 response = self.app.post(
423 423 route_path('reset_password'), {'email': bad_email, }
424 424 )
425 425 assert_session_flash(response,
426 426 'If such email exists, a password reset link was sent to it.')
427 427
428 428 def test_forgot_password(self, user_util):
429 429 # this initializes a session
430 430 self.app.get(route_path('reset_password'))
431 431
432 432 user = user_util.create_user()
433 433 user_id = user.user_id
434 434 email = user.email
435 435
436 436 response = self.app.post(route_path('reset_password'), {'email': email, })
437 437
438 438 assert_session_flash(response,
439 439 'If such email exists, a password reset link was sent to it.')
440 440
441 441 # BAD KEY
442 442 confirm_url = '{}?key={}'.format(route_path('reset_password_confirmation'), 'badkey')
443 443 response = self.app.get(confirm_url, status=302)
444 444 assert response.location.endswith(route_path('reset_password'))
445 445 assert_session_flash(response, 'Given reset token is invalid')
446 446
447 447 response.follow() # cleanup flash
448 448
449 449 # GOOD KEY
450 450 key = UserApiKeys.query()\
451 451 .filter(UserApiKeys.user_id == user_id)\
452 452 .filter(UserApiKeys.role == UserApiKeys.ROLE_PASSWORD_RESET)\
453 453 .first()
454 454
455 455 assert key
456 456
457 457 confirm_url = '{}?key={}'.format(route_path('reset_password_confirmation'), key.api_key)
458 458 response = self.app.get(confirm_url)
459 459 assert response.status == '302 Found'
460 460 assert response.location.endswith(route_path('login'))
461 461
462 462 assert_session_flash(
463 463 response,
464 464 'Your password reset was successful, '
465 465 'a new password has been sent to your email')
466 466
467 467 response.follow()
468 468
469 469 def _get_api_whitelist(self, values=None):
470 470 config = {'api_access_controllers_whitelist': values or []}
471 471 return config
472 472
473 473 @pytest.mark.parametrize("test_name, auth_token", [
474 474 ('none', None),
475 475 ('empty_string', ''),
476 476 ('fake_number', '123456'),
477 477 ('proper_auth_token', None)
478 478 ])
479 479 def test_access_not_whitelisted_page_via_auth_token(
480 480 self, test_name, auth_token, user_admin):
481 481
482 482 whitelist = self._get_api_whitelist([])
483 483 with mock.patch.dict('rhodecode.CONFIG', whitelist):
484 484 assert [] == whitelist['api_access_controllers_whitelist']
485 485 if test_name == 'proper_auth_token':
486 486 # use builtin if api_key is None
487 487 auth_token = user_admin.api_key
488 488
489 489 with fixture.anon_access(False):
490 490 self.app.get(
491 491 route_path('repo_commit_raw',
492 492 repo_name=HG_REPO, commit_id='tip',
493 493 params=dict(api_key=auth_token)),
494 494 status=302)
495 495
496 496 @pytest.mark.parametrize("test_name, auth_token, code", [
497 497 ('none', None, 302),
498 498 ('empty_string', '', 302),
499 499 ('fake_number', '123456', 302),
500 500 ('proper_auth_token', None, 200)
501 501 ])
502 502 def test_access_whitelisted_page_via_auth_token(
503 503 self, test_name, auth_token, code, user_admin):
504 504
505 505 whitelist = self._get_api_whitelist(whitelist_view)
506 506
507 507 with mock.patch.dict('rhodecode.CONFIG', whitelist):
508 508 assert whitelist_view == whitelist['api_access_controllers_whitelist']
509 509
510 510 if test_name == 'proper_auth_token':
511 511 auth_token = user_admin.api_key
512 512 assert auth_token
513 513
514 514 with fixture.anon_access(False):
515 515 self.app.get(
516 516 route_path('repo_commit_raw',
517 517 repo_name=HG_REPO, commit_id='tip',
518 518 params=dict(api_key=auth_token)),
519 519 status=code)
520 520
521 521 @pytest.mark.parametrize("test_name, auth_token, code", [
522 522 ('proper_auth_token', None, 200),
523 523 ('wrong_auth_token', '123456', 302),
524 524 ])
525 525 def test_access_whitelisted_page_via_auth_token_bound_to_token(
526 526 self, test_name, auth_token, code, user_admin):
527 527
528 528 expected_token = auth_token
529 529 if test_name == 'proper_auth_token':
530 530 auth_token = user_admin.api_key
531 531 expected_token = auth_token
532 532 assert auth_token
533 533
534 534 whitelist = self._get_api_whitelist([
535 535 'RepoCommitsView:repo_commit_raw@{}'.format(expected_token)])
536 536
537 537 with mock.patch.dict('rhodecode.CONFIG', whitelist):
538 538
539 539 with fixture.anon_access(False):
540 540 self.app.get(
541 541 route_path('repo_commit_raw',
542 542 repo_name=HG_REPO, commit_id='tip',
543 543 params=dict(api_key=auth_token)),
544 544 status=code)
545 545
546 546 def test_access_page_via_extra_auth_token(self):
547 547 whitelist = self._get_api_whitelist(whitelist_view)
548 548 with mock.patch.dict('rhodecode.CONFIG', whitelist):
549 549 assert whitelist_view == \
550 550 whitelist['api_access_controllers_whitelist']
551 551
552 552 new_auth_token = AuthTokenModel().create(
553 553 TEST_USER_ADMIN_LOGIN, 'test')
554 554 Session().commit()
555 555 with fixture.anon_access(False):
556 556 self.app.get(
557 557 route_path('repo_commit_raw',
558 558 repo_name=HG_REPO, commit_id='tip',
559 559 params=dict(api_key=new_auth_token.api_key)),
560 560 status=200)
561 561
562 562 def test_access_page_via_expired_auth_token(self):
563 563 whitelist = self._get_api_whitelist(whitelist_view)
564 564 with mock.patch.dict('rhodecode.CONFIG', whitelist):
565 565 assert whitelist_view == \
566 566 whitelist['api_access_controllers_whitelist']
567 567
568 568 new_auth_token = AuthTokenModel().create(
569 569 TEST_USER_ADMIN_LOGIN, 'test')
570 570 Session().commit()
571 571 # patch the api key and make it expired
572 572 new_auth_token.expires = 0
573 573 Session().add(new_auth_token)
574 574 Session().commit()
575 575 with fixture.anon_access(False):
576 576 self.app.get(
577 577 route_path('repo_commit_raw',
578 578 repo_name=HG_REPO, commit_id='tip',
579 579 params=dict(api_key=new_auth_token.api_key)),
580 580 status=302)
@@ -1,1877 +1,1877 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import logging
22 22 import collections
23 23
24 24 import formencode
25 25 import formencode.htmlfill
26 26 import peppercorn
27 27 from pyramid.httpexceptions import (
28 28 HTTPFound, HTTPNotFound, HTTPForbidden, HTTPBadRequest, HTTPConflict)
29 29
30 30 from pyramid.renderers import render
31 31
32 32 from rhodecode.apps._base import RepoAppView, DataGridAppView
33 33
34 34 from rhodecode.lib import helpers as h, diffs, codeblocks, channelstream
35 35 from rhodecode.lib.base import vcs_operation_context
36 36 from rhodecode.lib.diffs import load_cached_diff, cache_diff, diff_cache_exist
37 37 from rhodecode.lib.exceptions import CommentVersionMismatch
38 38 from rhodecode.lib.ext_json import json
39 39 from rhodecode.lib.auth import (
40 40 LoginRequired, HasRepoPermissionAny, HasRepoPermissionAnyDecorator,
41 41 NotAnonymous, CSRFRequired)
42 42 from rhodecode.lib.utils2 import str2bool, safe_str, safe_unicode, safe_int, aslist, retry
43 43 from rhodecode.lib.vcs.backends.base import (
44 44 EmptyCommit, UpdateFailureReason, unicode_to_reference)
45 45 from rhodecode.lib.vcs.exceptions import (
46 46 CommitDoesNotExistError, RepositoryRequirementError, EmptyRepositoryError)
47 47 from rhodecode.model.changeset_status import ChangesetStatusModel
48 48 from rhodecode.model.comment import CommentsModel
49 49 from rhodecode.model.db import (
50 50 func, false, or_, PullRequest, ChangesetComment, ChangesetStatus, Repository,
51 51 PullRequestReviewers)
52 52 from rhodecode.model.forms import PullRequestForm
53 53 from rhodecode.model.meta import Session
54 54 from rhodecode.model.pull_request import PullRequestModel, MergeCheck
55 55 from rhodecode.model.scm import ScmModel
56 56
57 57 log = logging.getLogger(__name__)
58 58
59 59
60 60 class RepoPullRequestsView(RepoAppView, DataGridAppView):
61 61
62 62 def load_default_context(self):
63 63 c = self._get_local_tmpl_context(include_app_defaults=True)
64 64 c.REVIEW_STATUS_APPROVED = ChangesetStatus.STATUS_APPROVED
65 65 c.REVIEW_STATUS_REJECTED = ChangesetStatus.STATUS_REJECTED
66 66 # backward compat., we use for OLD PRs a plain renderer
67 67 c.renderer = 'plain'
68 68 return c
69 69
70 70 def _get_pull_requests_list(
71 71 self, repo_name, source, filter_type, opened_by, statuses):
72 72
73 73 draw, start, limit = self._extract_chunk(self.request)
74 74 search_q, order_by, order_dir = self._extract_ordering(self.request)
75 75 _render = self.request.get_partial_renderer(
76 76 'rhodecode:templates/data_table/_dt_elements.mako')
77 77
78 78 # pagination
79 79
80 80 if filter_type == 'awaiting_review':
81 81 pull_requests = PullRequestModel().get_awaiting_review(
82 82 repo_name,
83 83 search_q=search_q, statuses=statuses,
84 84 offset=start, length=limit, order_by=order_by, order_dir=order_dir)
85 85 pull_requests_total_count = PullRequestModel().count_awaiting_review(
86 86 repo_name,
87 87 search_q=search_q, statuses=statuses)
88 88 elif filter_type == 'awaiting_my_review':
89 89 pull_requests = PullRequestModel().get_awaiting_my_review(
90 90 repo_name, self._rhodecode_user.user_id,
91 91 search_q=search_q, statuses=statuses,
92 92 offset=start, length=limit, order_by=order_by, order_dir=order_dir)
93 93 pull_requests_total_count = PullRequestModel().count_awaiting_my_review(
94 94 repo_name, self._rhodecode_user.user_id,
95 95 search_q=search_q, statuses=statuses)
96 96 else:
97 97 pull_requests = PullRequestModel().get_all(
98 98 repo_name, search_q=search_q, source=source, opened_by=opened_by,
99 99 statuses=statuses, offset=start, length=limit,
100 100 order_by=order_by, order_dir=order_dir)
101 101 pull_requests_total_count = PullRequestModel().count_all(
102 102 repo_name, search_q=search_q, source=source, statuses=statuses,
103 103 opened_by=opened_by)
104 104
105 105 data = []
106 106 comments_model = CommentsModel()
107 107 for pr in pull_requests:
108 108 comments_count = comments_model.get_all_comments(
109 109 self.db_repo.repo_id, pull_request=pr,
110 110 include_drafts=False, count_only=True)
111 111
112 112 review_statuses = pr.reviewers_statuses(user=self._rhodecode_db_user)
113 113 my_review_status = ChangesetStatus.STATUS_NOT_REVIEWED
114 114 if review_statuses and review_statuses[4]:
115 115 _review_obj, _user, _reasons, _mandatory, statuses = review_statuses
116 116 my_review_status = statuses[0][1].status
117 117
118 118 data.append({
119 119 'name': _render('pullrequest_name',
120 120 pr.pull_request_id, pr.pull_request_state,
121 121 pr.work_in_progress, pr.target_repo.repo_name,
122 122 short=True),
123 123 'name_raw': pr.pull_request_id,
124 124 'status': _render('pullrequest_status',
125 125 pr.calculated_review_status()),
126 126 'my_status': _render('pullrequest_status',
127 127 my_review_status),
128 128 'title': _render('pullrequest_title', pr.title, pr.description),
129 129 'description': h.escape(pr.description),
130 130 'updated_on': _render('pullrequest_updated_on',
131 131 h.datetime_to_time(pr.updated_on),
132 132 pr.versions_count),
133 133 'updated_on_raw': h.datetime_to_time(pr.updated_on),
134 134 'created_on': _render('pullrequest_updated_on',
135 135 h.datetime_to_time(pr.created_on)),
136 136 'created_on_raw': h.datetime_to_time(pr.created_on),
137 137 'state': pr.pull_request_state,
138 138 'author': _render('pullrequest_author',
139 139 pr.author.full_contact, ),
140 140 'author_raw': pr.author.full_name,
141 141 'comments': _render('pullrequest_comments', comments_count),
142 142 'comments_raw': comments_count,
143 143 'closed': pr.is_closed(),
144 144 })
145 145
146 146 data = ({
147 147 'draw': draw,
148 148 'data': data,
149 149 'recordsTotal': pull_requests_total_count,
150 150 'recordsFiltered': pull_requests_total_count,
151 151 })
152 152 return data
153 153
154 154 @LoginRequired()
155 155 @HasRepoPermissionAnyDecorator(
156 156 'repository.read', 'repository.write', 'repository.admin')
157 157 def pull_request_list(self):
158 158 c = self.load_default_context()
159 159
160 160 req_get = self.request.GET
161 161 c.source = str2bool(req_get.get('source'))
162 162 c.closed = str2bool(req_get.get('closed'))
163 163 c.my = str2bool(req_get.get('my'))
164 164 c.awaiting_review = str2bool(req_get.get('awaiting_review'))
165 165 c.awaiting_my_review = str2bool(req_get.get('awaiting_my_review'))
166 166
167 167 c.active = 'open'
168 168 if c.my:
169 169 c.active = 'my'
170 170 if c.closed:
171 171 c.active = 'closed'
172 172 if c.awaiting_review and not c.source:
173 173 c.active = 'awaiting'
174 174 if c.source and not c.awaiting_review:
175 175 c.active = 'source'
176 176 if c.awaiting_my_review:
177 177 c.active = 'awaiting_my'
178 178
179 179 return self._get_template_context(c)
180 180
181 181 @LoginRequired()
182 182 @HasRepoPermissionAnyDecorator(
183 183 'repository.read', 'repository.write', 'repository.admin')
184 184 def pull_request_list_data(self):
185 185 self.load_default_context()
186 186
187 187 # additional filters
188 188 req_get = self.request.GET
189 189 source = str2bool(req_get.get('source'))
190 190 closed = str2bool(req_get.get('closed'))
191 191 my = str2bool(req_get.get('my'))
192 192 awaiting_review = str2bool(req_get.get('awaiting_review'))
193 193 awaiting_my_review = str2bool(req_get.get('awaiting_my_review'))
194 194
195 195 filter_type = 'awaiting_review' if awaiting_review \
196 196 else 'awaiting_my_review' if awaiting_my_review \
197 197 else None
198 198
199 199 opened_by = None
200 200 if my:
201 201 opened_by = [self._rhodecode_user.user_id]
202 202
203 203 statuses = [PullRequest.STATUS_NEW, PullRequest.STATUS_OPEN]
204 204 if closed:
205 205 statuses = [PullRequest.STATUS_CLOSED]
206 206
207 207 data = self._get_pull_requests_list(
208 208 repo_name=self.db_repo_name, source=source,
209 209 filter_type=filter_type, opened_by=opened_by, statuses=statuses)
210 210
211 211 return data
212 212
213 213 def _is_diff_cache_enabled(self, target_repo):
214 214 caching_enabled = self._get_general_setting(
215 215 target_repo, 'rhodecode_diff_cache')
216 216 log.debug('Diff caching enabled: %s', caching_enabled)
217 217 return caching_enabled
218 218
219 219 def _get_diffset(self, source_repo_name, source_repo,
220 220 ancestor_commit,
221 221 source_ref_id, target_ref_id,
222 222 target_commit, source_commit, diff_limit, file_limit,
223 223 fulldiff, hide_whitespace_changes, diff_context, use_ancestor=True):
224 224
225 225 target_commit_final = target_commit
226 226 source_commit_final = source_commit
227 227
228 228 if use_ancestor:
229 229 # we might want to not use it for versions
230 230 target_ref_id = ancestor_commit.raw_id
231 231 target_commit_final = ancestor_commit
232 232
233 233 vcs_diff = PullRequestModel().get_diff(
234 234 source_repo, source_ref_id, target_ref_id,
235 235 hide_whitespace_changes, diff_context)
236 236
237 237 diff_processor = diffs.DiffProcessor(
238 238 vcs_diff, format='newdiff', diff_limit=diff_limit,
239 239 file_limit=file_limit, show_full_diff=fulldiff)
240 240
241 241 _parsed = diff_processor.prepare()
242 242
243 243 diffset = codeblocks.DiffSet(
244 244 repo_name=self.db_repo_name,
245 245 source_repo_name=source_repo_name,
246 246 source_node_getter=codeblocks.diffset_node_getter(target_commit_final),
247 247 target_node_getter=codeblocks.diffset_node_getter(source_commit_final),
248 248 )
249 249 diffset = self.path_filter.render_patchset_filtered(
250 250 diffset, _parsed, target_ref_id, source_ref_id)
251 251
252 252 return diffset
253 253
254 254 def _get_range_diffset(self, source_scm, source_repo,
255 255 commit1, commit2, diff_limit, file_limit,
256 256 fulldiff, hide_whitespace_changes, diff_context):
257 257 vcs_diff = source_scm.get_diff(
258 258 commit1, commit2,
259 259 ignore_whitespace=hide_whitespace_changes,
260 260 context=diff_context)
261 261
262 262 diff_processor = diffs.DiffProcessor(
263 263 vcs_diff, format='newdiff', diff_limit=diff_limit,
264 264 file_limit=file_limit, show_full_diff=fulldiff)
265 265
266 266 _parsed = diff_processor.prepare()
267 267
268 268 diffset = codeblocks.DiffSet(
269 269 repo_name=source_repo.repo_name,
270 270 source_node_getter=codeblocks.diffset_node_getter(commit1),
271 271 target_node_getter=codeblocks.diffset_node_getter(commit2))
272 272
273 273 diffset = self.path_filter.render_patchset_filtered(
274 274 diffset, _parsed, commit1.raw_id, commit2.raw_id)
275 275
276 276 return diffset
277 277
278 278 def register_comments_vars(self, c, pull_request, versions, include_drafts=True):
279 279 comments_model = CommentsModel()
280 280
281 281 # GENERAL COMMENTS with versions #
282 282 q = comments_model._all_general_comments_of_pull_request(pull_request)
283 283 q = q.order_by(ChangesetComment.comment_id.asc())
284 284 if not include_drafts:
285 285 q = q.filter(ChangesetComment.draft == false())
286 286 general_comments = q
287 287
288 288 # pick comments we want to render at current version
289 289 c.comment_versions = comments_model.aggregate_comments(
290 290 general_comments, versions, c.at_version_num)
291 291
292 292 # INLINE COMMENTS with versions #
293 293 q = comments_model._all_inline_comments_of_pull_request(pull_request)
294 294 q = q.order_by(ChangesetComment.comment_id.asc())
295 295 if not include_drafts:
296 296 q = q.filter(ChangesetComment.draft == false())
297 297 inline_comments = q
298 298
299 299 c.inline_versions = comments_model.aggregate_comments(
300 300 inline_comments, versions, c.at_version_num, inline=True)
301 301
302 302 # Comments inline+general
303 303 if c.at_version:
304 304 c.inline_comments_flat = c.inline_versions[c.at_version_num]['display']
305 305 c.comments = c.comment_versions[c.at_version_num]['display']
306 306 else:
307 307 c.inline_comments_flat = c.inline_versions[c.at_version_num]['until']
308 308 c.comments = c.comment_versions[c.at_version_num]['until']
309 309
310 310 return general_comments, inline_comments
311 311
312 312 @LoginRequired()
313 313 @HasRepoPermissionAnyDecorator(
314 314 'repository.read', 'repository.write', 'repository.admin')
315 315 def pull_request_show(self):
316 316 _ = self.request.translate
317 317 c = self.load_default_context()
318 318
319 319 pull_request = PullRequest.get_or_404(
320 320 self.request.matchdict['pull_request_id'])
321 321 pull_request_id = pull_request.pull_request_id
322 322
323 323 c.state_progressing = pull_request.is_state_changing()
324 324 c.pr_broadcast_channel = channelstream.pr_channel(pull_request)
325 325
326 326 _new_state = {
327 327 'created': PullRequest.STATE_CREATED,
328 328 }.get(self.request.GET.get('force_state'))
329 329 can_force_state = c.is_super_admin or HasRepoPermissionAny('repository.admin')(c.repo_name)
330 330
331 331 if can_force_state and _new_state:
332 332 with pull_request.set_state(PullRequest.STATE_UPDATING, final_state=_new_state):
333 333 h.flash(
334 334 _('Pull Request state was force changed to `{}`').format(_new_state),
335 335 category='success')
336 336 Session().commit()
337 337
338 338 raise HTTPFound(h.route_path(
339 339 'pullrequest_show', repo_name=self.db_repo_name,
340 340 pull_request_id=pull_request_id))
341 341
342 342 version = self.request.GET.get('version')
343 343 from_version = self.request.GET.get('from_version') or version
344 344 merge_checks = self.request.GET.get('merge_checks')
345 345 c.fulldiff = str2bool(self.request.GET.get('fulldiff'))
346 346 force_refresh = str2bool(self.request.GET.get('force_refresh'))
347 347 c.range_diff_on = self.request.GET.get('range-diff') == "1"
348 348
349 349 # fetch global flags of ignore ws or context lines
350 350 diff_context = diffs.get_diff_context(self.request)
351 351 hide_whitespace_changes = diffs.get_diff_whitespace_flag(self.request)
352 352
353 353 (pull_request_latest,
354 354 pull_request_at_ver,
355 355 pull_request_display_obj,
356 356 at_version) = PullRequestModel().get_pr_version(
357 357 pull_request_id, version=version)
358 358
359 359 pr_closed = pull_request_latest.is_closed()
360 360
361 361 if pr_closed and (version or from_version):
362 362 # not allow to browse versions for closed PR
363 363 raise HTTPFound(h.route_path(
364 364 'pullrequest_show', repo_name=self.db_repo_name,
365 365 pull_request_id=pull_request_id))
366 366
367 367 versions = pull_request_display_obj.versions()
368 368
369 369 c.commit_versions = PullRequestModel().pr_commits_versions(versions)
370 370
371 371 # used to store per-commit range diffs
372 372 c.changes = collections.OrderedDict()
373 373
374 374 c.at_version = at_version
375 375 c.at_version_num = (at_version
376 376 if at_version and at_version != PullRequest.LATEST_VER
377 377 else None)
378 378
379 379 c.at_version_index = ChangesetComment.get_index_from_version(
380 380 c.at_version_num, versions)
381 381
382 382 (prev_pull_request_latest,
383 383 prev_pull_request_at_ver,
384 384 prev_pull_request_display_obj,
385 385 prev_at_version) = PullRequestModel().get_pr_version(
386 386 pull_request_id, version=from_version)
387 387
388 388 c.from_version = prev_at_version
389 389 c.from_version_num = (prev_at_version
390 390 if prev_at_version and prev_at_version != PullRequest.LATEST_VER
391 391 else None)
392 392 c.from_version_index = ChangesetComment.get_index_from_version(
393 393 c.from_version_num, versions)
394 394
395 395 # define if we're in COMPARE mode or VIEW at version mode
396 396 compare = at_version != prev_at_version
397 397
398 398 # pull_requests repo_name we opened it against
399 399 # ie. target_repo must match
400 400 if self.db_repo_name != pull_request_at_ver.target_repo.repo_name:
401 401 log.warning('Mismatch between the current repo: %s, and target %s',
402 402 self.db_repo_name, pull_request_at_ver.target_repo.repo_name)
403 403 raise HTTPNotFound()
404 404
405 405 c.shadow_clone_url = PullRequestModel().get_shadow_clone_url(pull_request_at_ver)
406 406
407 407 c.pull_request = pull_request_display_obj
408 408 c.renderer = pull_request_at_ver.description_renderer or c.renderer
409 409 c.pull_request_latest = pull_request_latest
410 410
411 411 # inject latest version
412 412 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
413 413 c.versions = versions + [latest_ver]
414 414
415 415 if compare or (at_version and not at_version == PullRequest.LATEST_VER):
416 416 c.allowed_to_change_status = False
417 417 c.allowed_to_update = False
418 418 c.allowed_to_merge = False
419 419 c.allowed_to_delete = False
420 420 c.allowed_to_comment = False
421 421 c.allowed_to_close = False
422 422 else:
423 423 can_change_status = PullRequestModel().check_user_change_status(
424 424 pull_request_at_ver, self._rhodecode_user)
425 425 c.allowed_to_change_status = can_change_status and not pr_closed
426 426
427 427 c.allowed_to_update = PullRequestModel().check_user_update(
428 428 pull_request_latest, self._rhodecode_user) and not pr_closed
429 429 c.allowed_to_merge = PullRequestModel().check_user_merge(
430 430 pull_request_latest, self._rhodecode_user) and not pr_closed
431 431 c.allowed_to_delete = PullRequestModel().check_user_delete(
432 432 pull_request_latest, self._rhodecode_user) and not pr_closed
433 433 c.allowed_to_comment = not pr_closed
434 434 c.allowed_to_close = c.allowed_to_merge and not pr_closed
435 435
436 436 c.forbid_adding_reviewers = False
437 437
438 438 if pull_request_latest.reviewer_data and \
439 439 'rules' in pull_request_latest.reviewer_data:
440 440 rules = pull_request_latest.reviewer_data['rules'] or {}
441 441 try:
442 442 c.forbid_adding_reviewers = rules.get('forbid_adding_reviewers')
443 443 except Exception:
444 444 pass
445 445
446 446 # check merge capabilities
447 447 _merge_check = MergeCheck.validate(
448 448 pull_request_latest, auth_user=self._rhodecode_user,
449 449 translator=self.request.translate,
450 450 force_shadow_repo_refresh=force_refresh)
451 451
452 452 c.pr_merge_errors = _merge_check.error_details
453 453 c.pr_merge_possible = not _merge_check.failed
454 454 c.pr_merge_message = _merge_check.merge_msg
455 455 c.pr_merge_source_commit = _merge_check.source_commit
456 456 c.pr_merge_target_commit = _merge_check.target_commit
457 457
458 458 c.pr_merge_info = MergeCheck.get_merge_conditions(
459 459 pull_request_latest, translator=self.request.translate)
460 460
461 461 c.pull_request_review_status = _merge_check.review_status
462 462 if merge_checks:
463 463 self.request.override_renderer = \
464 464 'rhodecode:templates/pullrequests/pullrequest_merge_checks.mako'
465 465 return self._get_template_context(c)
466 466
467 467 c.reviewers_count = pull_request.reviewers_count
468 468 c.observers_count = pull_request.observers_count
469 469
470 470 # reviewers and statuses
471 471 c.pull_request_default_reviewers_data_json = json.dumps(pull_request.reviewer_data)
472 472 c.pull_request_set_reviewers_data_json = collections.OrderedDict({'reviewers': []})
473 473 c.pull_request_set_observers_data_json = collections.OrderedDict({'observers': []})
474 474
475 475 for review_obj, member, reasons, mandatory, status in pull_request_at_ver.reviewers_statuses():
476 476 member_reviewer = h.reviewer_as_json(
477 477 member, reasons=reasons, mandatory=mandatory,
478 478 role=review_obj.role,
479 479 user_group=review_obj.rule_user_group_data()
480 480 )
481 481
482 482 current_review_status = status[0][1].status if status else ChangesetStatus.STATUS_NOT_REVIEWED
483 483 member_reviewer['review_status'] = current_review_status
484 484 member_reviewer['review_status_label'] = h.commit_status_lbl(current_review_status)
485 485 member_reviewer['allowed_to_update'] = c.allowed_to_update
486 486 c.pull_request_set_reviewers_data_json['reviewers'].append(member_reviewer)
487 487
488 488 c.pull_request_set_reviewers_data_json = json.dumps(c.pull_request_set_reviewers_data_json)
489 489
490 490 for observer_obj, member in pull_request_at_ver.observers():
491 491 member_observer = h.reviewer_as_json(
492 492 member, reasons=[], mandatory=False,
493 493 role=observer_obj.role,
494 494 user_group=observer_obj.rule_user_group_data()
495 495 )
496 496 member_observer['allowed_to_update'] = c.allowed_to_update
497 497 c.pull_request_set_observers_data_json['observers'].append(member_observer)
498 498
499 499 c.pull_request_set_observers_data_json = json.dumps(c.pull_request_set_observers_data_json)
500 500
501 501 general_comments, inline_comments = \
502 502 self.register_comments_vars(c, pull_request_latest, versions)
503 503
504 504 # TODOs
505 505 c.unresolved_comments = CommentsModel() \
506 506 .get_pull_request_unresolved_todos(pull_request_latest)
507 507 c.resolved_comments = CommentsModel() \
508 508 .get_pull_request_resolved_todos(pull_request_latest)
509 509
510 510 # Drafts
511 511 c.draft_comments = CommentsModel().get_pull_request_drafts(
512 512 self._rhodecode_db_user.user_id,
513 513 pull_request_latest)
514 514
515 515 # if we use version, then do not show later comments
516 516 # than current version
517 517 display_inline_comments = collections.defaultdict(
518 518 lambda: collections.defaultdict(list))
519 519 for co in inline_comments:
520 520 if c.at_version_num:
521 521 # pick comments that are at least UPTO given version, so we
522 522 # don't render comments for higher version
523 523 should_render = co.pull_request_version_id and \
524 524 co.pull_request_version_id <= c.at_version_num
525 525 else:
526 526 # showing all, for 'latest'
527 527 should_render = True
528 528
529 529 if should_render:
530 530 display_inline_comments[co.f_path][co.line_no].append(co)
531 531
532 532 # load diff data into template context, if we use compare mode then
533 533 # diff is calculated based on changes between versions of PR
534 534
535 535 source_repo = pull_request_at_ver.source_repo
536 536 source_ref_id = pull_request_at_ver.source_ref_parts.commit_id
537 537
538 538 target_repo = pull_request_at_ver.target_repo
539 539 target_ref_id = pull_request_at_ver.target_ref_parts.commit_id
540 540
541 541 if compare:
542 542 # in compare switch the diff base to latest commit from prev version
543 543 target_ref_id = prev_pull_request_display_obj.revisions[0]
544 544
545 545 # despite opening commits for bookmarks/branches/tags, we always
546 546 # convert this to rev to prevent changes after bookmark or branch change
547 547 c.source_ref_type = 'rev'
548 548 c.source_ref = source_ref_id
549 549
550 550 c.target_ref_type = 'rev'
551 551 c.target_ref = target_ref_id
552 552
553 553 c.source_repo = source_repo
554 554 c.target_repo = target_repo
555 555
556 556 c.commit_ranges = []
557 557 source_commit = EmptyCommit()
558 558 target_commit = EmptyCommit()
559 559 c.missing_requirements = False
560 560
561 561 source_scm = source_repo.scm_instance()
562 562 target_scm = target_repo.scm_instance()
563 563
564 564 shadow_scm = None
565 565 try:
566 566 shadow_scm = pull_request_latest.get_shadow_repo()
567 567 except Exception:
568 568 log.debug('Failed to get shadow repo', exc_info=True)
569 569 # try first the existing source_repo, and then shadow
570 570 # repo if we can obtain one
571 571 commits_source_repo = source_scm
572 572 if shadow_scm:
573 573 commits_source_repo = shadow_scm
574 574
575 575 c.commits_source_repo = commits_source_repo
576 576 c.ancestor = None # set it to None, to hide it from PR view
577 577
578 578 # empty version means latest, so we keep this to prevent
579 579 # double caching
580 580 version_normalized = version or PullRequest.LATEST_VER
581 581 from_version_normalized = from_version or PullRequest.LATEST_VER
582 582
583 583 cache_path = self.rhodecode_vcs_repo.get_create_shadow_cache_pr_path(target_repo)
584 584 cache_file_path = diff_cache_exist(
585 585 cache_path, 'pull_request', pull_request_id, version_normalized,
586 586 from_version_normalized, source_ref_id, target_ref_id,
587 587 hide_whitespace_changes, diff_context, c.fulldiff)
588 588
589 589 caching_enabled = self._is_diff_cache_enabled(c.target_repo)
590 590 force_recache = self.get_recache_flag()
591 591
592 592 cached_diff = None
593 593 if caching_enabled:
594 594 cached_diff = load_cached_diff(cache_file_path)
595 595
596 596 has_proper_commit_cache = (
597 597 cached_diff and cached_diff.get('commits')
598 598 and len(cached_diff.get('commits', [])) == 5
599 599 and cached_diff.get('commits')[0]
600 600 and cached_diff.get('commits')[3])
601 601
602 602 if not force_recache and not c.range_diff_on and has_proper_commit_cache:
603 603 diff_commit_cache = \
604 604 (ancestor_commit, commit_cache, missing_requirements,
605 605 source_commit, target_commit) = cached_diff['commits']
606 606 else:
607 607 # NOTE(marcink): we reach potentially unreachable errors when a PR has
608 608 # merge errors resulting in potentially hidden commits in the shadow repo.
609 609 maybe_unreachable = _merge_check.MERGE_CHECK in _merge_check.error_details \
610 610 and _merge_check.merge_response
611 611 maybe_unreachable = maybe_unreachable \
612 612 and _merge_check.merge_response.metadata.get('unresolved_files')
613 613 log.debug("Using unreachable commits due to MERGE_CHECK in merge simulation")
614 614 diff_commit_cache = \
615 615 (ancestor_commit, commit_cache, missing_requirements,
616 616 source_commit, target_commit) = self.get_commits(
617 617 commits_source_repo,
618 618 pull_request_at_ver,
619 619 source_commit,
620 620 source_ref_id,
621 621 source_scm,
622 622 target_commit,
623 623 target_ref_id,
624 624 target_scm,
625 625 maybe_unreachable=maybe_unreachable)
626 626
627 627 # register our commit range
628 628 for comm in commit_cache.values():
629 629 c.commit_ranges.append(comm)
630 630
631 631 c.missing_requirements = missing_requirements
632 632 c.ancestor_commit = ancestor_commit
633 633 c.statuses = source_repo.statuses(
634 634 [x.raw_id for x in c.commit_ranges])
635 635
636 636 # auto collapse if we have more than limit
637 637 collapse_limit = diffs.DiffProcessor._collapse_commits_over
638 638 c.collapse_all_commits = len(c.commit_ranges) > collapse_limit
639 639 c.compare_mode = compare
640 640
641 641 # diff_limit is the old behavior, will cut off the whole diff
642 642 # if the limit is applied otherwise will just hide the
643 643 # big files from the front-end
644 644 diff_limit = c.visual.cut_off_limit_diff
645 645 file_limit = c.visual.cut_off_limit_file
646 646
647 647 c.missing_commits = False
648 648 if (c.missing_requirements
649 649 or isinstance(source_commit, EmptyCommit)
650 650 or source_commit == target_commit):
651 651
652 652 c.missing_commits = True
653 653 else:
654 654 c.inline_comments = display_inline_comments
655 655
656 656 use_ancestor = True
657 657 if from_version_normalized != version_normalized:
658 658 use_ancestor = False
659 659
660 660 has_proper_diff_cache = cached_diff and cached_diff.get('commits')
661 661 if not force_recache and has_proper_diff_cache:
662 662 c.diffset = cached_diff['diff']
663 663 else:
664 664 try:
665 665 c.diffset = self._get_diffset(
666 666 c.source_repo.repo_name, commits_source_repo,
667 667 c.ancestor_commit,
668 668 source_ref_id, target_ref_id,
669 669 target_commit, source_commit,
670 670 diff_limit, file_limit, c.fulldiff,
671 671 hide_whitespace_changes, diff_context,
672 672 use_ancestor=use_ancestor
673 673 )
674 674
675 675 # save cached diff
676 676 if caching_enabled:
677 677 cache_diff(cache_file_path, c.diffset, diff_commit_cache)
678 678 except CommitDoesNotExistError:
679 679 log.exception('Failed to generate diffset')
680 680 c.missing_commits = True
681 681
682 682 if not c.missing_commits:
683 683
684 684 c.limited_diff = c.diffset.limited_diff
685 685
686 686 # calculate removed files that are bound to comments
687 687 comment_deleted_files = [
688 688 fname for fname in display_inline_comments
689 689 if fname not in c.diffset.file_stats]
690 690
691 691 c.deleted_files_comments = collections.defaultdict(dict)
692 692 for fname, per_line_comments in display_inline_comments.items():
693 693 if fname in comment_deleted_files:
694 694 c.deleted_files_comments[fname]['stats'] = 0
695 695 c.deleted_files_comments[fname]['comments'] = list()
696 696 for lno, comments in per_line_comments.items():
697 697 c.deleted_files_comments[fname]['comments'].extend(comments)
698 698
699 699 # maybe calculate the range diff
700 700 if c.range_diff_on:
701 701 # TODO(marcink): set whitespace/context
702 702 context_lcl = 3
703 703 ign_whitespace_lcl = False
704 704
705 705 for commit in c.commit_ranges:
706 706 commit2 = commit
707 707 commit1 = commit.first_parent
708 708
709 709 range_diff_cache_file_path = diff_cache_exist(
710 710 cache_path, 'diff', commit.raw_id,
711 711 ign_whitespace_lcl, context_lcl, c.fulldiff)
712 712
713 713 cached_diff = None
714 714 if caching_enabled:
715 715 cached_diff = load_cached_diff(range_diff_cache_file_path)
716 716
717 717 has_proper_diff_cache = cached_diff and cached_diff.get('diff')
718 718 if not force_recache and has_proper_diff_cache:
719 719 diffset = cached_diff['diff']
720 720 else:
721 721 diffset = self._get_range_diffset(
722 722 commits_source_repo, source_repo,
723 723 commit1, commit2, diff_limit, file_limit,
724 724 c.fulldiff, ign_whitespace_lcl, context_lcl
725 725 )
726 726
727 727 # save cached diff
728 728 if caching_enabled:
729 729 cache_diff(range_diff_cache_file_path, diffset, None)
730 730
731 731 c.changes[commit.raw_id] = diffset
732 732
733 733 # this is a hack to properly display links, when creating PR, the
734 734 # compare view and others uses different notation, and
735 735 # compare_commits.mako renders links based on the target_repo.
736 736 # We need to swap that here to generate it properly on the html side
737 737 c.target_repo = c.source_repo
738 738
739 739 c.commit_statuses = ChangesetStatus.STATUSES
740 740
741 741 c.show_version_changes = not pr_closed
742 742 if c.show_version_changes:
743 743 cur_obj = pull_request_at_ver
744 744 prev_obj = prev_pull_request_at_ver
745 745
746 746 old_commit_ids = prev_obj.revisions
747 747 new_commit_ids = cur_obj.revisions
748 748 commit_changes = PullRequestModel()._calculate_commit_id_changes(
749 749 old_commit_ids, new_commit_ids)
750 750 c.commit_changes_summary = commit_changes
751 751
752 752 # calculate the diff for commits between versions
753 753 c.commit_changes = []
754 754
755 755 def mark(cs, fw):
756 return list(h.itertools.izip_longest([], cs, fillvalue=fw))
756 return list(h.itertools.zip_longest([], cs, fillvalue=fw))
757 757
758 758 for c_type, raw_id in mark(commit_changes.added, 'a') \
759 759 + mark(commit_changes.removed, 'r') \
760 760 + mark(commit_changes.common, 'c'):
761 761
762 762 if raw_id in commit_cache:
763 763 commit = commit_cache[raw_id]
764 764 else:
765 765 try:
766 766 commit = commits_source_repo.get_commit(raw_id)
767 767 except CommitDoesNotExistError:
768 768 # in case we fail extracting still use "dummy" commit
769 769 # for display in commit diff
770 770 commit = h.AttributeDict(
771 771 {'raw_id': raw_id,
772 772 'message': 'EMPTY or MISSING COMMIT'})
773 773 c.commit_changes.append([c_type, commit])
774 774
775 775 # current user review statuses for each version
776 776 c.review_versions = {}
777 777 is_reviewer = PullRequestModel().is_user_reviewer(
778 778 pull_request, self._rhodecode_user)
779 779 if is_reviewer:
780 780 for co in general_comments:
781 781 if co.author.user_id == self._rhodecode_user.user_id:
782 782 status = co.status_change
783 783 if status:
784 784 _ver_pr = status[0].comment.pull_request_version_id
785 785 c.review_versions[_ver_pr] = status[0]
786 786
787 787 return self._get_template_context(c)
788 788
789 789 def get_commits(
790 790 self, commits_source_repo, pull_request_at_ver, source_commit,
791 791 source_ref_id, source_scm, target_commit, target_ref_id, target_scm,
792 792 maybe_unreachable=False):
793 793
794 794 commit_cache = collections.OrderedDict()
795 795 missing_requirements = False
796 796
797 797 try:
798 798 pre_load = ["author", "date", "message", "branch", "parents"]
799 799
800 800 pull_request_commits = pull_request_at_ver.revisions
801 801 log.debug('Loading %s commits from %s',
802 802 len(pull_request_commits), commits_source_repo)
803 803
804 804 for rev in pull_request_commits:
805 805 comm = commits_source_repo.get_commit(commit_id=rev, pre_load=pre_load,
806 806 maybe_unreachable=maybe_unreachable)
807 807 commit_cache[comm.raw_id] = comm
808 808
809 809 # Order here matters, we first need to get target, and then
810 810 # the source
811 811 target_commit = commits_source_repo.get_commit(
812 812 commit_id=safe_str(target_ref_id))
813 813
814 814 source_commit = commits_source_repo.get_commit(
815 815 commit_id=safe_str(source_ref_id), maybe_unreachable=True)
816 816 except CommitDoesNotExistError:
817 817 log.warning('Failed to get commit from `{}` repo'.format(
818 818 commits_source_repo), exc_info=True)
819 819 except RepositoryRequirementError:
820 820 log.warning('Failed to get all required data from repo', exc_info=True)
821 821 missing_requirements = True
822 822
823 823 pr_ancestor_id = pull_request_at_ver.common_ancestor_id
824 824
825 825 try:
826 826 ancestor_commit = source_scm.get_commit(pr_ancestor_id)
827 827 except Exception:
828 828 ancestor_commit = None
829 829
830 830 return ancestor_commit, commit_cache, missing_requirements, source_commit, target_commit
831 831
832 832 def assure_not_empty_repo(self):
833 833 _ = self.request.translate
834 834
835 835 try:
836 836 self.db_repo.scm_instance().get_commit()
837 837 except EmptyRepositoryError:
838 838 h.flash(h.literal(_('There are no commits yet')),
839 839 category='warning')
840 840 raise HTTPFound(
841 841 h.route_path('repo_summary', repo_name=self.db_repo.repo_name))
842 842
843 843 @LoginRequired()
844 844 @NotAnonymous()
845 845 @HasRepoPermissionAnyDecorator(
846 846 'repository.read', 'repository.write', 'repository.admin')
847 847 def pull_request_new(self):
848 848 _ = self.request.translate
849 849 c = self.load_default_context()
850 850
851 851 self.assure_not_empty_repo()
852 852 source_repo = self.db_repo
853 853
854 854 commit_id = self.request.GET.get('commit')
855 855 branch_ref = self.request.GET.get('branch')
856 856 bookmark_ref = self.request.GET.get('bookmark')
857 857
858 858 try:
859 859 source_repo_data = PullRequestModel().generate_repo_data(
860 860 source_repo, commit_id=commit_id,
861 861 branch=branch_ref, bookmark=bookmark_ref,
862 862 translator=self.request.translate)
863 863 except CommitDoesNotExistError as e:
864 864 log.exception(e)
865 865 h.flash(_('Commit does not exist'), 'error')
866 866 raise HTTPFound(
867 867 h.route_path('pullrequest_new', repo_name=source_repo.repo_name))
868 868
869 869 default_target_repo = source_repo
870 870
871 871 if source_repo.parent and c.has_origin_repo_read_perm:
872 872 parent_vcs_obj = source_repo.parent.scm_instance()
873 873 if parent_vcs_obj and not parent_vcs_obj.is_empty():
874 874 # change default if we have a parent repo
875 875 default_target_repo = source_repo.parent
876 876
877 877 target_repo_data = PullRequestModel().generate_repo_data(
878 878 default_target_repo, translator=self.request.translate)
879 879
880 880 selected_source_ref = source_repo_data['refs']['selected_ref']
881 881 title_source_ref = ''
882 882 if selected_source_ref:
883 883 title_source_ref = selected_source_ref.split(':', 2)[1]
884 884 c.default_title = PullRequestModel().generate_pullrequest_title(
885 885 source=source_repo.repo_name,
886 886 source_ref=title_source_ref,
887 887 target=default_target_repo.repo_name
888 888 )
889 889
890 890 c.default_repo_data = {
891 891 'source_repo_name': source_repo.repo_name,
892 892 'source_refs_json': json.dumps(source_repo_data),
893 893 'target_repo_name': default_target_repo.repo_name,
894 894 'target_refs_json': json.dumps(target_repo_data),
895 895 }
896 896 c.default_source_ref = selected_source_ref
897 897
898 898 return self._get_template_context(c)
899 899
900 900 @LoginRequired()
901 901 @NotAnonymous()
902 902 @HasRepoPermissionAnyDecorator(
903 903 'repository.read', 'repository.write', 'repository.admin')
904 904 def pull_request_repo_refs(self):
905 905 self.load_default_context()
906 906 target_repo_name = self.request.matchdict['target_repo_name']
907 907 repo = Repository.get_by_repo_name(target_repo_name)
908 908 if not repo:
909 909 raise HTTPNotFound()
910 910
911 911 target_perm = HasRepoPermissionAny(
912 912 'repository.read', 'repository.write', 'repository.admin')(
913 913 target_repo_name)
914 914 if not target_perm:
915 915 raise HTTPNotFound()
916 916
917 917 return PullRequestModel().generate_repo_data(
918 918 repo, translator=self.request.translate)
919 919
920 920 @LoginRequired()
921 921 @NotAnonymous()
922 922 @HasRepoPermissionAnyDecorator(
923 923 'repository.read', 'repository.write', 'repository.admin')
924 924 def pullrequest_repo_targets(self):
925 925 _ = self.request.translate
926 926 filter_query = self.request.GET.get('query')
927 927
928 928 # get the parents
929 929 parent_target_repos = []
930 930 if self.db_repo.parent:
931 931 parents_query = Repository.query() \
932 932 .order_by(func.length(Repository.repo_name)) \
933 933 .filter(Repository.fork_id == self.db_repo.parent.repo_id)
934 934
935 935 if filter_query:
936 936 ilike_expression = u'%{}%'.format(safe_unicode(filter_query))
937 937 parents_query = parents_query.filter(
938 938 Repository.repo_name.ilike(ilike_expression))
939 939 parents = parents_query.limit(20).all()
940 940
941 941 for parent in parents:
942 942 parent_vcs_obj = parent.scm_instance()
943 943 if parent_vcs_obj and not parent_vcs_obj.is_empty():
944 944 parent_target_repos.append(parent)
945 945
946 946 # get other forks, and repo itself
947 947 query = Repository.query() \
948 948 .order_by(func.length(Repository.repo_name)) \
949 949 .filter(
950 950 or_(Repository.repo_id == self.db_repo.repo_id, # repo itself
951 951 Repository.fork_id == self.db_repo.repo_id) # forks of this repo
952 952 ) \
953 953 .filter(~Repository.repo_id.in_([x.repo_id for x in parent_target_repos]))
954 954
955 955 if filter_query:
956 956 ilike_expression = u'%{}%'.format(safe_unicode(filter_query))
957 957 query = query.filter(Repository.repo_name.ilike(ilike_expression))
958 958
959 959 limit = max(20 - len(parent_target_repos), 5) # not less then 5
960 960 target_repos = query.limit(limit).all()
961 961
962 962 all_target_repos = target_repos + parent_target_repos
963 963
964 964 repos = []
965 965 # This checks permissions to the repositories
966 966 for obj in ScmModel().get_repos(all_target_repos):
967 967 repos.append({
968 968 'id': obj['name'],
969 969 'text': obj['name'],
970 970 'type': 'repo',
971 971 'repo_id': obj['dbrepo']['repo_id'],
972 972 'repo_type': obj['dbrepo']['repo_type'],
973 973 'private': obj['dbrepo']['private'],
974 974
975 975 })
976 976
977 977 data = {
978 978 'more': False,
979 979 'results': [{
980 980 'text': _('Repositories'),
981 981 'children': repos
982 982 }] if repos else []
983 983 }
984 984 return data
985 985
986 986 @classmethod
987 987 def get_comment_ids(cls, post_data):
988 988 return filter(lambda e: e > 0, map(safe_int, aslist(post_data.get('comments'), ',')))
989 989
990 990 @LoginRequired()
991 991 @NotAnonymous()
992 992 @HasRepoPermissionAnyDecorator(
993 993 'repository.read', 'repository.write', 'repository.admin')
994 994 def pullrequest_comments(self):
995 995 self.load_default_context()
996 996
997 997 pull_request = PullRequest.get_or_404(
998 998 self.request.matchdict['pull_request_id'])
999 999 pull_request_id = pull_request.pull_request_id
1000 1000 version = self.request.GET.get('version')
1001 1001
1002 1002 _render = self.request.get_partial_renderer(
1003 1003 'rhodecode:templates/base/sidebar.mako')
1004 1004 c = _render.get_call_context()
1005 1005
1006 1006 (pull_request_latest,
1007 1007 pull_request_at_ver,
1008 1008 pull_request_display_obj,
1009 1009 at_version) = PullRequestModel().get_pr_version(
1010 1010 pull_request_id, version=version)
1011 1011 versions = pull_request_display_obj.versions()
1012 1012 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
1013 1013 c.versions = versions + [latest_ver]
1014 1014
1015 1015 c.at_version = at_version
1016 1016 c.at_version_num = (at_version
1017 1017 if at_version and at_version != PullRequest.LATEST_VER
1018 1018 else None)
1019 1019
1020 1020 self.register_comments_vars(c, pull_request_latest, versions, include_drafts=False)
1021 1021 all_comments = c.inline_comments_flat + c.comments
1022 1022
1023 1023 existing_ids = self.get_comment_ids(self.request.POST)
1024 1024 return _render('comments_table', all_comments, len(all_comments),
1025 1025 existing_ids=existing_ids)
1026 1026
1027 1027 @LoginRequired()
1028 1028 @NotAnonymous()
1029 1029 @HasRepoPermissionAnyDecorator(
1030 1030 'repository.read', 'repository.write', 'repository.admin')
1031 1031 def pullrequest_todos(self):
1032 1032 self.load_default_context()
1033 1033
1034 1034 pull_request = PullRequest.get_or_404(
1035 1035 self.request.matchdict['pull_request_id'])
1036 1036 pull_request_id = pull_request.pull_request_id
1037 1037 version = self.request.GET.get('version')
1038 1038
1039 1039 _render = self.request.get_partial_renderer(
1040 1040 'rhodecode:templates/base/sidebar.mako')
1041 1041 c = _render.get_call_context()
1042 1042 (pull_request_latest,
1043 1043 pull_request_at_ver,
1044 1044 pull_request_display_obj,
1045 1045 at_version) = PullRequestModel().get_pr_version(
1046 1046 pull_request_id, version=version)
1047 1047 versions = pull_request_display_obj.versions()
1048 1048 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
1049 1049 c.versions = versions + [latest_ver]
1050 1050
1051 1051 c.at_version = at_version
1052 1052 c.at_version_num = (at_version
1053 1053 if at_version and at_version != PullRequest.LATEST_VER
1054 1054 else None)
1055 1055
1056 1056 c.unresolved_comments = CommentsModel() \
1057 1057 .get_pull_request_unresolved_todos(pull_request, include_drafts=False)
1058 1058 c.resolved_comments = CommentsModel() \
1059 1059 .get_pull_request_resolved_todos(pull_request, include_drafts=False)
1060 1060
1061 1061 all_comments = c.unresolved_comments + c.resolved_comments
1062 1062 existing_ids = self.get_comment_ids(self.request.POST)
1063 1063 return _render('comments_table', all_comments, len(c.unresolved_comments),
1064 1064 todo_comments=True, existing_ids=existing_ids)
1065 1065
1066 1066 @LoginRequired()
1067 1067 @NotAnonymous()
1068 1068 @HasRepoPermissionAnyDecorator(
1069 1069 'repository.read', 'repository.write', 'repository.admin')
1070 1070 def pullrequest_drafts(self):
1071 1071 self.load_default_context()
1072 1072
1073 1073 pull_request = PullRequest.get_or_404(
1074 1074 self.request.matchdict['pull_request_id'])
1075 1075 pull_request_id = pull_request.pull_request_id
1076 1076 version = self.request.GET.get('version')
1077 1077
1078 1078 _render = self.request.get_partial_renderer(
1079 1079 'rhodecode:templates/base/sidebar.mako')
1080 1080 c = _render.get_call_context()
1081 1081
1082 1082 (pull_request_latest,
1083 1083 pull_request_at_ver,
1084 1084 pull_request_display_obj,
1085 1085 at_version) = PullRequestModel().get_pr_version(
1086 1086 pull_request_id, version=version)
1087 1087 versions = pull_request_display_obj.versions()
1088 1088 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
1089 1089 c.versions = versions + [latest_ver]
1090 1090
1091 1091 c.at_version = at_version
1092 1092 c.at_version_num = (at_version
1093 1093 if at_version and at_version != PullRequest.LATEST_VER
1094 1094 else None)
1095 1095
1096 1096 c.draft_comments = CommentsModel() \
1097 1097 .get_pull_request_drafts(self._rhodecode_db_user.user_id, pull_request)
1098 1098
1099 1099 all_comments = c.draft_comments
1100 1100
1101 1101 existing_ids = self.get_comment_ids(self.request.POST)
1102 1102 return _render('comments_table', all_comments, len(all_comments),
1103 1103 existing_ids=existing_ids, draft_comments=True)
1104 1104
1105 1105 @LoginRequired()
1106 1106 @NotAnonymous()
1107 1107 @HasRepoPermissionAnyDecorator(
1108 1108 'repository.read', 'repository.write', 'repository.admin')
1109 1109 @CSRFRequired()
1110 1110 def pull_request_create(self):
1111 1111 _ = self.request.translate
1112 1112 self.assure_not_empty_repo()
1113 1113 self.load_default_context()
1114 1114
1115 1115 controls = peppercorn.parse(self.request.POST.items())
1116 1116
1117 1117 try:
1118 1118 form = PullRequestForm(
1119 1119 self.request.translate, self.db_repo.repo_id)()
1120 1120 _form = form.to_python(controls)
1121 1121 except formencode.Invalid as errors:
1122 1122 if errors.error_dict.get('revisions'):
1123 1123 msg = 'Revisions: %s' % errors.error_dict['revisions']
1124 1124 elif errors.error_dict.get('pullrequest_title'):
1125 1125 msg = errors.error_dict.get('pullrequest_title')
1126 1126 else:
1127 1127 msg = _('Error creating pull request: {}').format(errors)
1128 1128 log.exception(msg)
1129 1129 h.flash(msg, 'error')
1130 1130
1131 1131 # would rather just go back to form ...
1132 1132 raise HTTPFound(
1133 1133 h.route_path('pullrequest_new', repo_name=self.db_repo_name))
1134 1134
1135 1135 source_repo = _form['source_repo']
1136 1136 source_ref = _form['source_ref']
1137 1137 target_repo = _form['target_repo']
1138 1138 target_ref = _form['target_ref']
1139 1139 commit_ids = _form['revisions'][::-1]
1140 1140 common_ancestor_id = _form['common_ancestor']
1141 1141
1142 1142 # find the ancestor for this pr
1143 1143 source_db_repo = Repository.get_by_repo_name(_form['source_repo'])
1144 1144 target_db_repo = Repository.get_by_repo_name(_form['target_repo'])
1145 1145
1146 1146 if not (source_db_repo or target_db_repo):
1147 1147 h.flash(_('source_repo or target repo not found'), category='error')
1148 1148 raise HTTPFound(
1149 1149 h.route_path('pullrequest_new', repo_name=self.db_repo_name))
1150 1150
1151 1151 # re-check permissions again here
1152 1152 # source_repo we must have read permissions
1153 1153
1154 1154 source_perm = HasRepoPermissionAny(
1155 1155 'repository.read', 'repository.write', 'repository.admin')(
1156 1156 source_db_repo.repo_name)
1157 1157 if not source_perm:
1158 1158 msg = _('Not Enough permissions to source repo `{}`.'.format(
1159 1159 source_db_repo.repo_name))
1160 1160 h.flash(msg, category='error')
1161 1161 # copy the args back to redirect
1162 1162 org_query = self.request.GET.mixed()
1163 1163 raise HTTPFound(
1164 1164 h.route_path('pullrequest_new', repo_name=self.db_repo_name,
1165 1165 _query=org_query))
1166 1166
1167 1167 # target repo we must have read permissions, and also later on
1168 1168 # we want to check branch permissions here
1169 1169 target_perm = HasRepoPermissionAny(
1170 1170 'repository.read', 'repository.write', 'repository.admin')(
1171 1171 target_db_repo.repo_name)
1172 1172 if not target_perm:
1173 1173 msg = _('Not Enough permissions to target repo `{}`.'.format(
1174 1174 target_db_repo.repo_name))
1175 1175 h.flash(msg, category='error')
1176 1176 # copy the args back to redirect
1177 1177 org_query = self.request.GET.mixed()
1178 1178 raise HTTPFound(
1179 1179 h.route_path('pullrequest_new', repo_name=self.db_repo_name,
1180 1180 _query=org_query))
1181 1181
1182 1182 source_scm = source_db_repo.scm_instance()
1183 1183 target_scm = target_db_repo.scm_instance()
1184 1184
1185 1185 source_ref_obj = unicode_to_reference(source_ref)
1186 1186 target_ref_obj = unicode_to_reference(target_ref)
1187 1187
1188 1188 source_commit = source_scm.get_commit(source_ref_obj.commit_id)
1189 1189 target_commit = target_scm.get_commit(target_ref_obj.commit_id)
1190 1190
1191 1191 ancestor = source_scm.get_common_ancestor(
1192 1192 source_commit.raw_id, target_commit.raw_id, target_scm)
1193 1193
1194 1194 # recalculate target ref based on ancestor
1195 1195 target_ref = ':'.join((target_ref_obj.type, target_ref_obj.name, ancestor))
1196 1196
1197 1197 get_default_reviewers_data, validate_default_reviewers, validate_observers = \
1198 1198 PullRequestModel().get_reviewer_functions()
1199 1199
1200 1200 # recalculate reviewers logic, to make sure we can validate this
1201 1201 reviewer_rules = get_default_reviewers_data(
1202 1202 self._rhodecode_db_user,
1203 1203 source_db_repo,
1204 1204 source_ref_obj,
1205 1205 target_db_repo,
1206 1206 target_ref_obj,
1207 1207 include_diff_info=False)
1208 1208
1209 1209 reviewers = validate_default_reviewers(_form['review_members'], reviewer_rules)
1210 1210 observers = validate_observers(_form['observer_members'], reviewer_rules)
1211 1211
1212 1212 pullrequest_title = _form['pullrequest_title']
1213 1213 title_source_ref = source_ref_obj.name
1214 1214 if not pullrequest_title:
1215 1215 pullrequest_title = PullRequestModel().generate_pullrequest_title(
1216 1216 source=source_repo,
1217 1217 source_ref=title_source_ref,
1218 1218 target=target_repo
1219 1219 )
1220 1220
1221 1221 description = _form['pullrequest_desc']
1222 1222 description_renderer = _form['description_renderer']
1223 1223
1224 1224 try:
1225 1225 pull_request = PullRequestModel().create(
1226 1226 created_by=self._rhodecode_user.user_id,
1227 1227 source_repo=source_repo,
1228 1228 source_ref=source_ref,
1229 1229 target_repo=target_repo,
1230 1230 target_ref=target_ref,
1231 1231 revisions=commit_ids,
1232 1232 common_ancestor_id=common_ancestor_id,
1233 1233 reviewers=reviewers,
1234 1234 observers=observers,
1235 1235 title=pullrequest_title,
1236 1236 description=description,
1237 1237 description_renderer=description_renderer,
1238 1238 reviewer_data=reviewer_rules,
1239 1239 auth_user=self._rhodecode_user
1240 1240 )
1241 1241 Session().commit()
1242 1242
1243 1243 h.flash(_('Successfully opened new pull request'),
1244 1244 category='success')
1245 1245 except Exception:
1246 1246 msg = _('Error occurred during creation of this pull request.')
1247 1247 log.exception(msg)
1248 1248 h.flash(msg, category='error')
1249 1249
1250 1250 # copy the args back to redirect
1251 1251 org_query = self.request.GET.mixed()
1252 1252 raise HTTPFound(
1253 1253 h.route_path('pullrequest_new', repo_name=self.db_repo_name,
1254 1254 _query=org_query))
1255 1255
1256 1256 raise HTTPFound(
1257 1257 h.route_path('pullrequest_show', repo_name=target_repo,
1258 1258 pull_request_id=pull_request.pull_request_id))
1259 1259
1260 1260 @LoginRequired()
1261 1261 @NotAnonymous()
1262 1262 @HasRepoPermissionAnyDecorator(
1263 1263 'repository.read', 'repository.write', 'repository.admin')
1264 1264 @CSRFRequired()
1265 1265 def pull_request_update(self):
1266 1266 pull_request = PullRequest.get_or_404(
1267 1267 self.request.matchdict['pull_request_id'])
1268 1268 _ = self.request.translate
1269 1269
1270 1270 c = self.load_default_context()
1271 1271 redirect_url = None
1272 1272 # we do this check as first, because we want to know ASAP in the flow that
1273 1273 # pr is updating currently
1274 1274 is_state_changing = pull_request.is_state_changing()
1275 1275
1276 1276 if pull_request.is_closed():
1277 1277 log.debug('update: forbidden because pull request is closed')
1278 1278 msg = _(u'Cannot update closed pull requests.')
1279 1279 h.flash(msg, category='error')
1280 1280 return {'response': True,
1281 1281 'redirect_url': redirect_url}
1282 1282
1283 1283 c.pr_broadcast_channel = channelstream.pr_channel(pull_request)
1284 1284
1285 1285 # only owner or admin can update it
1286 1286 allowed_to_update = PullRequestModel().check_user_update(
1287 1287 pull_request, self._rhodecode_user)
1288 1288
1289 1289 if allowed_to_update:
1290 1290 controls = peppercorn.parse(self.request.POST.items())
1291 1291 force_refresh = str2bool(self.request.POST.get('force_refresh', 'false'))
1292 1292 do_update_commits = str2bool(self.request.POST.get('update_commits', 'false'))
1293 1293
1294 1294 if 'review_members' in controls:
1295 1295 self._update_reviewers(
1296 1296 c,
1297 1297 pull_request, controls['review_members'],
1298 1298 pull_request.reviewer_data,
1299 1299 PullRequestReviewers.ROLE_REVIEWER)
1300 1300 elif 'observer_members' in controls:
1301 1301 self._update_reviewers(
1302 1302 c,
1303 1303 pull_request, controls['observer_members'],
1304 1304 pull_request.reviewer_data,
1305 1305 PullRequestReviewers.ROLE_OBSERVER)
1306 1306 elif do_update_commits:
1307 1307 if is_state_changing:
1308 1308 log.debug('commits update: forbidden because pull request is in state %s',
1309 1309 pull_request.pull_request_state)
1310 1310 msg = _(u'Cannot update pull requests commits in state other than `{}`. '
1311 1311 u'Current state is: `{}`').format(
1312 1312 PullRequest.STATE_CREATED, pull_request.pull_request_state)
1313 1313 h.flash(msg, category='error')
1314 1314 return {'response': True,
1315 1315 'redirect_url': redirect_url}
1316 1316
1317 1317 self._update_commits(c, pull_request)
1318 1318 if force_refresh:
1319 1319 redirect_url = h.route_path(
1320 1320 'pullrequest_show', repo_name=self.db_repo_name,
1321 1321 pull_request_id=pull_request.pull_request_id,
1322 1322 _query={"force_refresh": 1})
1323 1323 elif str2bool(self.request.POST.get('edit_pull_request', 'false')):
1324 1324 self._edit_pull_request(pull_request)
1325 1325 else:
1326 1326 log.error('Unhandled update data.')
1327 1327 raise HTTPBadRequest()
1328 1328
1329 1329 return {'response': True,
1330 1330 'redirect_url': redirect_url}
1331 1331 raise HTTPForbidden()
1332 1332
1333 1333 def _edit_pull_request(self, pull_request):
1334 1334 """
1335 1335 Edit title and description
1336 1336 """
1337 1337 _ = self.request.translate
1338 1338
1339 1339 try:
1340 1340 PullRequestModel().edit(
1341 1341 pull_request,
1342 1342 self.request.POST.get('title'),
1343 1343 self.request.POST.get('description'),
1344 1344 self.request.POST.get('description_renderer'),
1345 1345 self._rhodecode_user)
1346 1346 except ValueError:
1347 1347 msg = _(u'Cannot update closed pull requests.')
1348 1348 h.flash(msg, category='error')
1349 1349 return
1350 1350 else:
1351 1351 Session().commit()
1352 1352
1353 1353 msg = _(u'Pull request title & description updated.')
1354 1354 h.flash(msg, category='success')
1355 1355 return
1356 1356
1357 1357 def _update_commits(self, c, pull_request):
1358 1358 _ = self.request.translate
1359 1359 log.debug('pull-request: running update commits actions')
1360 1360
1361 1361 @retry(exception=Exception, n_tries=3, delay=2)
1362 1362 def commits_update():
1363 1363 return PullRequestModel().update_commits(
1364 1364 pull_request, self._rhodecode_db_user)
1365 1365
1366 1366 with pull_request.set_state(PullRequest.STATE_UPDATING):
1367 1367 resp = commits_update() # retry x3
1368 1368
1369 1369 if resp.executed:
1370 1370
1371 1371 if resp.target_changed and resp.source_changed:
1372 1372 changed = 'target and source repositories'
1373 1373 elif resp.target_changed and not resp.source_changed:
1374 1374 changed = 'target repository'
1375 1375 elif not resp.target_changed and resp.source_changed:
1376 1376 changed = 'source repository'
1377 1377 else:
1378 1378 changed = 'nothing'
1379 1379
1380 1380 msg = _(u'Pull request updated to "{source_commit_id}" with '
1381 1381 u'{count_added} added, {count_removed} removed commits. '
1382 1382 u'Source of changes: {change_source}.')
1383 1383 msg = msg.format(
1384 1384 source_commit_id=pull_request.source_ref_parts.commit_id,
1385 1385 count_added=len(resp.changes.added),
1386 1386 count_removed=len(resp.changes.removed),
1387 1387 change_source=changed)
1388 1388 h.flash(msg, category='success')
1389 1389 channelstream.pr_update_channelstream_push(
1390 1390 self.request, c.pr_broadcast_channel, self._rhodecode_user, msg)
1391 1391 else:
1392 1392 msg = PullRequestModel.UPDATE_STATUS_MESSAGES[resp.reason]
1393 1393 warning_reasons = [
1394 1394 UpdateFailureReason.NO_CHANGE,
1395 1395 UpdateFailureReason.WRONG_REF_TYPE,
1396 1396 ]
1397 1397 category = 'warning' if resp.reason in warning_reasons else 'error'
1398 1398 h.flash(msg, category=category)
1399 1399
1400 1400 def _update_reviewers(self, c, pull_request, review_members, reviewer_rules, role):
1401 1401 _ = self.request.translate
1402 1402
1403 1403 get_default_reviewers_data, validate_default_reviewers, validate_observers = \
1404 1404 PullRequestModel().get_reviewer_functions()
1405 1405
1406 1406 if role == PullRequestReviewers.ROLE_REVIEWER:
1407 1407 try:
1408 1408 reviewers = validate_default_reviewers(review_members, reviewer_rules)
1409 1409 except ValueError as e:
1410 1410 log.error('Reviewers Validation: {}'.format(e))
1411 1411 h.flash(e, category='error')
1412 1412 return
1413 1413
1414 1414 old_calculated_status = pull_request.calculated_review_status()
1415 1415 PullRequestModel().update_reviewers(
1416 1416 pull_request, reviewers, self._rhodecode_db_user)
1417 1417
1418 1418 Session().commit()
1419 1419
1420 1420 msg = _('Pull request reviewers updated.')
1421 1421 h.flash(msg, category='success')
1422 1422 channelstream.pr_update_channelstream_push(
1423 1423 self.request, c.pr_broadcast_channel, self._rhodecode_user, msg)
1424 1424
1425 1425 # trigger status changed if change in reviewers changes the status
1426 1426 calculated_status = pull_request.calculated_review_status()
1427 1427 if old_calculated_status != calculated_status:
1428 1428 PullRequestModel().trigger_pull_request_hook(
1429 1429 pull_request, self._rhodecode_user, 'review_status_change',
1430 1430 data={'status': calculated_status})
1431 1431
1432 1432 elif role == PullRequestReviewers.ROLE_OBSERVER:
1433 1433 try:
1434 1434 observers = validate_observers(review_members, reviewer_rules)
1435 1435 except ValueError as e:
1436 1436 log.error('Observers Validation: {}'.format(e))
1437 1437 h.flash(e, category='error')
1438 1438 return
1439 1439
1440 1440 PullRequestModel().update_observers(
1441 1441 pull_request, observers, self._rhodecode_db_user)
1442 1442
1443 1443 Session().commit()
1444 1444 msg = _('Pull request observers updated.')
1445 1445 h.flash(msg, category='success')
1446 1446 channelstream.pr_update_channelstream_push(
1447 1447 self.request, c.pr_broadcast_channel, self._rhodecode_user, msg)
1448 1448
1449 1449 @LoginRequired()
1450 1450 @NotAnonymous()
1451 1451 @HasRepoPermissionAnyDecorator(
1452 1452 'repository.read', 'repository.write', 'repository.admin')
1453 1453 @CSRFRequired()
1454 1454 def pull_request_merge(self):
1455 1455 """
1456 1456 Merge will perform a server-side merge of the specified
1457 1457 pull request, if the pull request is approved and mergeable.
1458 1458 After successful merging, the pull request is automatically
1459 1459 closed, with a relevant comment.
1460 1460 """
1461 1461 pull_request = PullRequest.get_or_404(
1462 1462 self.request.matchdict['pull_request_id'])
1463 1463 _ = self.request.translate
1464 1464
1465 1465 if pull_request.is_state_changing():
1466 1466 log.debug('show: forbidden because pull request is in state %s',
1467 1467 pull_request.pull_request_state)
1468 1468 msg = _(u'Cannot merge pull requests in state other than `{}`. '
1469 1469 u'Current state is: `{}`').format(PullRequest.STATE_CREATED,
1470 1470 pull_request.pull_request_state)
1471 1471 h.flash(msg, category='error')
1472 1472 raise HTTPFound(
1473 1473 h.route_path('pullrequest_show',
1474 1474 repo_name=pull_request.target_repo.repo_name,
1475 1475 pull_request_id=pull_request.pull_request_id))
1476 1476
1477 1477 self.load_default_context()
1478 1478
1479 1479 with pull_request.set_state(PullRequest.STATE_UPDATING):
1480 1480 check = MergeCheck.validate(
1481 1481 pull_request, auth_user=self._rhodecode_user,
1482 1482 translator=self.request.translate)
1483 1483 merge_possible = not check.failed
1484 1484
1485 1485 for err_type, error_msg in check.errors:
1486 1486 h.flash(error_msg, category=err_type)
1487 1487
1488 1488 if merge_possible:
1489 1489 log.debug("Pre-conditions checked, trying to merge.")
1490 1490 extras = vcs_operation_context(
1491 1491 self.request.environ, repo_name=pull_request.target_repo.repo_name,
1492 1492 username=self._rhodecode_db_user.username, action='push',
1493 1493 scm=pull_request.target_repo.repo_type)
1494 1494 with pull_request.set_state(PullRequest.STATE_UPDATING):
1495 1495 self._merge_pull_request(
1496 1496 pull_request, self._rhodecode_db_user, extras)
1497 1497 else:
1498 1498 log.debug("Pre-conditions failed, NOT merging.")
1499 1499
1500 1500 raise HTTPFound(
1501 1501 h.route_path('pullrequest_show',
1502 1502 repo_name=pull_request.target_repo.repo_name,
1503 1503 pull_request_id=pull_request.pull_request_id))
1504 1504
1505 1505 def _merge_pull_request(self, pull_request, user, extras):
1506 1506 _ = self.request.translate
1507 1507 merge_resp = PullRequestModel().merge_repo(pull_request, user, extras=extras)
1508 1508
1509 1509 if merge_resp.executed:
1510 1510 log.debug("The merge was successful, closing the pull request.")
1511 1511 PullRequestModel().close_pull_request(
1512 1512 pull_request.pull_request_id, user)
1513 1513 Session().commit()
1514 1514 msg = _('Pull request was successfully merged and closed.')
1515 1515 h.flash(msg, category='success')
1516 1516 else:
1517 1517 log.debug(
1518 1518 "The merge was not successful. Merge response: %s", merge_resp)
1519 1519 msg = merge_resp.merge_status_message
1520 1520 h.flash(msg, category='error')
1521 1521
1522 1522 @LoginRequired()
1523 1523 @NotAnonymous()
1524 1524 @HasRepoPermissionAnyDecorator(
1525 1525 'repository.read', 'repository.write', 'repository.admin')
1526 1526 @CSRFRequired()
1527 1527 def pull_request_delete(self):
1528 1528 _ = self.request.translate
1529 1529
1530 1530 pull_request = PullRequest.get_or_404(
1531 1531 self.request.matchdict['pull_request_id'])
1532 1532 self.load_default_context()
1533 1533
1534 1534 pr_closed = pull_request.is_closed()
1535 1535 allowed_to_delete = PullRequestModel().check_user_delete(
1536 1536 pull_request, self._rhodecode_user) and not pr_closed
1537 1537
1538 1538 # only owner can delete it !
1539 1539 if allowed_to_delete:
1540 1540 PullRequestModel().delete(pull_request, self._rhodecode_user)
1541 1541 Session().commit()
1542 1542 h.flash(_('Successfully deleted pull request'),
1543 1543 category='success')
1544 1544 raise HTTPFound(h.route_path('pullrequest_show_all',
1545 1545 repo_name=self.db_repo_name))
1546 1546
1547 1547 log.warning('user %s tried to delete pull request without access',
1548 1548 self._rhodecode_user)
1549 1549 raise HTTPNotFound()
1550 1550
1551 1551 def _pull_request_comments_create(self, pull_request, comments):
1552 1552 _ = self.request.translate
1553 1553 data = {}
1554 1554 if not comments:
1555 1555 return
1556 1556 pull_request_id = pull_request.pull_request_id
1557 1557
1558 1558 all_drafts = len([x for x in comments if str2bool(x['is_draft'])]) == len(comments)
1559 1559
1560 1560 for entry in comments:
1561 1561 c = self.load_default_context()
1562 1562 comment_type = entry['comment_type']
1563 1563 text = entry['text']
1564 1564 status = entry['status']
1565 1565 is_draft = str2bool(entry['is_draft'])
1566 1566 resolves_comment_id = entry['resolves_comment_id']
1567 1567 close_pull_request = entry['close_pull_request']
1568 1568 f_path = entry['f_path']
1569 1569 line_no = entry['line']
1570 1570 target_elem_id = 'file-{}'.format(h.safeid(h.safe_unicode(f_path)))
1571 1571
1572 1572 # the logic here should work like following, if we submit close
1573 1573 # pr comment, use `close_pull_request_with_comment` function
1574 1574 # else handle regular comment logic
1575 1575
1576 1576 if close_pull_request:
1577 1577 # only owner or admin or person with write permissions
1578 1578 allowed_to_close = PullRequestModel().check_user_update(
1579 1579 pull_request, self._rhodecode_user)
1580 1580 if not allowed_to_close:
1581 1581 log.debug('comment: forbidden because not allowed to close '
1582 1582 'pull request %s', pull_request_id)
1583 1583 raise HTTPForbidden()
1584 1584
1585 1585 # This also triggers `review_status_change`
1586 1586 comment, status = PullRequestModel().close_pull_request_with_comment(
1587 1587 pull_request, self._rhodecode_user, self.db_repo, message=text,
1588 1588 auth_user=self._rhodecode_user)
1589 1589 Session().flush()
1590 1590 is_inline = comment.is_inline
1591 1591
1592 1592 PullRequestModel().trigger_pull_request_hook(
1593 1593 pull_request, self._rhodecode_user, 'comment',
1594 1594 data={'comment': comment})
1595 1595
1596 1596 else:
1597 1597 # regular comment case, could be inline, or one with status.
1598 1598 # for that one we check also permissions
1599 1599 # Additionally ENSURE if somehow draft is sent we're then unable to change status
1600 1600 allowed_to_change_status = PullRequestModel().check_user_change_status(
1601 1601 pull_request, self._rhodecode_user) and not is_draft
1602 1602
1603 1603 if status and allowed_to_change_status:
1604 1604 message = (_('Status change %(transition_icon)s %(status)s')
1605 1605 % {'transition_icon': '>',
1606 1606 'status': ChangesetStatus.get_status_lbl(status)})
1607 1607 text = text or message
1608 1608
1609 1609 comment = CommentsModel().create(
1610 1610 text=text,
1611 1611 repo=self.db_repo.repo_id,
1612 1612 user=self._rhodecode_user.user_id,
1613 1613 pull_request=pull_request,
1614 1614 f_path=f_path,
1615 1615 line_no=line_no,
1616 1616 status_change=(ChangesetStatus.get_status_lbl(status)
1617 1617 if status and allowed_to_change_status else None),
1618 1618 status_change_type=(status
1619 1619 if status and allowed_to_change_status else None),
1620 1620 comment_type=comment_type,
1621 1621 is_draft=is_draft,
1622 1622 resolves_comment_id=resolves_comment_id,
1623 1623 auth_user=self._rhodecode_user,
1624 1624 send_email=not is_draft, # skip notification for draft comments
1625 1625 )
1626 1626 is_inline = comment.is_inline
1627 1627
1628 1628 if allowed_to_change_status:
1629 1629 # calculate old status before we change it
1630 1630 old_calculated_status = pull_request.calculated_review_status()
1631 1631
1632 1632 # get status if set !
1633 1633 if status:
1634 1634 ChangesetStatusModel().set_status(
1635 1635 self.db_repo.repo_id,
1636 1636 status,
1637 1637 self._rhodecode_user.user_id,
1638 1638 comment,
1639 1639 pull_request=pull_request
1640 1640 )
1641 1641
1642 1642 Session().flush()
1643 1643 # this is somehow required to get access to some relationship
1644 1644 # loaded on comment
1645 1645 Session().refresh(comment)
1646 1646
1647 1647 # skip notifications for drafts
1648 1648 if not is_draft:
1649 1649 PullRequestModel().trigger_pull_request_hook(
1650 1650 pull_request, self._rhodecode_user, 'comment',
1651 1651 data={'comment': comment})
1652 1652
1653 1653 # we now calculate the status of pull request, and based on that
1654 1654 # calculation we set the commits status
1655 1655 calculated_status = pull_request.calculated_review_status()
1656 1656 if old_calculated_status != calculated_status:
1657 1657 PullRequestModel().trigger_pull_request_hook(
1658 1658 pull_request, self._rhodecode_user, 'review_status_change',
1659 1659 data={'status': calculated_status})
1660 1660
1661 1661 comment_id = comment.comment_id
1662 1662 data[comment_id] = {
1663 1663 'target_id': target_elem_id
1664 1664 }
1665 1665 Session().flush()
1666 1666
1667 1667 c.co = comment
1668 1668 c.at_version_num = None
1669 1669 c.is_new = True
1670 1670 rendered_comment = render(
1671 1671 'rhodecode:templates/changeset/changeset_comment_block.mako',
1672 1672 self._get_template_context(c), self.request)
1673 1673
1674 1674 data[comment_id].update(comment.get_dict())
1675 1675 data[comment_id].update({'rendered_text': rendered_comment})
1676 1676
1677 1677 Session().commit()
1678 1678
1679 1679 # skip channelstream for draft comments
1680 1680 if not all_drafts:
1681 1681 comment_broadcast_channel = channelstream.comment_channel(
1682 1682 self.db_repo_name, pull_request_obj=pull_request)
1683 1683
1684 1684 comment_data = data
1685 1685 posted_comment_type = 'inline' if is_inline else 'general'
1686 1686 if len(data) == 1:
1687 1687 msg = _('posted {} new {} comment').format(len(data), posted_comment_type)
1688 1688 else:
1689 1689 msg = _('posted {} new {} comments').format(len(data), posted_comment_type)
1690 1690
1691 1691 channelstream.comment_channelstream_push(
1692 1692 self.request, comment_broadcast_channel, self._rhodecode_user, msg,
1693 1693 comment_data=comment_data)
1694 1694
1695 1695 return data
1696 1696
1697 1697 @LoginRequired()
1698 1698 @NotAnonymous()
1699 1699 @HasRepoPermissionAnyDecorator(
1700 1700 'repository.read', 'repository.write', 'repository.admin')
1701 1701 @CSRFRequired()
1702 1702 def pull_request_comment_create(self):
1703 1703 _ = self.request.translate
1704 1704
1705 1705 pull_request = PullRequest.get_or_404(self.request.matchdict['pull_request_id'])
1706 1706
1707 1707 if pull_request.is_closed():
1708 1708 log.debug('comment: forbidden because pull request is closed')
1709 1709 raise HTTPForbidden()
1710 1710
1711 1711 allowed_to_comment = PullRequestModel().check_user_comment(
1712 1712 pull_request, self._rhodecode_user)
1713 1713 if not allowed_to_comment:
1714 1714 log.debug('comment: forbidden because pull request is from forbidden repo')
1715 1715 raise HTTPForbidden()
1716 1716
1717 1717 comment_data = {
1718 1718 'comment_type': self.request.POST.get('comment_type'),
1719 1719 'text': self.request.POST.get('text'),
1720 1720 'status': self.request.POST.get('changeset_status', None),
1721 1721 'is_draft': self.request.POST.get('draft'),
1722 1722 'resolves_comment_id': self.request.POST.get('resolves_comment_id', None),
1723 1723 'close_pull_request': self.request.POST.get('close_pull_request'),
1724 1724 'f_path': self.request.POST.get('f_path'),
1725 1725 'line': self.request.POST.get('line'),
1726 1726 }
1727 1727 data = self._pull_request_comments_create(pull_request, [comment_data])
1728 1728
1729 1729 return data
1730 1730
1731 1731 @LoginRequired()
1732 1732 @NotAnonymous()
1733 1733 @HasRepoPermissionAnyDecorator(
1734 1734 'repository.read', 'repository.write', 'repository.admin')
1735 1735 @CSRFRequired()
1736 1736 def pull_request_comment_delete(self):
1737 1737 pull_request = PullRequest.get_or_404(
1738 1738 self.request.matchdict['pull_request_id'])
1739 1739
1740 1740 comment = ChangesetComment.get_or_404(
1741 1741 self.request.matchdict['comment_id'])
1742 1742 comment_id = comment.comment_id
1743 1743
1744 1744 if comment.immutable:
1745 1745 # don't allow deleting comments that are immutable
1746 1746 raise HTTPForbidden()
1747 1747
1748 1748 if pull_request.is_closed():
1749 1749 log.debug('comment: forbidden because pull request is closed')
1750 1750 raise HTTPForbidden()
1751 1751
1752 1752 if not comment:
1753 1753 log.debug('Comment with id:%s not found, skipping', comment_id)
1754 1754 # comment already deleted in another call probably
1755 1755 return True
1756 1756
1757 1757 if comment.pull_request.is_closed():
1758 1758 # don't allow deleting comments on closed pull request
1759 1759 raise HTTPForbidden()
1760 1760
1761 1761 is_repo_admin = h.HasRepoPermissionAny('repository.admin')(self.db_repo_name)
1762 1762 super_admin = h.HasPermissionAny('hg.admin')()
1763 1763 comment_owner = comment.author.user_id == self._rhodecode_user.user_id
1764 1764 is_repo_comment = comment.repo.repo_name == self.db_repo_name
1765 1765 comment_repo_admin = is_repo_admin and is_repo_comment
1766 1766
1767 1767 if comment.draft and not comment_owner:
1768 1768 # We never allow to delete draft comments for other than owners
1769 1769 raise HTTPNotFound()
1770 1770
1771 1771 if super_admin or comment_owner or comment_repo_admin:
1772 1772 old_calculated_status = comment.pull_request.calculated_review_status()
1773 1773 CommentsModel().delete(comment=comment, auth_user=self._rhodecode_user)
1774 1774 Session().commit()
1775 1775 calculated_status = comment.pull_request.calculated_review_status()
1776 1776 if old_calculated_status != calculated_status:
1777 1777 PullRequestModel().trigger_pull_request_hook(
1778 1778 comment.pull_request, self._rhodecode_user, 'review_status_change',
1779 1779 data={'status': calculated_status})
1780 1780 return True
1781 1781 else:
1782 1782 log.warning('No permissions for user %s to delete comment_id: %s',
1783 1783 self._rhodecode_db_user, comment_id)
1784 1784 raise HTTPNotFound()
1785 1785
1786 1786 @LoginRequired()
1787 1787 @NotAnonymous()
1788 1788 @HasRepoPermissionAnyDecorator(
1789 1789 'repository.read', 'repository.write', 'repository.admin')
1790 1790 @CSRFRequired()
1791 1791 def pull_request_comment_edit(self):
1792 1792 self.load_default_context()
1793 1793
1794 1794 pull_request = PullRequest.get_or_404(
1795 1795 self.request.matchdict['pull_request_id']
1796 1796 )
1797 1797 comment = ChangesetComment.get_or_404(
1798 1798 self.request.matchdict['comment_id']
1799 1799 )
1800 1800 comment_id = comment.comment_id
1801 1801
1802 1802 if comment.immutable:
1803 1803 # don't allow deleting comments that are immutable
1804 1804 raise HTTPForbidden()
1805 1805
1806 1806 if pull_request.is_closed():
1807 1807 log.debug('comment: forbidden because pull request is closed')
1808 1808 raise HTTPForbidden()
1809 1809
1810 1810 if comment.pull_request.is_closed():
1811 1811 # don't allow deleting comments on closed pull request
1812 1812 raise HTTPForbidden()
1813 1813
1814 1814 is_repo_admin = h.HasRepoPermissionAny('repository.admin')(self.db_repo_name)
1815 1815 super_admin = h.HasPermissionAny('hg.admin')()
1816 1816 comment_owner = comment.author.user_id == self._rhodecode_user.user_id
1817 1817 is_repo_comment = comment.repo.repo_name == self.db_repo_name
1818 1818 comment_repo_admin = is_repo_admin and is_repo_comment
1819 1819
1820 1820 if super_admin or comment_owner or comment_repo_admin:
1821 1821 text = self.request.POST.get('text')
1822 1822 version = self.request.POST.get('version')
1823 1823 if text == comment.text:
1824 1824 log.warning(
1825 1825 'Comment(PR): '
1826 1826 'Trying to create new version '
1827 1827 'with the same comment body {}'.format(
1828 1828 comment_id,
1829 1829 )
1830 1830 )
1831 1831 raise HTTPNotFound()
1832 1832
1833 1833 if version.isdigit():
1834 1834 version = int(version)
1835 1835 else:
1836 1836 log.warning(
1837 1837 'Comment(PR): Wrong version type {} {} '
1838 1838 'for comment {}'.format(
1839 1839 version,
1840 1840 type(version),
1841 1841 comment_id,
1842 1842 )
1843 1843 )
1844 1844 raise HTTPNotFound()
1845 1845
1846 1846 try:
1847 1847 comment_history = CommentsModel().edit(
1848 1848 comment_id=comment_id,
1849 1849 text=text,
1850 1850 auth_user=self._rhodecode_user,
1851 1851 version=version,
1852 1852 )
1853 1853 except CommentVersionMismatch:
1854 1854 raise HTTPConflict()
1855 1855
1856 1856 if not comment_history:
1857 1857 raise HTTPNotFound()
1858 1858
1859 1859 Session().commit()
1860 1860 if not comment.draft:
1861 1861 PullRequestModel().trigger_pull_request_hook(
1862 1862 pull_request, self._rhodecode_user, 'comment_edit',
1863 1863 data={'comment': comment})
1864 1864
1865 1865 return {
1866 1866 'comment_history_id': comment_history.comment_history_id,
1867 1867 'comment_id': comment.comment_id,
1868 1868 'comment_version': comment_history.version,
1869 1869 'comment_author_username': comment_history.author.username,
1870 1870 'comment_author_gravatar': h.gravatar_url(comment_history.author.email, 16),
1871 1871 'comment_created_on': h.age_component(comment_history.created_on,
1872 1872 time_is_local=True),
1873 1873 }
1874 1874 else:
1875 1875 log.warning('No permissions for user %s to edit comment_id: %s',
1876 1876 self._rhodecode_db_user, comment_id)
1877 1877 raise HTTPNotFound()
@@ -1,127 +1,125 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2017-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import logging
22 22
23 23 from pyramid.httpexceptions import HTTPFound, HTTPNotFound
24 24
25 25 import formencode
26 26
27 27 from rhodecode.apps._base import RepoAppView
28 28 from rhodecode.lib import audit_logger
29 29 from rhodecode.lib import helpers as h
30 30 from rhodecode.lib.auth import (
31 31 LoginRequired, HasRepoPermissionAnyDecorator, CSRFRequired)
32 32 from rhodecode.model.forms import IssueTrackerPatternsForm
33 33 from rhodecode.model.meta import Session
34 34 from rhodecode.model.settings import SettingsModel
35 35
36 36 log = logging.getLogger(__name__)
37 37
38 38
39 39 class RepoSettingsIssueTrackersView(RepoAppView):
40 40 def load_default_context(self):
41 41 c = self._get_local_tmpl_context()
42
43
44 42 return c
45 43
46 44 @LoginRequired()
47 45 @HasRepoPermissionAnyDecorator('repository.admin')
48 46 def repo_issuetracker(self):
49 47 c = self.load_default_context()
50 48 c.active = 'issuetracker'
51 49 c.data = 'data'
52 50
53 51 c.settings_model = self.db_repo_patterns
54 52 c.global_patterns = c.settings_model.get_global_settings()
55 53 c.repo_patterns = c.settings_model.get_repo_settings()
56 54
57 55 return self._get_template_context(c)
58 56
59 57 @LoginRequired()
60 58 @HasRepoPermissionAnyDecorator('repository.admin')
61 59 @CSRFRequired()
62 60 def repo_issuetracker_test(self):
63 61 return h.urlify_commit_message(
64 62 self.request.POST.get('test_text', ''),
65 63 self.db_repo_name)
66 64
67 65 @LoginRequired()
68 66 @HasRepoPermissionAnyDecorator('repository.admin')
69 67 @CSRFRequired()
70 68 def repo_issuetracker_delete(self):
71 69 _ = self.request.translate
72 70 uid = self.request.POST.get('uid')
73 71 repo_settings = self.db_repo_patterns
74 72 try:
75 73 repo_settings.delete_entries(uid)
76 74 except Exception:
77 75 h.flash(_('Error occurred during deleting issue tracker entry'),
78 76 category='error')
79 77 raise HTTPNotFound()
80 78
81 79 SettingsModel().invalidate_settings_cache()
82 80 h.flash(_('Removed issue tracker entry.'), category='success')
83 81
84 82 return {'deleted': uid}
85 83
86 84 def _update_patterns(self, form, repo_settings):
87 85 for uid in form['delete_patterns']:
88 86 repo_settings.delete_entries(uid)
89 87
90 88 for pattern_data in form['patterns']:
91 89 for setting_key, pattern, type_ in pattern_data:
92 90 sett = repo_settings.create_or_update_setting(
93 91 setting_key, pattern.strip(), type_)
94 92 Session().add(sett)
95 93
96 94 Session().commit()
97 95
98 96 @LoginRequired()
99 97 @HasRepoPermissionAnyDecorator('repository.admin')
100 98 @CSRFRequired()
101 99 def repo_issuetracker_update(self):
102 100 _ = self.request.translate
103 101 # Save inheritance
104 102 repo_settings = self.db_repo_patterns
105 103 inherited = (
106 104 self.request.POST.get('inherit_global_issuetracker') == "inherited")
107 105 repo_settings.inherit_global_settings = inherited
108 106 Session().commit()
109 107
110 108 try:
111 109 form = IssueTrackerPatternsForm(self.request.translate)().to_python(self.request.POST)
112 110 except formencode.Invalid as errors:
113 111 log.exception('Failed to add new pattern')
114 112 error = errors
115 113 h.flash(_('Invalid issue tracker pattern: {}'.format(error)),
116 114 category='error')
117 115 raise HTTPFound(
118 116 h.route_path('edit_repo_issuetracker',
119 117 repo_name=self.db_repo_name))
120 118
121 119 if form:
122 120 self._update_patterns(form, repo_settings)
123 121
124 122 h.flash(_('Updated issue tracker entries'), category='success')
125 123 raise HTTPFound(
126 124 h.route_path('edit_repo_issuetracker', repo_name=self.db_repo_name))
127 125
@@ -1,290 +1,290 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import logging
22 22 import string
23 23 import time
24 24
25 25 import rhodecode
26 26
27 27
28 28
29 29 from rhodecode.lib.view_utils import get_format_ref_id
30 30 from rhodecode.apps._base import RepoAppView
31 31 from rhodecode.config.conf import (LANGUAGES_EXTENSIONS_MAP)
32 32 from rhodecode.lib import helpers as h, rc_cache
33 33 from rhodecode.lib.utils2 import safe_str, safe_int
34 34 from rhodecode.lib.auth import LoginRequired, HasRepoPermissionAnyDecorator
35 35 from rhodecode.lib.ext_json import json
36 36 from rhodecode.lib.vcs.backends.base import EmptyCommit
37 37 from rhodecode.lib.vcs.exceptions import (
38 38 CommitError, EmptyRepositoryError, CommitDoesNotExistError)
39 39 from rhodecode.model.db import Statistics, CacheKey, User
40 40 from rhodecode.model.meta import Session
41 41 from rhodecode.model.scm import ScmModel
42 42
43 43 log = logging.getLogger(__name__)
44 44
45 45
46 46 class RepoSummaryView(RepoAppView):
47 47
48 48 def load_default_context(self):
49 49 c = self._get_local_tmpl_context(include_app_defaults=True)
50 50 c.rhodecode_repo = None
51 51 if not c.repository_requirements_missing:
52 52 c.rhodecode_repo = self.rhodecode_vcs_repo
53 53 return c
54 54
55 55 def _load_commits_context(self, c):
56 56 p = safe_int(self.request.GET.get('page'), 1)
57 57 size = safe_int(self.request.GET.get('size'), 10)
58 58
59 59 def url_generator(page_num):
60 60 query_params = {
61 61 'page': page_num,
62 62 'size': size
63 63 }
64 64 return h.route_path(
65 65 'repo_summary_commits',
66 66 repo_name=c.rhodecode_db_repo.repo_name, _query=query_params)
67 67
68 68 pre_load = self.get_commit_preload_attrs()
69 69
70 70 try:
71 71 collection = self.rhodecode_vcs_repo.get_commits(
72 72 pre_load=pre_load, translate_tags=False)
73 73 except EmptyRepositoryError:
74 74 collection = self.rhodecode_vcs_repo
75 75
76 76 c.repo_commits = h.RepoPage(
77 77 collection, page=p, items_per_page=size, url_maker=url_generator)
78 78 page_ids = [x.raw_id for x in c.repo_commits]
79 79 c.comments = self.db_repo.get_comments(page_ids)
80 80 c.statuses = self.db_repo.statuses(page_ids)
81 81
82 82 def _prepare_and_set_clone_url(self, c):
83 83 username = ''
84 84 if self._rhodecode_user.username != User.DEFAULT_USER:
85 85 username = safe_str(self._rhodecode_user.username)
86 86
87 87 _def_clone_uri = c.clone_uri_tmpl
88 88 _def_clone_uri_id = c.clone_uri_id_tmpl
89 89 _def_clone_uri_ssh = c.clone_uri_ssh_tmpl
90 90
91 91 c.clone_repo_url = self.db_repo.clone_url(
92 92 user=username, uri_tmpl=_def_clone_uri)
93 93 c.clone_repo_url_id = self.db_repo.clone_url(
94 94 user=username, uri_tmpl=_def_clone_uri_id)
95 95 c.clone_repo_url_ssh = self.db_repo.clone_url(
96 96 uri_tmpl=_def_clone_uri_ssh, ssh=True)
97 97
98 98 @LoginRequired()
99 99 @HasRepoPermissionAnyDecorator(
100 100 'repository.read', 'repository.write', 'repository.admin')
101 101 def summary_commits(self):
102 102 c = self.load_default_context()
103 103 self._prepare_and_set_clone_url(c)
104 104 self._load_commits_context(c)
105 105 return self._get_template_context(c)
106 106
107 107 @LoginRequired()
108 108 @HasRepoPermissionAnyDecorator(
109 109 'repository.read', 'repository.write', 'repository.admin')
110 110 def summary(self):
111 111 c = self.load_default_context()
112 112
113 113 # Prepare the clone URL
114 114 self._prepare_and_set_clone_url(c)
115 115
116 116 # If enabled, get statistics data
117 117 c.show_stats = bool(self.db_repo.enable_statistics)
118 118
119 119 stats = Session().query(Statistics) \
120 120 .filter(Statistics.repository == self.db_repo) \
121 121 .scalar()
122 122
123 123 c.stats_percentage = 0
124 124
125 125 if stats and stats.languages:
126 126 c.no_data = False is self.db_repo.enable_statistics
127 127 lang_stats_d = json.loads(stats.languages)
128 128
129 129 # Sort first by decreasing count and second by the file extension,
130 130 # so we have a consistent output.
131 131 lang_stats_items = sorted(lang_stats_d.items(),
132 132 key=lambda k: (-k[1], k[0]))[:10]
133 133 lang_stats = [(x, {"count": y,
134 134 "desc": LANGUAGES_EXTENSIONS_MAP.get(x)})
135 135 for x, y in lang_stats_items]
136 136
137 137 c.trending_languages = json.dumps(lang_stats)
138 138 else:
139 139 c.no_data = True
140 140 c.trending_languages = json.dumps({})
141 141
142 142 scm_model = ScmModel()
143 143 c.enable_downloads = self.db_repo.enable_downloads
144 144 c.repository_followers = scm_model.get_followers(self.db_repo)
145 145 c.repository_forks = scm_model.get_forks(self.db_repo)
146 146
147 147 # first interaction with the VCS instance after here...
148 148 if c.repository_requirements_missing:
149 149 self.request.override_renderer = \
150 150 'rhodecode:templates/summary/missing_requirements.mako'
151 151 return self._get_template_context(c)
152 152
153 153 c.readme_data, c.readme_file = \
154 154 self._get_readme_data(self.db_repo, c.visual.default_renderer)
155 155
156 156 # loads the summary commits template context
157 157 self._load_commits_context(c)
158 158
159 159 return self._get_template_context(c)
160 160
161 161 @LoginRequired()
162 162 @HasRepoPermissionAnyDecorator(
163 163 'repository.read', 'repository.write', 'repository.admin')
164 164 def repo_stats(self):
165 165 show_stats = bool(self.db_repo.enable_statistics)
166 166 repo_id = self.db_repo.repo_id
167 167
168 168 landing_commit = self.db_repo.get_landing_commit()
169 169 if isinstance(landing_commit, EmptyCommit):
170 170 return {'size': 0, 'code_stats': {}}
171 171
172 172 cache_seconds = safe_int(rhodecode.CONFIG.get('rc_cache.cache_repo.expiration_time'))
173 173 cache_on = cache_seconds > 0
174 174
175 175 log.debug(
176 176 'Computing REPO STATS for repo_id %s commit_id `%s` '
177 177 'with caching: %s[TTL: %ss]' % (
178 178 repo_id, landing_commit, cache_on, cache_seconds or 0))
179 179
180 180 cache_namespace_uid = 'cache_repo.{}'.format(repo_id)
181 181 region = rc_cache.get_or_create_region('cache_repo', cache_namespace_uid)
182 182
183 183 @region.conditional_cache_on_arguments(namespace=cache_namespace_uid,
184 184 condition=cache_on)
185 185 def compute_stats(repo_id, commit_id, _show_stats):
186 186 code_stats = {}
187 187 size = 0
188 188 try:
189 189 commit = self.db_repo.get_commit(commit_id)
190 190
191 191 for node in commit.get_filenodes_generator():
192 192 size += node.size
193 193 if not _show_stats:
194 194 continue
195 ext = string.lower(node.extension)
195 ext = node.extension.lower()
196 196 ext_info = LANGUAGES_EXTENSIONS_MAP.get(ext)
197 197 if ext_info:
198 198 if ext in code_stats:
199 199 code_stats[ext]['count'] += 1
200 200 else:
201 201 code_stats[ext] = {"count": 1, "desc": ext_info}
202 202 except (EmptyRepositoryError, CommitDoesNotExistError):
203 203 pass
204 204 return {'size': h.format_byte_size_binary(size),
205 205 'code_stats': code_stats}
206 206
207 207 stats = compute_stats(self.db_repo.repo_id, landing_commit.raw_id, show_stats)
208 208 return stats
209 209
210 210 @LoginRequired()
211 211 @HasRepoPermissionAnyDecorator(
212 212 'repository.read', 'repository.write', 'repository.admin')
213 213 def repo_refs_data(self):
214 214 _ = self.request.translate
215 215 self.load_default_context()
216 216
217 217 repo = self.rhodecode_vcs_repo
218 218 refs_to_create = [
219 219 (_("Branch"), repo.branches, 'branch'),
220 220 (_("Tag"), repo.tags, 'tag'),
221 221 (_("Bookmark"), repo.bookmarks, 'book'),
222 222 ]
223 223 res = self._create_reference_data(repo, self.db_repo_name, refs_to_create)
224 224 data = {
225 225 'more': False,
226 226 'results': res
227 227 }
228 228 return data
229 229
230 230 @LoginRequired()
231 231 @HasRepoPermissionAnyDecorator(
232 232 'repository.read', 'repository.write', 'repository.admin')
233 233 def repo_refs_changelog_data(self):
234 234 _ = self.request.translate
235 235 self.load_default_context()
236 236
237 237 repo = self.rhodecode_vcs_repo
238 238
239 239 refs_to_create = [
240 240 (_("Branches"), repo.branches, 'branch'),
241 241 (_("Closed branches"), repo.branches_closed, 'branch_closed'),
242 242 # TODO: enable when vcs can handle bookmarks filters
243 243 # (_("Bookmarks"), repo.bookmarks, "book"),
244 244 ]
245 245 res = self._create_reference_data(
246 246 repo, self.db_repo_name, refs_to_create)
247 247 data = {
248 248 'more': False,
249 249 'results': res
250 250 }
251 251 return data
252 252
253 253 def _create_reference_data(self, repo, full_repo_name, refs_to_create):
254 254 format_ref_id = get_format_ref_id(repo)
255 255
256 256 result = []
257 257 for title, refs, ref_type in refs_to_create:
258 258 if refs:
259 259 result.append({
260 260 'text': title,
261 261 'children': self._create_reference_items(
262 262 repo, full_repo_name, refs, ref_type,
263 263 format_ref_id),
264 264 })
265 265 return result
266 266
267 267 def _create_reference_items(self, repo, full_repo_name, refs, ref_type, format_ref_id):
268 268 result = []
269 269 is_svn = h.is_svn(repo)
270 270 for ref_name, raw_id in refs.items():
271 271 files_url = self._create_files_url(
272 272 repo, full_repo_name, ref_name, raw_id, is_svn)
273 273 result.append({
274 274 'text': ref_name,
275 275 'id': format_ref_id(ref_name, raw_id),
276 276 'raw_id': raw_id,
277 277 'type': ref_type,
278 278 'files_url': files_url,
279 279 'idx': 0,
280 280 })
281 281 return result
282 282
283 283 def _create_files_url(self, repo, full_repo_name, ref_name, raw_id, is_svn):
284 284 use_commit_id = '/' in ref_name or is_svn
285 285 return h.route_path(
286 286 'repo_files',
287 287 repo_name=full_repo_name,
288 288 f_path=ref_name if is_svn else '',
289 289 commit_id=raw_id if use_commit_id else ref_name,
290 290 _query=dict(at=ref_name))
@@ -1,172 +1,172 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2012-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 """
22 22 RhodeCode authentication library for PAM
23 23 """
24 24
25 25 import colander
26 26 import grp
27 27 import logging
28 28 import pam
29 29 import pwd
30 30 import re
31 31 import socket
32 32
33 33 from rhodecode.translation import _
34 34 from rhodecode.authentication.base import (
35 35 RhodeCodeExternalAuthPlugin, hybrid_property)
36 36 from rhodecode.authentication.schema import AuthnPluginSettingsSchemaBase
37 37 from rhodecode.authentication.routes import AuthnPluginResourceBase
38 38 from rhodecode.lib.colander_utils import strip_whitespace
39 39
40 40 log = logging.getLogger(__name__)
41 41
42 42
43 43 def plugin_factory(plugin_id, *args, **kwargs):
44 44 """
45 45 Factory function that is called during plugin discovery.
46 46 It returns the plugin instance.
47 47 """
48 48 plugin = RhodeCodeAuthPlugin(plugin_id)
49 49 return plugin
50 50
51 51
52 52 class PamAuthnResource(AuthnPluginResourceBase):
53 53 pass
54 54
55 55
56 56 class PamSettingsSchema(AuthnPluginSettingsSchemaBase):
57 57 service = colander.SchemaNode(
58 58 colander.String(),
59 59 default='login',
60 60 description=_('PAM service name to use for authentication.'),
61 61 preparer=strip_whitespace,
62 62 title=_('PAM service name'),
63 63 widget='string')
64 64 gecos = colander.SchemaNode(
65 65 colander.String(),
66 default='(?P<last_name>.+),\s*(?P<first_name>\w+)',
66 default=r'(?P<last_name>.+),\s*(?P<first_name>\w+)',
67 67 description=_('Regular expression for extracting user name/email etc. '
68 68 'from Unix userinfo.'),
69 69 preparer=strip_whitespace,
70 70 title=_('Gecos Regex'),
71 71 widget='string')
72 72
73 73
74 74 class RhodeCodeAuthPlugin(RhodeCodeExternalAuthPlugin):
75 75 uid = 'pam'
76 76 # PAM authentication can be slow. Repository operations involve a lot of
77 77 # auth calls. Little caching helps speedup push/pull operations significantly
78 78 AUTH_CACHE_TTL = 4
79 79
80 80 def includeme(self, config):
81 81 config.add_authn_plugin(self)
82 82 config.add_authn_resource(self.get_id(), PamAuthnResource(self))
83 83 config.add_view(
84 84 'rhodecode.authentication.views.AuthnPluginViewBase',
85 85 attr='settings_get',
86 86 renderer='rhodecode:templates/admin/auth/plugin_settings.mako',
87 87 request_method='GET',
88 88 route_name='auth_home',
89 89 context=PamAuthnResource)
90 90 config.add_view(
91 91 'rhodecode.authentication.views.AuthnPluginViewBase',
92 92 attr='settings_post',
93 93 renderer='rhodecode:templates/admin/auth/plugin_settings.mako',
94 94 request_method='POST',
95 95 route_name='auth_home',
96 96 context=PamAuthnResource)
97 97
98 98 def get_display_name(self, load_from_settings=False):
99 99 return _('PAM')
100 100
101 101 @classmethod
102 102 def docs(cls):
103 103 return "https://docs.rhodecode.com/RhodeCode-Enterprise/auth/auth-pam.html"
104 104
105 105 @hybrid_property
106 106 def name(self):
107 107 return u"pam"
108 108
109 109 def get_settings_schema(self):
110 110 return PamSettingsSchema()
111 111
112 112 def use_fake_password(self):
113 113 return True
114 114
115 115 def auth(self, userobj, username, password, settings, **kwargs):
116 116 if not username or not password:
117 117 log.debug('Empty username or password skipping...')
118 118 return None
119 119 _pam = pam.pam()
120 120 auth_result = _pam.authenticate(username, password, settings["service"])
121 121
122 122 if not auth_result:
123 123 log.error("PAM was unable to authenticate user: %s", username)
124 124 return None
125 125
126 126 log.debug('Got PAM response %s', auth_result)
127 127
128 128 # old attrs fetched from RhodeCode database
129 129 default_email = "%s@%s" % (username, socket.gethostname())
130 130 admin = getattr(userobj, 'admin', False)
131 131 active = getattr(userobj, 'active', True)
132 132 email = getattr(userobj, 'email', '') or default_email
133 133 username = getattr(userobj, 'username', username)
134 134 firstname = getattr(userobj, 'firstname', '')
135 135 lastname = getattr(userobj, 'lastname', '')
136 136 extern_type = getattr(userobj, 'extern_type', '')
137 137
138 138 user_attrs = {
139 139 'username': username,
140 140 'firstname': firstname,
141 141 'lastname': lastname,
142 142 'groups': [g.gr_name for g in grp.getgrall()
143 143 if username in g.gr_mem],
144 144 'user_group_sync': True,
145 145 'email': email,
146 146 'admin': admin,
147 147 'active': active,
148 148 'active_from_extern': None,
149 149 'extern_name': username,
150 150 'extern_type': extern_type,
151 151 }
152 152
153 153 try:
154 154 user_data = pwd.getpwnam(username)
155 155 regex = settings["gecos"]
156 156 match = re.search(regex, user_data.pw_gecos)
157 157 if match:
158 158 user_attrs["firstname"] = match.group('first_name')
159 159 user_attrs["lastname"] = match.group('last_name')
160 160 except Exception:
161 161 log.warning("Cannot extract additional info for PAM user")
162 162 pass
163 163
164 164 log.debug("pamuser: %s", user_attrs)
165 165 log.info('user `%s` authenticated correctly', user_attrs['username'],
166 166 extra={"action": "user_auth_ok", "auth_module": "auth_pam", "username": user_attrs["username"]})
167 167 return user_attrs
168 168
169 169
170 170 def includeme(config):
171 171 plugin_id = 'egg:rhodecode-enterprise-ce#{}'.format(RhodeCodeAuthPlugin.uid)
172 172 plugin_factory(plugin_id).includeme(config)
@@ -1,89 +1,90 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import os
22 22 import logging
23 23 import rhodecode
24 import collections
24 25
25 26 from rhodecode.config import utils
26 27
27 28 from rhodecode.lib.utils import load_rcextensions
28 29 from rhodecode.lib.utils2 import str2bool
29 30 from rhodecode.lib.vcs import connect_vcs
30 31
31 32 log = logging.getLogger(__name__)
32 33
33 34
34 35 def load_pyramid_environment(global_config, settings):
35 36 # Some parts of the code expect a merge of global and app settings.
36 37 settings_merged = global_config.copy()
37 38 settings_merged.update(settings)
38 39
39 40 # TODO(marcink): probably not required anymore
40 41 # configure channelstream,
41 42 settings_merged['channelstream_config'] = {
42 43 'enabled': str2bool(settings_merged.get('channelstream.enabled', False)),
43 44 'server': settings_merged.get('channelstream.server'),
44 45 'secret': settings_merged.get('channelstream.secret')
45 46 }
46 47
47 48 # If this is a test run we prepare the test environment like
48 49 # creating a test database, test search index and test repositories.
49 50 # This has to be done before the database connection is initialized.
50 51 if settings['is_test']:
51 52 rhodecode.is_test = True
52 53 rhodecode.disable_error_handler = True
53 54 from rhodecode import authentication
54 55 authentication.plugin_default_auth_ttl = 0
55 56
56 57 utils.initialize_test_environment(settings_merged)
57 58
58 59 # Initialize the database connection.
59 60 utils.initialize_database(settings_merged)
60 61
61 62 load_rcextensions(root_path=settings_merged['here'])
62 63
63 64 # Limit backends to `vcs.backends` from configuration, and preserve the order
64 65 for alias in rhodecode.BACKENDS.keys():
65 66 if alias not in settings['vcs.backends']:
66 67 del rhodecode.BACKENDS[alias]
67 68
68 69 _sorted_backend = sorted(rhodecode.BACKENDS.items(),
69 70 key=lambda item: settings['vcs.backends'].index(item[0]))
70 rhodecode.BACKENDS = rhodecode.OrderedDict(_sorted_backend)
71 rhodecode.BACKENDS = collections.OrderedDict(_sorted_backend)
71 72
72 73 log.info('Enabled VCS backends: %s', rhodecode.BACKENDS.keys())
73 74
74 75 # initialize vcs client and optionally run the server if enabled
75 76 vcs_server_uri = settings['vcs.server']
76 77 vcs_server_enabled = settings['vcs.server.enable']
77 78
78 79 utils.configure_vcs(settings)
79 80
80 81 # Store the settings to make them available to other modules.
81 82
82 83 rhodecode.PYRAMID_SETTINGS = settings_merged
83 84 rhodecode.CONFIG = settings_merged
84 85 rhodecode.CONFIG['default_user_id'] = utils.get_default_user_id()
85 86
86 87 if vcs_server_enabled:
87 88 connect_vcs(vcs_server_uri, utils.get_vcs_server_protocol(settings))
88 89 else:
89 90 log.warning('vcs-server not enabled, vcs connection unavailable')
@@ -1,282 +1,280 b''
1 1 # -*- coding: utf-8 -*-
2 2 """
3 3 Adapters
4 4 --------
5 5
6 6 .. contents::
7 7 :backlinks: none
8 8
9 9 The :func:`authomatic.login` function needs access to functionality like
10 10 getting the **URL** of the handler where it is being called, getting the
11 11 **request params** and **cookies** and **writing the body**, **headers**
12 12 and **status** to the response.
13 13
14 14 Since implementation of these features varies across Python web frameworks,
15 15 the Authomatic library uses **adapters** to unify these differences into a
16 16 single interface.
17 17
18 18 Available Adapters
19 19 ^^^^^^^^^^^^^^^^^^
20 20
21 21 If you are missing an adapter for the framework of your choice, please
22 22 open an `enhancement issue <https://github.com/authomatic/authomatic/issues>`_
23 23 or consider a contribution to this module by
24 24 :ref:`implementing <implement_adapters>` one by yourself.
25 25 Its very easy and shouldn't take you more than a few minutes.
26 26
27 27 .. autoclass:: DjangoAdapter
28 28 :members:
29 29
30 30 .. autoclass:: Webapp2Adapter
31 31 :members:
32 32
33 33 .. autoclass:: WebObAdapter
34 34 :members:
35 35
36 36 .. autoclass:: WerkzeugAdapter
37 37 :members:
38 38
39 39 .. _implement_adapters:
40 40
41 41 Implementing an Adapter
42 42 ^^^^^^^^^^^^^^^^^^^^^^^
43 43
44 44 Implementing an adapter for a Python web framework is pretty easy.
45 45
46 46 Do it by subclassing the :class:`.BaseAdapter` abstract class.
47 47 There are only **six** members that you need to implement.
48 48
49 49 Moreover if your framework is based on the |webob|_ or |werkzeug|_ package
50 50 you can subclass the :class:`.WebObAdapter` or :class:`.WerkzeugAdapter`
51 51 respectively.
52 52
53 53 .. autoclass:: BaseAdapter
54 54 :members:
55 55
56 56 """
57 57
58 58 import abc
59 59 from authomatic.core import Response
60 60
61 61
62 class BaseAdapter(object):
62 class BaseAdapter(object, metaclass=abc.ABCMeta):
63 63 """
64 64 Base class for platform adapters.
65 65
66 66 Defines common interface for WSGI framework specific functionality.
67 67
68 68 """
69 69
70 __metaclass__ = abc.ABCMeta
71
72 70 @abc.abstractproperty
73 71 def params(self):
74 72 """
75 73 Must return a :class:`dict` of all request parameters of any HTTP
76 74 method.
77 75
78 76 :returns:
79 77 :class:`dict`
80 78
81 79 """
82 80
83 81 @abc.abstractproperty
84 82 def url(self):
85 83 """
86 84 Must return the url of the actual request including path but without
87 85 query and fragment.
88 86
89 87 :returns:
90 88 :class:`str`
91 89
92 90 """
93 91
94 92 @abc.abstractproperty
95 93 def cookies(self):
96 94 """
97 95 Must return cookies as a :class:`dict`.
98 96
99 97 :returns:
100 98 :class:`dict`
101 99
102 100 """
103 101
104 102 @abc.abstractmethod
105 103 def write(self, value):
106 104 """
107 105 Must write specified value to response.
108 106
109 107 :param str value:
110 108 String to be written to response.
111 109
112 110 """
113 111
114 112 @abc.abstractmethod
115 113 def set_header(self, key, value):
116 114 """
117 115 Must set response headers to ``Key: value``.
118 116
119 117 :param str key:
120 118 Header name.
121 119
122 120 :param str value:
123 121 Header value.
124 122
125 123 """
126 124
127 125 @abc.abstractmethod
128 126 def set_status(self, status):
129 127 """
130 128 Must set the response status e.g. ``'302 Found'``.
131 129
132 130 :param str status:
133 131 The HTTP response status.
134 132
135 133 """
136 134
137 135
138 136 class DjangoAdapter(BaseAdapter):
139 137 """
140 138 Adapter for the |django|_ framework.
141 139 """
142 140
143 141 def __init__(self, request, response):
144 142 """
145 143 :param request:
146 144 An instance of the :class:`django.http.HttpRequest` class.
147 145
148 146 :param response:
149 147 An instance of the :class:`django.http.HttpResponse` class.
150 148 """
151 149 self.request = request
152 150 self.response = response
153 151
154 152 @property
155 153 def params(self):
156 154 params = {}
157 155 params.update(self.request.GET.dict())
158 156 params.update(self.request.POST.dict())
159 157 return params
160 158
161 159 @property
162 160 def url(self):
163 161 return self.request.build_absolute_uri(self.request.path)
164 162
165 163 @property
166 164 def cookies(self):
167 165 return dict(self.request.COOKIES)
168 166
169 167 def write(self, value):
170 168 self.response.write(value)
171 169
172 170 def set_header(self, key, value):
173 171 self.response[key] = value
174 172
175 173 def set_status(self, status):
176 174 status_code, reason = status.split(' ', 1)
177 175 self.response.status_code = int(status_code)
178 176
179 177
180 178 class WebObAdapter(BaseAdapter):
181 179 """
182 180 Adapter for the |webob|_ package.
183 181 """
184 182
185 183 def __init__(self, request, response):
186 184 """
187 185 :param request:
188 186 A |webob|_ :class:`Request` instance.
189 187
190 188 :param response:
191 189 A |webob|_ :class:`Response` instance.
192 190 """
193 191 self.request = request
194 192 self.response = response
195 193
196 194 # =========================================================================
197 195 # Request
198 196 # =========================================================================
199 197
200 198 @property
201 199 def url(self):
202 200 return self.request.path_url
203 201
204 202 @property
205 203 def params(self):
206 204 return dict(self.request.params)
207 205
208 206 @property
209 207 def cookies(self):
210 208 return dict(self.request.cookies)
211 209
212 210 # =========================================================================
213 211 # Response
214 212 # =========================================================================
215 213
216 214 def write(self, value):
217 215 self.response.write(value)
218 216
219 217 def set_header(self, key, value):
220 218 self.response.headers[key] = str(value)
221 219
222 220 def set_status(self, status):
223 221 self.response.status = status
224 222
225 223
226 224 class Webapp2Adapter(WebObAdapter):
227 225 """
228 226 Adapter for the |webapp2|_ framework.
229 227
230 228 Inherits from the :class:`.WebObAdapter`.
231 229
232 230 """
233 231
234 232 def __init__(self, handler):
235 233 """
236 234 :param handler:
237 235 A :class:`webapp2.RequestHandler` instance.
238 236 """
239 237 self.request = handler.request
240 238 self.response = handler.response
241 239
242 240
243 241 class WerkzeugAdapter(BaseAdapter):
244 242 """
245 243 Adapter for |flask|_ and other |werkzeug|_ based frameworks.
246 244
247 245 Thanks to `Mark Steve Samson <http://marksteve.com>`_.
248 246
249 247 """
250 248
251 249 @property
252 250 def params(self):
253 251 return self.request.args
254 252
255 253 @property
256 254 def url(self):
257 255 return self.request.base_url
258 256
259 257 @property
260 258 def cookies(self):
261 259 return self.request.cookies
262 260
263 261 def __init__(self, request, response):
264 262 """
265 263 :param request:
266 264 Instance of the :class:`werkzeug.wrappers.Request` class.
267 265
268 266 :param response:
269 267 Instance of the :class:`werkzeug.wrappers.Response` class.
270 268 """
271 269
272 270 self.request = request
273 271 self.response = response
274 272
275 273 def write(self, value):
276 274 self.response.data = self.response.data + value
277 275
278 276 def set_header(self, key, value):
279 277 self.response.headers[key] = value
280 278
281 279 def set_status(self, status):
282 280 self.response.status = status
@@ -1,305 +1,305 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2017-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import logging
22 22 import datetime
23 23
24 24 from rhodecode.lib.jsonalchemy import JsonRaw
25 25 from rhodecode.model import meta
26 26 from rhodecode.model.db import User, UserLog, Repository
27 27
28 28
29 29 log = logging.getLogger(__name__)
30 30
31 31 # action as key, and expected action_data as value
32 32 ACTIONS_V1 = {
33 33 'user.login.success': {'user_agent': ''},
34 34 'user.login.failure': {'user_agent': ''},
35 35 'user.logout': {'user_agent': ''},
36 36 'user.register': {},
37 37 'user.password.reset_request': {},
38 38 'user.push': {'user_agent': '', 'commit_ids': []},
39 39 'user.pull': {'user_agent': ''},
40 40
41 41 'user.create': {'data': {}},
42 42 'user.delete': {'old_data': {}},
43 43 'user.edit': {'old_data': {}},
44 44 'user.edit.permissions': {},
45 45 'user.edit.ip.add': {'ip': {}, 'user': {}},
46 46 'user.edit.ip.delete': {'ip': {}, 'user': {}},
47 47 'user.edit.token.add': {'token': {}, 'user': {}},
48 48 'user.edit.token.delete': {'token': {}, 'user': {}},
49 49 'user.edit.email.add': {'email': ''},
50 50 'user.edit.email.delete': {'email': ''},
51 51 'user.edit.ssh_key.add': {'token': {}, 'user': {}},
52 52 'user.edit.ssh_key.delete': {'token': {}, 'user': {}},
53 53 'user.edit.password_reset.enabled': {},
54 54 'user.edit.password_reset.disabled': {},
55 55
56 56 'user_group.create': {'data': {}},
57 57 'user_group.delete': {'old_data': {}},
58 58 'user_group.edit': {'old_data': {}},
59 59 'user_group.edit.permissions': {},
60 60 'user_group.edit.member.add': {'user': {}},
61 61 'user_group.edit.member.delete': {'user': {}},
62 62
63 63 'repo.create': {'data': {}},
64 64 'repo.fork': {'data': {}},
65 65 'repo.edit': {'old_data': {}},
66 66 'repo.edit.permissions': {},
67 67 'repo.edit.permissions.branch': {},
68 68 'repo.archive': {'old_data': {}},
69 69 'repo.delete': {'old_data': {}},
70 70
71 71 'repo.archive.download': {'user_agent': '', 'archive_name': '',
72 72 'archive_spec': '', 'archive_cached': ''},
73 73
74 74 'repo.permissions.branch_rule.create': {},
75 75 'repo.permissions.branch_rule.edit': {},
76 76 'repo.permissions.branch_rule.delete': {},
77 77
78 78 'repo.pull_request.create': '',
79 79 'repo.pull_request.edit': '',
80 80 'repo.pull_request.delete': '',
81 81 'repo.pull_request.close': '',
82 82 'repo.pull_request.merge': '',
83 83 'repo.pull_request.vote': '',
84 84 'repo.pull_request.comment.create': '',
85 85 'repo.pull_request.comment.edit': '',
86 86 'repo.pull_request.comment.delete': '',
87 87
88 88 'repo.pull_request.reviewer.add': '',
89 89 'repo.pull_request.reviewer.delete': '',
90 90
91 91 'repo.pull_request.observer.add': '',
92 92 'repo.pull_request.observer.delete': '',
93 93
94 94 'repo.commit.strip': {'commit_id': ''},
95 95 'repo.commit.comment.create': {'data': {}},
96 96 'repo.commit.comment.delete': {'data': {}},
97 97 'repo.commit.comment.edit': {'data': {}},
98 98 'repo.commit.vote': '',
99 99
100 100 'repo.artifact.add': '',
101 101 'repo.artifact.delete': '',
102 102
103 103 'repo_group.create': {'data': {}},
104 104 'repo_group.edit': {'old_data': {}},
105 105 'repo_group.edit.permissions': {},
106 106 'repo_group.delete': {'old_data': {}},
107 107 }
108 108
109 109 ACTIONS = ACTIONS_V1
110 110
111 111 SOURCE_WEB = 'source_web'
112 112 SOURCE_API = 'source_api'
113 113
114 114
115 115 class UserWrap(object):
116 116 """
117 117 Fake object used to imitate AuthUser
118 118 """
119 119
120 120 def __init__(self, user_id=None, username=None, ip_addr=None):
121 121 self.user_id = user_id
122 122 self.username = username
123 123 self.ip_addr = ip_addr
124 124
125 125
126 126 class RepoWrap(object):
127 127 """
128 128 Fake object used to imitate RepoObject that audit logger requires
129 129 """
130 130
131 131 def __init__(self, repo_id=None, repo_name=None):
132 132 self.repo_id = repo_id
133 133 self.repo_name = repo_name
134 134
135 135
136 136 def _store_log(action_name, action_data, user_id, username, user_data,
137 137 ip_address, repository_id, repository_name):
138 138 user_log = UserLog()
139 139 user_log.version = UserLog.VERSION_2
140 140
141 141 user_log.action = action_name
142 user_log.action_data = action_data or JsonRaw(u'{}')
142 user_log.action_data = action_data or JsonRaw('{}')
143 143
144 144 user_log.user_ip = ip_address
145 145
146 146 user_log.user_id = user_id
147 147 user_log.username = username
148 user_log.user_data = user_data or JsonRaw(u'{}')
148 user_log.user_data = user_data or JsonRaw('{}')
149 149
150 150 user_log.repository_id = repository_id
151 151 user_log.repository_name = repository_name
152 152
153 153 user_log.action_date = datetime.datetime.now()
154 154
155 155 return user_log
156 156
157 157
158 158 def store_web(*args, **kwargs):
159 159 action_data = {}
160 160 org_action_data = kwargs.pop('action_data', {})
161 161 action_data.update(org_action_data)
162 162 action_data['source'] = SOURCE_WEB
163 163 kwargs['action_data'] = action_data
164 164
165 165 return store(*args, **kwargs)
166 166
167 167
168 168 def store_api(*args, **kwargs):
169 169 action_data = {}
170 170 org_action_data = kwargs.pop('action_data', {})
171 171 action_data.update(org_action_data)
172 172 action_data['source'] = SOURCE_API
173 173 kwargs['action_data'] = action_data
174 174
175 175 return store(*args, **kwargs)
176 176
177 177
178 178 def store(action, user, action_data=None, user_data=None, ip_addr=None,
179 179 repo=None, sa_session=None, commit=False):
180 180 """
181 181 Audit logger for various actions made by users, typically this
182 182 results in a call such::
183 183
184 184 from rhodecode.lib import audit_logger
185 185
186 186 audit_logger.store(
187 187 'repo.edit', user=self._rhodecode_user)
188 188 audit_logger.store(
189 189 'repo.delete', action_data={'data': repo_data},
190 190 user=audit_logger.UserWrap(username='itried-login', ip_addr='8.8.8.8'))
191 191
192 192 # repo action
193 193 audit_logger.store(
194 194 'repo.delete',
195 195 user=audit_logger.UserWrap(username='itried-login', ip_addr='8.8.8.8'),
196 196 repo=audit_logger.RepoWrap(repo_name='some-repo'))
197 197
198 198 # repo action, when we know and have the repository object already
199 199 audit_logger.store(
200 200 'repo.delete', action_data={'source': audit_logger.SOURCE_WEB, },
201 201 user=self._rhodecode_user,
202 202 repo=repo_object)
203 203
204 204 # alternative wrapper to the above
205 205 audit_logger.store_web(
206 206 'repo.delete', action_data={},
207 207 user=self._rhodecode_user,
208 208 repo=repo_object)
209 209
210 210 # without an user ?
211 211 audit_logger.store(
212 212 'user.login.failure',
213 213 user=audit_logger.UserWrap(
214 214 username=self.request.params.get('username'),
215 215 ip_addr=self.request.remote_addr))
216 216
217 217 """
218 218 from rhodecode.lib.utils2 import safe_unicode
219 219 from rhodecode.lib.auth import AuthUser
220 220
221 221 action_spec = ACTIONS.get(action, None)
222 222 if action_spec is None:
223 223 raise ValueError('Action `{}` is not supported'.format(action))
224 224
225 225 if not sa_session:
226 226 sa_session = meta.Session()
227 227
228 228 try:
229 229 username = getattr(user, 'username', None)
230 230 if not username:
231 231 pass
232 232
233 233 user_id = getattr(user, 'user_id', None)
234 234 if not user_id:
235 235 # maybe we have username ? Try to figure user_id from username
236 236 if username:
237 237 user_id = getattr(
238 238 User.get_by_username(username), 'user_id', None)
239 239
240 240 ip_addr = ip_addr or getattr(user, 'ip_addr', None)
241 241 if not ip_addr:
242 242 pass
243 243
244 244 if not user_data:
245 245 # try to get this from the auth user
246 246 if isinstance(user, AuthUser):
247 247 user_data = {
248 248 'username': user.username,
249 249 'email': user.email,
250 250 }
251 251
252 252 repository_name = getattr(repo, 'repo_name', None)
253 253 repository_id = getattr(repo, 'repo_id', None)
254 254 if not repository_id:
255 255 # maybe we have repo_name ? Try to figure repo_id from repo_name
256 256 if repository_name:
257 257 repository_id = getattr(
258 258 Repository.get_by_repo_name(repository_name), 'repo_id', None)
259 259
260 260 action_name = safe_unicode(action)
261 261 ip_address = safe_unicode(ip_addr)
262 262
263 263 with sa_session.no_autoflush:
264 264
265 265 user_log = _store_log(
266 266 action_name=action_name,
267 267 action_data=action_data or {},
268 268 user_id=user_id,
269 269 username=username,
270 270 user_data=user_data or {},
271 271 ip_address=ip_address,
272 272 repository_id=repository_id,
273 273 repository_name=repository_name
274 274 )
275 275
276 276 sa_session.add(user_log)
277 277 if commit:
278 278 sa_session.commit()
279 279 entry_id = user_log.entry_id or ''
280 280
281 281 update_user_last_activity(sa_session, user_id)
282 282
283 283 if commit:
284 284 sa_session.commit()
285 285
286 286 log.info('AUDIT[%s]: Logging action: `%s` by user:id:%s[%s] ip:%s',
287 287 entry_id, action_name, user_id, username, ip_address,
288 288 extra={"entry_id": entry_id, "action": action_name,
289 289 "user_id": user_id, "ip": ip_address})
290 290
291 291 except Exception:
292 292 log.exception('AUDIT: failed to store audit log')
293 293
294 294
295 295 def update_user_last_activity(sa_session, user_id):
296 296 _last_activity = datetime.datetime.now()
297 297 try:
298 298 sa_session.query(User).filter(User.user_id == user_id).update(
299 299 {"last_activity": _last_activity})
300 300 log.debug(
301 301 'updated user `%s` last activity to:%s', user_id, _last_activity)
302 302 except Exception:
303 303 log.exception("Failed last activity update for user_id: %s", user_id)
304 304 sa_session.rollback()
305 305
@@ -1,371 +1,371 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2016-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import os
22 22 import hashlib
23 23 import itsdangerous
24 24 import logging
25 25 import requests
26 26 import datetime
27 27
28 28 from dogpile.util.readwrite_lock import ReadWriteMutex
29 29 from pyramid.threadlocal import get_current_registry
30 30
31 31 import rhodecode.lib.helpers as h
32 32 from rhodecode.lib.auth import HasRepoPermissionAny
33 33 from rhodecode.lib.ext_json import json
34 34 from rhodecode.model.db import User
35 35
36 36 log = logging.getLogger(__name__)
37 37
38 38 LOCK = ReadWriteMutex()
39 39
40 40 USER_STATE_PUBLIC_KEYS = [
41 41 'id', 'username', 'first_name', 'last_name',
42 42 'icon_link', 'display_name', 'display_link']
43 43
44 44
45 45 class ChannelstreamException(Exception):
46 46 pass
47 47
48 48
49 49 class ChannelstreamConnectionException(ChannelstreamException):
50 50 pass
51 51
52 52
53 53 class ChannelstreamPermissionException(ChannelstreamException):
54 54 pass
55 55
56 56
57 57 def get_channelstream_server_url(config, endpoint):
58 58 return 'http://{}{}'.format(config['server'], endpoint)
59 59
60 60
61 61 def channelstream_request(config, payload, endpoint, raise_exc=True):
62 62 signer = itsdangerous.TimestampSigner(config['secret'])
63 63 sig_for_server = signer.sign(endpoint)
64 64 secret_headers = {'x-channelstream-secret': sig_for_server,
65 65 'x-channelstream-endpoint': endpoint,
66 66 'Content-Type': 'application/json'}
67 67 req_url = get_channelstream_server_url(config, endpoint)
68 68
69 69 log.debug('Sending a channelstream request to endpoint: `%s`', req_url)
70 70 response = None
71 71 try:
72 72 response = requests.post(req_url, data=json.dumps(payload),
73 73 headers=secret_headers).json()
74 74 except requests.ConnectionError:
75 75 log.exception('ConnectionError occurred for endpoint %s', req_url)
76 76 if raise_exc:
77 77 raise ChannelstreamConnectionException(req_url)
78 78 except Exception:
79 79 log.exception('Exception related to Channelstream happened')
80 80 if raise_exc:
81 81 raise ChannelstreamConnectionException()
82 82 log.debug('Got channelstream response: %s', response)
83 83 return response
84 84
85 85
86 86 def get_user_data(user_id):
87 87 user = User.get(user_id)
88 88 return {
89 89 'id': user.user_id,
90 90 'username': user.username,
91 91 'first_name': user.first_name,
92 92 'last_name': user.last_name,
93 93 'icon_link': h.gravatar_url(user.email, 60),
94 94 'display_name': h.person(user, 'username_or_name_or_email'),
95 95 'display_link': h.link_to_user(user),
96 96 'notifications': user.user_data.get('notification_status', True)
97 97 }
98 98
99 99
100 100 def broadcast_validator(channel_name):
101 101 """ checks if user can access the broadcast channel """
102 102 if channel_name == 'broadcast':
103 103 return True
104 104
105 105
106 106 def repo_validator(channel_name):
107 107 """ checks if user can access the broadcast channel """
108 108 channel_prefix = '/repo$'
109 109 if channel_name.startswith(channel_prefix):
110 110 elements = channel_name[len(channel_prefix):].split('$')
111 111 repo_name = elements[0]
112 112 can_access = HasRepoPermissionAny(
113 113 'repository.read',
114 114 'repository.write',
115 115 'repository.admin')(repo_name)
116 116 log.debug(
117 117 'permission check for %s channel resulted in %s',
118 118 repo_name, can_access)
119 119 if can_access:
120 120 return True
121 121 return False
122 122
123 123
124 124 def check_channel_permissions(channels, plugin_validators, should_raise=True):
125 125 valid_channels = []
126 126
127 127 validators = [broadcast_validator, repo_validator]
128 128 if plugin_validators:
129 129 validators.extend(plugin_validators)
130 130 for channel_name in channels:
131 131 is_valid = False
132 132 for validator in validators:
133 133 if validator(channel_name):
134 134 is_valid = True
135 135 break
136 136 if is_valid:
137 137 valid_channels.append(channel_name)
138 138 else:
139 139 if should_raise:
140 140 raise ChannelstreamPermissionException()
141 141 return valid_channels
142 142
143 143
144 144 def get_channels_info(self, channels):
145 145 payload = {'channels': channels}
146 146 # gather persistence info
147 147 return channelstream_request(self._config(), payload, '/info')
148 148
149 149
150 150 def parse_channels_info(info_result, include_channel_info=None):
151 151 """
152 152 Returns data that contains only secure information that can be
153 153 presented to clients
154 154 """
155 155 include_channel_info = include_channel_info or []
156 156
157 157 user_state_dict = {}
158 158 for userinfo in info_result['users']:
159 159 user_state_dict[userinfo['user']] = {
160 160 k: v for k, v in userinfo['state'].items()
161 161 if k in USER_STATE_PUBLIC_KEYS
162 162 }
163 163
164 164 channels_info = {}
165 165
166 166 for c_name, c_info in info_result['channels'].items():
167 167 if c_name not in include_channel_info:
168 168 continue
169 169 connected_list = []
170 170 for username in c_info['users']:
171 171 connected_list.append({
172 172 'user': username,
173 173 'state': user_state_dict[username]
174 174 })
175 175 channels_info[c_name] = {'users': connected_list,
176 176 'history': c_info['history']}
177 177
178 178 return channels_info
179 179
180 180
181 181 def log_filepath(history_location, channel_name):
182 182 hasher = hashlib.sha256()
183 183 hasher.update(channel_name.encode('utf8'))
184 184 filename = '{}.log'.format(hasher.hexdigest())
185 185 filepath = os.path.join(history_location, filename)
186 186 return filepath
187 187
188 188
189 189 def read_history(history_location, channel_name):
190 190 filepath = log_filepath(history_location, channel_name)
191 191 if not os.path.exists(filepath):
192 192 return []
193 193 history_lines_limit = -100
194 194 history = []
195 195 with open(filepath, 'rb') as f:
196 196 for line in f.readlines()[history_lines_limit:]:
197 197 try:
198 198 history.append(json.loads(line))
199 199 except Exception:
200 200 log.exception('Failed to load history')
201 201 return history
202 202
203 203
204 204 def update_history_from_logs(config, channels, payload):
205 205 history_location = config.get('history.location')
206 206 for channel in channels:
207 207 history = read_history(history_location, channel)
208 208 payload['channels_info'][channel]['history'] = history
209 209
210 210
211 211 def write_history(config, message):
212 212 """ writes a message to a base64encoded filename """
213 213 history_location = config.get('history.location')
214 214 if not os.path.exists(history_location):
215 215 return
216 216 try:
217 217 LOCK.acquire_write_lock()
218 218 filepath = log_filepath(history_location, message['channel'])
219 219 json_message = json.dumps(message)
220 220 with open(filepath, 'ab') as f:
221 221 f.write(json_message)
222 222 f.write('\n')
223 223 finally:
224 224 LOCK.release_write_lock()
225 225
226 226
227 227 def get_connection_validators(registry):
228 228 validators = []
229 229 for k, config in registry.rhodecode_plugins.items():
230 230 validator = config.get('channelstream', {}).get('connect_validator')
231 231 if validator:
232 232 validators.append(validator)
233 233 return validators
234 234
235 235
236 236 def get_channelstream_config(registry=None):
237 237 if not registry:
238 238 registry = get_current_registry()
239 239
240 240 rhodecode_plugins = getattr(registry, 'rhodecode_plugins', {})
241 241 channelstream_config = rhodecode_plugins.get('channelstream', {})
242 242 return channelstream_config
243 243
244 244
245 245 def post_message(channel, message, username, registry=None):
246 246 channelstream_config = get_channelstream_config(registry)
247 247 if not channelstream_config.get('enabled'):
248 248 return
249 249
250 250 message_obj = message
251 251 if isinstance(message, str):
252 252 message_obj = {
253 253 'message': message,
254 254 'level': 'success',
255 255 'topic': '/notifications'
256 256 }
257 257
258 258 log.debug('Channelstream: sending notification to channel %s', channel)
259 259 payload = {
260 260 'type': 'message',
261 261 'timestamp': datetime.datetime.utcnow(),
262 262 'user': 'system',
263 263 'exclude_users': [username],
264 264 'channel': channel,
265 265 'message': message_obj
266 266 }
267 267
268 268 try:
269 269 return channelstream_request(
270 270 channelstream_config, [payload], '/message',
271 271 raise_exc=False)
272 272 except ChannelstreamException:
273 273 log.exception('Failed to send channelstream data')
274 274 raise
275 275
276 276
277 277 def _reload_link(label):
278 278 return (
279 279 '<a onclick="window.location.reload()">'
280 280 '<strong>{}</strong>'
281 281 '</a>'.format(label)
282 282 )
283 283
284 284
285 285 def pr_channel(pull_request):
286 286 repo_name = pull_request.target_repo.repo_name
287 287 pull_request_id = pull_request.pull_request_id
288 288 channel = '/repo${}$/pr/{}'.format(repo_name, pull_request_id)
289 289 log.debug('Getting pull-request channelstream broadcast channel: %s', channel)
290 290 return channel
291 291
292 292
293 293 def comment_channel(repo_name, commit_obj=None, pull_request_obj=None):
294 294 channel = None
295 295 if commit_obj:
296 channel = u'/repo${}$/commit/{}'.format(
296 channel = '/repo${}$/commit/{}'.format(
297 297 repo_name, commit_obj.raw_id
298 298 )
299 299 elif pull_request_obj:
300 channel = u'/repo${}$/pr/{}'.format(
300 channel = '/repo${}$/pr/{}'.format(
301 301 repo_name, pull_request_obj.pull_request_id
302 302 )
303 303 log.debug('Getting comment channelstream broadcast channel: %s', channel)
304 304
305 305 return channel
306 306
307 307
308 308 def pr_update_channelstream_push(request, pr_broadcast_channel, user, msg, **kwargs):
309 309 """
310 310 Channel push on pull request update
311 311 """
312 312 if not pr_broadcast_channel:
313 313 return
314 314
315 315 _ = request.translate
316 316
317 317 message = '{} {}'.format(
318 318 msg,
319 319 _reload_link(_(' Reload page to load changes')))
320 320
321 321 message_obj = {
322 322 'message': message,
323 323 'level': 'success',
324 324 'topic': '/notifications'
325 325 }
326 326
327 327 post_message(
328 328 pr_broadcast_channel, message_obj, user.username,
329 329 registry=request.registry)
330 330
331 331
332 332 def comment_channelstream_push(request, comment_broadcast_channel, user, msg, **kwargs):
333 333 """
334 334 Channelstream push on comment action, on commit, or pull-request
335 335 """
336 336 if not comment_broadcast_channel:
337 337 return
338 338
339 339 _ = request.translate
340 340
341 341 comment_data = kwargs.pop('comment_data', {})
342 342 user_data = kwargs.pop('user_data', {})
343 343 comment_id = comment_data.keys()[0] if comment_data else ''
344 344
345 345 message = '<strong>{}</strong> {} #{}'.format(
346 346 user.username,
347 347 msg,
348 348 comment_id,
349 349 )
350 350
351 351 message_obj = {
352 352 'message': message,
353 353 'level': 'success',
354 354 'topic': '/notifications'
355 355 }
356 356
357 357 post_message(
358 358 comment_broadcast_channel, message_obj, user.username,
359 359 registry=request.registry)
360 360
361 361 message_obj = {
362 362 'message': None,
363 363 'user': user.username,
364 364 'comment_id': comment_id,
365 365 'comment_data': comment_data,
366 366 'user_data': user_data,
367 367 'topic': '/comment'
368 368 }
369 369 post_message(
370 370 comment_broadcast_channel, message_obj, user.username,
371 371 registry=request.registry)
@@ -1,797 +1,797 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import logging
22 22 import difflib
23 23 from itertools import groupby
24 24
25 25 from pygments import lex
26 26 from pygments.formatters.html import _get_ttype_class as pygment_token_class
27 27 from pygments.lexers.special import TextLexer, Token
28 28 from pygments.lexers import get_lexer_by_name
29 29
30 30 from rhodecode.lib.helpers import (
31 31 get_lexer_for_filenode, html_escape, get_custom_lexer)
32 32 from rhodecode.lib.utils2 import AttributeDict, StrictAttributeDict, safe_unicode
33 33 from rhodecode.lib.vcs.nodes import FileNode
34 34 from rhodecode.lib.vcs.exceptions import VCSError, NodeDoesNotExistError
35 35 from rhodecode.lib.diff_match_patch import diff_match_patch
36 36 from rhodecode.lib.diffs import LimitedDiffContainer, DEL_FILENODE, BIN_FILENODE
37 37
38 38
39 39 plain_text_lexer = get_lexer_by_name(
40 40 'text', stripall=False, stripnl=False, ensurenl=False)
41 41
42 42
43 43 log = logging.getLogger(__name__)
44 44
45 45
46 46 def filenode_as_lines_tokens(filenode, lexer=None):
47 47 org_lexer = lexer
48 48 lexer = lexer or get_lexer_for_filenode(filenode)
49 49 log.debug('Generating file node pygment tokens for %s, %s, org_lexer:%s',
50 50 lexer, filenode, org_lexer)
51 51 content = filenode.content
52 52 tokens = tokenize_string(content, lexer)
53 53 lines = split_token_stream(tokens, content)
54 54 rv = list(lines)
55 55 return rv
56 56
57 57
58 58 def tokenize_string(content, lexer):
59 59 """
60 60 Use pygments to tokenize some content based on a lexer
61 61 ensuring all original new lines and whitespace is preserved
62 62 """
63 63
64 64 lexer.stripall = False
65 65 lexer.stripnl = False
66 66 lexer.ensurenl = False
67 67
68 68 if isinstance(lexer, TextLexer):
69 69 lexed = [(Token.Text, content)]
70 70 else:
71 71 lexed = lex(content, lexer)
72 72
73 73 for token_type, token_text in lexed:
74 74 yield pygment_token_class(token_type), token_text
75 75
76 76
77 77 def split_token_stream(tokens, content):
78 78 """
79 79 Take a list of (TokenType, text) tuples and split them by a string
80 80
81 81 split_token_stream([(TEXT, 'some\ntext'), (TEXT, 'more\n')])
82 82 [(TEXT, 'some'), (TEXT, 'text'),
83 83 (TEXT, 'more'), (TEXT, 'text')]
84 84 """
85 85
86 86 token_buffer = []
87 87 for token_class, token_text in tokens:
88 88 parts = token_text.split('\n')
89 89 for part in parts[:-1]:
90 90 token_buffer.append((token_class, part))
91 91 yield token_buffer
92 92 token_buffer = []
93 93
94 94 token_buffer.append((token_class, parts[-1]))
95 95
96 96 if token_buffer:
97 97 yield token_buffer
98 98 elif content:
99 99 # this is a special case, we have the content, but tokenization didn't produce
100 100 # any results. THis can happen if know file extensions like .css have some bogus
101 101 # unicode content without any newline characters
102 102 yield [(pygment_token_class(Token.Text), content)]
103 103
104 104
105 105 def filenode_as_annotated_lines_tokens(filenode):
106 106 """
107 107 Take a file node and return a list of annotations => lines, if no annotation
108 108 is found, it will be None.
109 109
110 110 eg:
111 111
112 112 [
113 113 (annotation1, [
114 114 (1, line1_tokens_list),
115 115 (2, line2_tokens_list),
116 116 ]),
117 117 (annotation2, [
118 118 (3, line1_tokens_list),
119 119 ]),
120 120 (None, [
121 121 (4, line1_tokens_list),
122 122 ]),
123 123 (annotation1, [
124 124 (5, line1_tokens_list),
125 125 (6, line2_tokens_list),
126 126 ])
127 127 ]
128 128 """
129 129
130 130 commit_cache = {} # cache commit_getter lookups
131 131
132 132 def _get_annotation(commit_id, commit_getter):
133 133 if commit_id not in commit_cache:
134 134 commit_cache[commit_id] = commit_getter()
135 135 return commit_cache[commit_id]
136 136
137 137 annotation_lookup = {
138 138 line_no: _get_annotation(commit_id, commit_getter)
139 139 for line_no, commit_id, commit_getter, line_content
140 140 in filenode.annotate
141 141 }
142 142
143 143 annotations_lines = ((annotation_lookup.get(line_no), line_no, tokens)
144 144 for line_no, tokens
145 145 in enumerate(filenode_as_lines_tokens(filenode), 1))
146 146
147 147 grouped_annotations_lines = groupby(annotations_lines, lambda x: x[0])
148 148
149 149 for annotation, group in grouped_annotations_lines:
150 150 yield (
151 151 annotation, [(line_no, tokens)
152 152 for (_, line_no, tokens) in group]
153 153 )
154 154
155 155
156 156 def render_tokenstream(tokenstream):
157 157 result = []
158 158 for token_class, token_ops_texts in rollup_tokenstream(tokenstream):
159 159
160 160 if token_class:
161 result.append(u'<span class="%s">' % token_class)
161 result.append('<span class="%s">' % token_class)
162 162 else:
163 result.append(u'<span>')
163 result.append('<span>')
164 164
165 165 for op_tag, token_text in token_ops_texts:
166 166
167 167 if op_tag:
168 result.append(u'<%s>' % op_tag)
168 result.append('<%s>' % op_tag)
169 169
170 170 # NOTE(marcink): in some cases of mixed encodings, we might run into
171 171 # troubles in the html_escape, in this case we say unicode force on token_text
172 172 # that would ensure "correct" data even with the cost of rendered
173 173 try:
174 174 escaped_text = html_escape(token_text)
175 175 except TypeError:
176 176 escaped_text = html_escape(safe_unicode(token_text))
177 177
178 178 # TODO: dan: investigate showing hidden characters like space/nl/tab
179 179 # escaped_text = escaped_text.replace(' ', '<sp> </sp>')
180 180 # escaped_text = escaped_text.replace('\n', '<nl>\n</nl>')
181 181 # escaped_text = escaped_text.replace('\t', '<tab>\t</tab>')
182 182
183 183 result.append(escaped_text)
184 184
185 185 if op_tag:
186 result.append(u'</%s>' % op_tag)
186 result.append('</%s>' % op_tag)
187 187
188 result.append(u'</span>')
188 result.append('</span>')
189 189
190 190 html = ''.join(result)
191 191 return html
192 192
193 193
194 194 def rollup_tokenstream(tokenstream):
195 195 """
196 196 Group a token stream of the format:
197 197
198 198 ('class', 'op', 'text')
199 199 or
200 200 ('class', 'text')
201 201
202 202 into
203 203
204 204 [('class1',
205 205 [('op1', 'text'),
206 206 ('op2', 'text')]),
207 207 ('class2',
208 208 [('op3', 'text')])]
209 209
210 210 This is used to get the minimal tags necessary when
211 211 rendering to html eg for a token stream ie.
212 212
213 213 <span class="A"><ins>he</ins>llo</span>
214 214 vs
215 215 <span class="A"><ins>he</ins></span><span class="A">llo</span>
216 216
217 217 If a 2 tuple is passed in, the output op will be an empty string.
218 218
219 219 eg:
220 220
221 221 >>> rollup_tokenstream([('classA', '', 'h'),
222 222 ('classA', 'del', 'ell'),
223 223 ('classA', '', 'o'),
224 224 ('classB', '', ' '),
225 225 ('classA', '', 'the'),
226 226 ('classA', '', 're'),
227 227 ])
228 228
229 229 [('classA', [('', 'h'), ('del', 'ell'), ('', 'o')],
230 230 ('classB', [('', ' ')],
231 231 ('classA', [('', 'there')]]
232 232
233 233 """
234 234 if tokenstream and len(tokenstream[0]) == 2:
235 235 tokenstream = ((t[0], '', t[1]) for t in tokenstream)
236 236
237 237 result = []
238 238 for token_class, op_list in groupby(tokenstream, lambda t: t[0]):
239 239 ops = []
240 240 for token_op, token_text_list in groupby(op_list, lambda o: o[1]):
241 241 text_buffer = []
242 242 for t_class, t_op, t_text in token_text_list:
243 243 text_buffer.append(t_text)
244 244 ops.append((token_op, ''.join(text_buffer)))
245 245 result.append((token_class, ops))
246 246 return result
247 247
248 248
249 249 def tokens_diff(old_tokens, new_tokens, use_diff_match_patch=True):
250 250 """
251 251 Converts a list of (token_class, token_text) tuples to a list of
252 252 (token_class, token_op, token_text) tuples where token_op is one of
253 253 ('ins', 'del', '')
254 254
255 255 :param old_tokens: list of (token_class, token_text) tuples of old line
256 256 :param new_tokens: list of (token_class, token_text) tuples of new line
257 257 :param use_diff_match_patch: boolean, will use google's diff match patch
258 258 library which has options to 'smooth' out the character by character
259 259 differences making nicer ins/del blocks
260 260 """
261 261
262 262 old_tokens_result = []
263 263 new_tokens_result = []
264 264
265 265 similarity = difflib.SequenceMatcher(None,
266 266 ''.join(token_text for token_class, token_text in old_tokens),
267 267 ''.join(token_text for token_class, token_text in new_tokens)
268 268 ).ratio()
269 269
270 270 if similarity < 0.6: # return, the blocks are too different
271 271 for token_class, token_text in old_tokens:
272 272 old_tokens_result.append((token_class, '', token_text))
273 273 for token_class, token_text in new_tokens:
274 274 new_tokens_result.append((token_class, '', token_text))
275 275 return old_tokens_result, new_tokens_result, similarity
276 276
277 277 token_sequence_matcher = difflib.SequenceMatcher(None,
278 278 [x[1] for x in old_tokens],
279 279 [x[1] for x in new_tokens])
280 280
281 281 for tag, o1, o2, n1, n2 in token_sequence_matcher.get_opcodes():
282 282 # check the differences by token block types first to give a more
283 283 # nicer "block" level replacement vs character diffs
284 284
285 285 if tag == 'equal':
286 286 for token_class, token_text in old_tokens[o1:o2]:
287 287 old_tokens_result.append((token_class, '', token_text))
288 288 for token_class, token_text in new_tokens[n1:n2]:
289 289 new_tokens_result.append((token_class, '', token_text))
290 290 elif tag == 'delete':
291 291 for token_class, token_text in old_tokens[o1:o2]:
292 292 old_tokens_result.append((token_class, 'del', token_text))
293 293 elif tag == 'insert':
294 294 for token_class, token_text in new_tokens[n1:n2]:
295 295 new_tokens_result.append((token_class, 'ins', token_text))
296 296 elif tag == 'replace':
297 297 # if same type token blocks must be replaced, do a diff on the
298 298 # characters in the token blocks to show individual changes
299 299
300 300 old_char_tokens = []
301 301 new_char_tokens = []
302 302 for token_class, token_text in old_tokens[o1:o2]:
303 303 for char in token_text:
304 304 old_char_tokens.append((token_class, char))
305 305
306 306 for token_class, token_text in new_tokens[n1:n2]:
307 307 for char in token_text:
308 308 new_char_tokens.append((token_class, char))
309 309
310 310 old_string = ''.join([token_text for
311 311 token_class, token_text in old_char_tokens])
312 312 new_string = ''.join([token_text for
313 313 token_class, token_text in new_char_tokens])
314 314
315 315 char_sequence = difflib.SequenceMatcher(
316 316 None, old_string, new_string)
317 317 copcodes = char_sequence.get_opcodes()
318 318 obuffer, nbuffer = [], []
319 319
320 320 if use_diff_match_patch:
321 321 dmp = diff_match_patch()
322 322 dmp.Diff_EditCost = 11 # TODO: dan: extract this to a setting
323 323 reps = dmp.diff_main(old_string, new_string)
324 324 dmp.diff_cleanupEfficiency(reps)
325 325
326 326 a, b = 0, 0
327 327 for op, rep in reps:
328 328 l = len(rep)
329 329 if op == 0:
330 330 for i, c in enumerate(rep):
331 331 obuffer.append((old_char_tokens[a+i][0], '', c))
332 332 nbuffer.append((new_char_tokens[b+i][0], '', c))
333 333 a += l
334 334 b += l
335 335 elif op == -1:
336 336 for i, c in enumerate(rep):
337 337 obuffer.append((old_char_tokens[a+i][0], 'del', c))
338 338 a += l
339 339 elif op == 1:
340 340 for i, c in enumerate(rep):
341 341 nbuffer.append((new_char_tokens[b+i][0], 'ins', c))
342 342 b += l
343 343 else:
344 344 for ctag, co1, co2, cn1, cn2 in copcodes:
345 345 if ctag == 'equal':
346 346 for token_class, token_text in old_char_tokens[co1:co2]:
347 347 obuffer.append((token_class, '', token_text))
348 348 for token_class, token_text in new_char_tokens[cn1:cn2]:
349 349 nbuffer.append((token_class, '', token_text))
350 350 elif ctag == 'delete':
351 351 for token_class, token_text in old_char_tokens[co1:co2]:
352 352 obuffer.append((token_class, 'del', token_text))
353 353 elif ctag == 'insert':
354 354 for token_class, token_text in new_char_tokens[cn1:cn2]:
355 355 nbuffer.append((token_class, 'ins', token_text))
356 356 elif ctag == 'replace':
357 357 for token_class, token_text in old_char_tokens[co1:co2]:
358 358 obuffer.append((token_class, 'del', token_text))
359 359 for token_class, token_text in new_char_tokens[cn1:cn2]:
360 360 nbuffer.append((token_class, 'ins', token_text))
361 361
362 362 old_tokens_result.extend(obuffer)
363 363 new_tokens_result.extend(nbuffer)
364 364
365 365 return old_tokens_result, new_tokens_result, similarity
366 366
367 367
368 368 def diffset_node_getter(commit):
369 369 def get_node(fname):
370 370 try:
371 371 return commit.get_node(fname)
372 372 except NodeDoesNotExistError:
373 373 return None
374 374
375 375 return get_node
376 376
377 377
378 378 class DiffSet(object):
379 379 """
380 380 An object for parsing the diff result from diffs.DiffProcessor and
381 381 adding highlighting, side by side/unified renderings and line diffs
382 382 """
383 383
384 384 HL_REAL = 'REAL' # highlights using original file, slow
385 385 HL_FAST = 'FAST' # highlights using just the line, fast but not correct
386 386 # in the case of multiline code
387 387 HL_NONE = 'NONE' # no highlighting, fastest
388 388
389 389 def __init__(self, highlight_mode=HL_REAL, repo_name=None,
390 390 source_repo_name=None,
391 391 source_node_getter=lambda filename: None,
392 392 target_repo_name=None,
393 393 target_node_getter=lambda filename: None,
394 394 source_nodes=None, target_nodes=None,
395 395 # files over this size will use fast highlighting
396 396 max_file_size_limit=150 * 1024,
397 397 ):
398 398
399 399 self.highlight_mode = highlight_mode
400 400 self.highlighted_filenodes = {
401 401 'before': {},
402 402 'after': {}
403 403 }
404 404 self.source_node_getter = source_node_getter
405 405 self.target_node_getter = target_node_getter
406 406 self.source_nodes = source_nodes or {}
407 407 self.target_nodes = target_nodes or {}
408 408 self.repo_name = repo_name
409 409 self.target_repo_name = target_repo_name or repo_name
410 410 self.source_repo_name = source_repo_name or repo_name
411 411 self.max_file_size_limit = max_file_size_limit
412 412
413 413 def render_patchset(self, patchset, source_ref=None, target_ref=None):
414 414 diffset = AttributeDict(dict(
415 415 lines_added=0,
416 416 lines_deleted=0,
417 417 changed_files=0,
418 418 files=[],
419 419 file_stats={},
420 420 limited_diff=isinstance(patchset, LimitedDiffContainer),
421 421 repo_name=self.repo_name,
422 422 target_repo_name=self.target_repo_name,
423 423 source_repo_name=self.source_repo_name,
424 424 source_ref=source_ref,
425 425 target_ref=target_ref,
426 426 ))
427 427 for patch in patchset:
428 428 diffset.file_stats[patch['filename']] = patch['stats']
429 429 filediff = self.render_patch(patch)
430 430 filediff.diffset = StrictAttributeDict(dict(
431 431 source_ref=diffset.source_ref,
432 432 target_ref=diffset.target_ref,
433 433 repo_name=diffset.repo_name,
434 434 source_repo_name=diffset.source_repo_name,
435 435 target_repo_name=diffset.target_repo_name,
436 436 ))
437 437 diffset.files.append(filediff)
438 438 diffset.changed_files += 1
439 439 if not patch['stats']['binary']:
440 440 diffset.lines_added += patch['stats']['added']
441 441 diffset.lines_deleted += patch['stats']['deleted']
442 442
443 443 return diffset
444 444
445 445 _lexer_cache = {}
446 446
447 447 def _get_lexer_for_filename(self, filename, filenode=None):
448 448 # cached because we might need to call it twice for source/target
449 449 if filename not in self._lexer_cache:
450 450 if filenode:
451 451 lexer = filenode.lexer
452 452 extension = filenode.extension
453 453 else:
454 454 lexer = FileNode.get_lexer(filename=filename)
455 455 extension = filename.split('.')[-1]
456 456
457 457 lexer = get_custom_lexer(extension) or lexer
458 458 self._lexer_cache[filename] = lexer
459 459 return self._lexer_cache[filename]
460 460
461 461 def render_patch(self, patch):
462 462 log.debug('rendering diff for %r', patch['filename'])
463 463
464 464 source_filename = patch['original_filename']
465 465 target_filename = patch['filename']
466 466
467 467 source_lexer = plain_text_lexer
468 468 target_lexer = plain_text_lexer
469 469
470 470 if not patch['stats']['binary']:
471 471 node_hl_mode = self.HL_NONE if patch['chunks'] == [] else None
472 472 hl_mode = node_hl_mode or self.highlight_mode
473 473
474 474 if hl_mode == self.HL_REAL:
475 475 if (source_filename and patch['operation'] in ('D', 'M')
476 476 and source_filename not in self.source_nodes):
477 477 self.source_nodes[source_filename] = (
478 478 self.source_node_getter(source_filename))
479 479
480 480 if (target_filename and patch['operation'] in ('A', 'M')
481 481 and target_filename not in self.target_nodes):
482 482 self.target_nodes[target_filename] = (
483 483 self.target_node_getter(target_filename))
484 484
485 485 elif hl_mode == self.HL_FAST:
486 486 source_lexer = self._get_lexer_for_filename(source_filename)
487 487 target_lexer = self._get_lexer_for_filename(target_filename)
488 488
489 489 source_file = self.source_nodes.get(source_filename, source_filename)
490 490 target_file = self.target_nodes.get(target_filename, target_filename)
491 491 raw_id_uid = ''
492 492 if self.source_nodes.get(source_filename):
493 493 raw_id_uid = self.source_nodes[source_filename].commit.raw_id
494 494
495 495 if not raw_id_uid and self.target_nodes.get(target_filename):
496 496 # in case this is a new file we only have it in target
497 497 raw_id_uid = self.target_nodes[target_filename].commit.raw_id
498 498
499 499 source_filenode, target_filenode = None, None
500 500
501 501 # TODO: dan: FileNode.lexer works on the content of the file - which
502 502 # can be slow - issue #4289 explains a lexer clean up - which once
503 503 # done can allow caching a lexer for a filenode to avoid the file lookup
504 504 if isinstance(source_file, FileNode):
505 505 source_filenode = source_file
506 506 #source_lexer = source_file.lexer
507 507 source_lexer = self._get_lexer_for_filename(source_filename)
508 508 source_file.lexer = source_lexer
509 509
510 510 if isinstance(target_file, FileNode):
511 511 target_filenode = target_file
512 512 #target_lexer = target_file.lexer
513 513 target_lexer = self._get_lexer_for_filename(target_filename)
514 514 target_file.lexer = target_lexer
515 515
516 516 source_file_path, target_file_path = None, None
517 517
518 518 if source_filename != '/dev/null':
519 519 source_file_path = source_filename
520 520 if target_filename != '/dev/null':
521 521 target_file_path = target_filename
522 522
523 523 source_file_type = source_lexer.name
524 524 target_file_type = target_lexer.name
525 525
526 526 filediff = AttributeDict({
527 527 'source_file_path': source_file_path,
528 528 'target_file_path': target_file_path,
529 529 'source_filenode': source_filenode,
530 530 'target_filenode': target_filenode,
531 531 'source_file_type': target_file_type,
532 532 'target_file_type': source_file_type,
533 533 'patch': {'filename': patch['filename'], 'stats': patch['stats']},
534 534 'operation': patch['operation'],
535 535 'source_mode': patch['stats']['old_mode'],
536 536 'target_mode': patch['stats']['new_mode'],
537 537 'limited_diff': patch['is_limited_diff'],
538 538 'hunks': [],
539 539 'hunk_ops': None,
540 540 'diffset': self,
541 541 'raw_id': raw_id_uid,
542 542 })
543 543
544 544 file_chunks = patch['chunks'][1:]
545 545 for i, hunk in enumerate(file_chunks, 1):
546 546 hunkbit = self.parse_hunk(hunk, source_file, target_file)
547 547 hunkbit.source_file_path = source_file_path
548 548 hunkbit.target_file_path = target_file_path
549 549 hunkbit.index = i
550 550 filediff.hunks.append(hunkbit)
551 551
552 552 # Simulate hunk on OPS type line which doesn't really contain any diff
553 553 # this allows commenting on those
554 554 if not file_chunks:
555 555 actions = []
556 556 for op_id, op_text in filediff.patch['stats']['ops'].items():
557 557 if op_id == DEL_FILENODE:
558 actions.append(u'file was removed')
558 actions.append('file was removed')
559 559 elif op_id == BIN_FILENODE:
560 actions.append(u'binary diff hidden')
560 actions.append('binary diff hidden')
561 561 else:
562 562 actions.append(safe_unicode(op_text))
563 action_line = u'NO CONTENT: ' + \
564 u', '.join(actions) or u'UNDEFINED_ACTION'
563 action_line = 'NO CONTENT: ' + \
564 ', '.join(actions) or 'UNDEFINED_ACTION'
565 565
566 566 hunk_ops = {'source_length': 0, 'source_start': 0,
567 567 'lines': [
568 568 {'new_lineno': 0, 'old_lineno': 1,
569 569 'action': 'unmod-no-hl', 'line': action_line}
570 570 ],
571 'section_header': u'', 'target_start': 1, 'target_length': 1}
571 'section_header': '', 'target_start': 1, 'target_length': 1}
572 572
573 573 hunkbit = self.parse_hunk(hunk_ops, source_file, target_file)
574 574 hunkbit.source_file_path = source_file_path
575 575 hunkbit.target_file_path = target_file_path
576 576 filediff.hunk_ops = hunkbit
577 577 return filediff
578 578
579 579 def parse_hunk(self, hunk, source_file, target_file):
580 580 result = AttributeDict(dict(
581 581 source_start=hunk['source_start'],
582 582 source_length=hunk['source_length'],
583 583 target_start=hunk['target_start'],
584 584 target_length=hunk['target_length'],
585 585 section_header=hunk['section_header'],
586 586 lines=[],
587 587 ))
588 588 before, after = [], []
589 589
590 590 for line in hunk['lines']:
591 591 if line['action'] in ['unmod', 'unmod-no-hl']:
592 592 no_hl = line['action'] == 'unmod-no-hl'
593 593 result.lines.extend(
594 594 self.parse_lines(before, after, source_file, target_file, no_hl=no_hl))
595 595 after.append(line)
596 596 before.append(line)
597 597 elif line['action'] == 'add':
598 598 after.append(line)
599 599 elif line['action'] == 'del':
600 600 before.append(line)
601 601 elif line['action'] == 'old-no-nl':
602 602 before.append(line)
603 603 elif line['action'] == 'new-no-nl':
604 604 after.append(line)
605 605
606 606 all_actions = [x['action'] for x in after] + [x['action'] for x in before]
607 607 no_hl = {x for x in all_actions} == {'unmod-no-hl'}
608 608 result.lines.extend(
609 609 self.parse_lines(before, after, source_file, target_file, no_hl=no_hl))
610 610 # NOTE(marcink): we must keep list() call here so we can cache the result...
611 611 result.unified = list(self.as_unified(result.lines))
612 612 result.sideside = result.lines
613 613
614 614 return result
615 615
616 616 def parse_lines(self, before_lines, after_lines, source_file, target_file,
617 617 no_hl=False):
618 618 # TODO: dan: investigate doing the diff comparison and fast highlighting
619 619 # on the entire before and after buffered block lines rather than by
620 620 # line, this means we can get better 'fast' highlighting if the context
621 621 # allows it - eg.
622 622 # line 4: """
623 623 # line 5: this gets highlighted as a string
624 624 # line 6: """
625 625
626 626 lines = []
627 627
628 628 before_newline = AttributeDict()
629 629 after_newline = AttributeDict()
630 630 if before_lines and before_lines[-1]['action'] == 'old-no-nl':
631 631 before_newline_line = before_lines.pop(-1)
632 632 before_newline.content = '\n {}'.format(
633 633 render_tokenstream(
634 634 [(x[0], '', x[1])
635 635 for x in [('nonl', before_newline_line['line'])]]))
636 636
637 637 if after_lines and after_lines[-1]['action'] == 'new-no-nl':
638 638 after_newline_line = after_lines.pop(-1)
639 639 after_newline.content = '\n {}'.format(
640 640 render_tokenstream(
641 641 [(x[0], '', x[1])
642 642 for x in [('nonl', after_newline_line['line'])]]))
643 643
644 644 while before_lines or after_lines:
645 645 before, after = None, None
646 646 before_tokens, after_tokens = None, None
647 647
648 648 if before_lines:
649 649 before = before_lines.pop(0)
650 650 if after_lines:
651 651 after = after_lines.pop(0)
652 652
653 653 original = AttributeDict()
654 654 modified = AttributeDict()
655 655
656 656 if before:
657 657 if before['action'] == 'old-no-nl':
658 658 before_tokens = [('nonl', before['line'])]
659 659 else:
660 660 before_tokens = self.get_line_tokens(
661 661 line_text=before['line'], line_number=before['old_lineno'],
662 662 input_file=source_file, no_hl=no_hl, source='before')
663 663 original.lineno = before['old_lineno']
664 664 original.content = before['line']
665 665 original.action = self.action_to_op(before['action'])
666 666
667 667 original.get_comment_args = (
668 668 source_file, 'o', before['old_lineno'])
669 669
670 670 if after:
671 671 if after['action'] == 'new-no-nl':
672 672 after_tokens = [('nonl', after['line'])]
673 673 else:
674 674 after_tokens = self.get_line_tokens(
675 675 line_text=after['line'], line_number=after['new_lineno'],
676 676 input_file=target_file, no_hl=no_hl, source='after')
677 677 modified.lineno = after['new_lineno']
678 678 modified.content = after['line']
679 679 modified.action = self.action_to_op(after['action'])
680 680
681 681 modified.get_comment_args = (target_file, 'n', after['new_lineno'])
682 682
683 683 # diff the lines
684 684 if before_tokens and after_tokens:
685 685 o_tokens, m_tokens, similarity = tokens_diff(
686 686 before_tokens, after_tokens)
687 687 original.content = render_tokenstream(o_tokens)
688 688 modified.content = render_tokenstream(m_tokens)
689 689 elif before_tokens:
690 690 original.content = render_tokenstream(
691 691 [(x[0], '', x[1]) for x in before_tokens])
692 692 elif after_tokens:
693 693 modified.content = render_tokenstream(
694 694 [(x[0], '', x[1]) for x in after_tokens])
695 695
696 696 if not before_lines and before_newline:
697 697 original.content += before_newline.content
698 698 before_newline = None
699 699 if not after_lines and after_newline:
700 700 modified.content += after_newline.content
701 701 after_newline = None
702 702
703 703 lines.append(AttributeDict({
704 704 'original': original,
705 705 'modified': modified,
706 706 }))
707 707
708 708 return lines
709 709
710 710 def get_line_tokens(self, line_text, line_number, input_file=None, no_hl=False, source=''):
711 711 filenode = None
712 712 filename = None
713 713
714 714 if isinstance(input_file, str):
715 715 filename = input_file
716 716 elif isinstance(input_file, FileNode):
717 717 filenode = input_file
718 718 filename = input_file.unicode_path
719 719
720 720 hl_mode = self.HL_NONE if no_hl else self.highlight_mode
721 721 if hl_mode == self.HL_REAL and filenode:
722 722 lexer = self._get_lexer_for_filename(filename)
723 723 file_size_allowed = input_file.size < self.max_file_size_limit
724 724 if line_number and file_size_allowed:
725 725 return self.get_tokenized_filenode_line(input_file, line_number, lexer, source)
726 726
727 727 if hl_mode in (self.HL_REAL, self.HL_FAST) and filename:
728 728 lexer = self._get_lexer_for_filename(filename)
729 729 return list(tokenize_string(line_text, lexer))
730 730
731 731 return list(tokenize_string(line_text, plain_text_lexer))
732 732
733 733 def get_tokenized_filenode_line(self, filenode, line_number, lexer=None, source=''):
734 734
735 735 def tokenize(_filenode):
736 736 self.highlighted_filenodes[source][filenode] = filenode_as_lines_tokens(filenode, lexer)
737 737
738 738 if filenode not in self.highlighted_filenodes[source]:
739 739 tokenize(filenode)
740 740
741 741 try:
742 742 return self.highlighted_filenodes[source][filenode][line_number - 1]
743 743 except Exception:
744 744 log.exception('diff rendering error')
745 return [('', u'L{}: rhodecode diff rendering error'.format(line_number))]
745 return [('', 'L{}: rhodecode diff rendering error'.format(line_number))]
746 746
747 747 def action_to_op(self, action):
748 748 return {
749 749 'add': '+',
750 750 'del': '-',
751 751 'unmod': ' ',
752 752 'unmod-no-hl': ' ',
753 753 'old-no-nl': ' ',
754 754 'new-no-nl': ' ',
755 755 }.get(action, action)
756 756
757 757 def as_unified(self, lines):
758 758 """
759 759 Return a generator that yields the lines of a diff in unified order
760 760 """
761 761 def generator():
762 762 buf = []
763 763 for line in lines:
764 764
765 765 if buf and not line.original or line.original.action == ' ':
766 766 for b in buf:
767 767 yield b
768 768 buf = []
769 769
770 770 if line.original:
771 771 if line.original.action == ' ':
772 772 yield (line.original.lineno, line.modified.lineno,
773 773 line.original.action, line.original.content,
774 774 line.original.get_comment_args)
775 775 continue
776 776
777 777 if line.original.action == '-':
778 778 yield (line.original.lineno, None,
779 779 line.original.action, line.original.content,
780 780 line.original.get_comment_args)
781 781
782 782 if line.modified.action == '+':
783 783 buf.append((
784 784 None, line.modified.lineno,
785 785 line.modified.action, line.modified.content,
786 786 line.modified.get_comment_args))
787 787 continue
788 788
789 789 if line.modified:
790 790 yield (None, line.modified.lineno,
791 791 line.modified.action, line.modified.content,
792 792 line.modified.get_comment_args)
793 793
794 794 for b in buf:
795 795 yield b
796 796
797 797 return generator()
@@ -1,680 +1,680 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 """
22 22 Database creation, and setup module for RhodeCode Enterprise. Used for creation
23 23 of database as well as for migration operations
24 24 """
25 25
26 26 import os
27 27 import sys
28 28 import time
29 29 import uuid
30 30 import logging
31 31 import getpass
32 32 from os.path import dirname as dn, join as jn
33 33
34 34 from sqlalchemy.engine import create_engine
35 35
36 36 from rhodecode import __dbversion__
37 37 from rhodecode.model import init_model
38 38 from rhodecode.model.user import UserModel
39 39 from rhodecode.model.db import (
40 40 User, Permission, RhodeCodeUi, RhodeCodeSetting, UserToPerm,
41 41 DbMigrateVersion, RepoGroup, UserRepoGroupToPerm, CacheKey, Repository)
42 42 from rhodecode.model.meta import Session, Base
43 43 from rhodecode.model.permission import PermissionModel
44 44 from rhodecode.model.repo import RepoModel
45 45 from rhodecode.model.repo_group import RepoGroupModel
46 46 from rhodecode.model.settings import SettingsModel
47 47
48 48
49 49 log = logging.getLogger(__name__)
50 50
51 51
52 52 def notify(msg):
53 53 """
54 54 Notification for migrations messages
55 55 """
56 56 ml = len(msg) + (4 * 2)
57 57 print(('\n%s\n*** %s ***\n%s' % ('*' * ml, msg, '*' * ml)).upper())
58 58
59 59
60 60 class DbManage(object):
61 61
62 62 def __init__(self, log_sql, dbconf, root, tests=False,
63 63 SESSION=None, cli_args=None):
64 64 self.dbname = dbconf.split('/')[-1]
65 65 self.tests = tests
66 66 self.root = root
67 67 self.dburi = dbconf
68 68 self.log_sql = log_sql
69 69 self.cli_args = cli_args or {}
70 70 self.init_db(SESSION=SESSION)
71 71 self.ask_ok = self.get_ask_ok_func(self.cli_args.get('force_ask'))
72 72
73 73 def db_exists(self):
74 74 if not self.sa:
75 75 self.init_db()
76 76 try:
77 77 self.sa.query(RhodeCodeUi)\
78 78 .filter(RhodeCodeUi.ui_key == '/')\
79 79 .scalar()
80 80 return True
81 81 except Exception:
82 82 return False
83 83 finally:
84 84 self.sa.rollback()
85 85
86 86 def get_ask_ok_func(self, param):
87 87 if param not in [None]:
88 88 # return a function lambda that has a default set to param
89 89 return lambda *args, **kwargs: param
90 90 else:
91 91 from rhodecode.lib.utils import ask_ok
92 92 return ask_ok
93 93
94 94 def init_db(self, SESSION=None):
95 95 if SESSION:
96 96 self.sa = SESSION
97 97 else:
98 98 # init new sessions
99 99 engine = create_engine(self.dburi, echo=self.log_sql)
100 100 init_model(engine)
101 101 self.sa = Session()
102 102
103 103 def create_tables(self, override=False):
104 104 """
105 105 Create a auth database
106 106 """
107 107
108 108 log.info("Existing database with the same name is going to be destroyed.")
109 109 log.info("Setup command will run DROP ALL command on that database.")
110 110 if self.tests:
111 111 destroy = True
112 112 else:
113 113 destroy = self.ask_ok('Are you sure that you want to destroy the old database? [y/n]')
114 114 if not destroy:
115 115 log.info('db tables bootstrap: Nothing done.')
116 116 sys.exit(0)
117 117 if destroy:
118 118 Base.metadata.drop_all()
119 119
120 120 checkfirst = not override
121 121 Base.metadata.create_all(checkfirst=checkfirst)
122 122 log.info('Created tables for %s', self.dbname)
123 123
124 124 def set_db_version(self):
125 125 ver = DbMigrateVersion()
126 126 ver.version = __dbversion__
127 127 ver.repository_id = 'rhodecode_db_migrations'
128 128 ver.repository_path = 'versions'
129 129 self.sa.add(ver)
130 130 log.info('db version set to: %s', __dbversion__)
131 131
132 132 def run_post_migration_tasks(self):
133 133 """
134 134 Run various tasks before actually doing migrations
135 135 """
136 136 # delete cache keys on each upgrade
137 137 total = CacheKey.query().count()
138 138 log.info("Deleting (%s) cache keys now...", total)
139 139 CacheKey.delete_all_cache()
140 140
141 141 def upgrade(self, version=None):
142 142 """
143 143 Upgrades given database schema to given revision following
144 144 all needed steps, to perform the upgrade
145 145
146 146 """
147 147
148 148 from rhodecode.lib.dbmigrate.migrate.versioning import api
149 149 from rhodecode.lib.dbmigrate.migrate.exceptions import \
150 150 DatabaseNotControlledError
151 151
152 152 if 'sqlite' in self.dburi:
153 153 print(
154 154 '********************** WARNING **********************\n'
155 155 'Make sure your version of sqlite is at least 3.7.X. \n'
156 156 'Earlier versions are known to fail on some migrations\n'
157 157 '*****************************************************\n')
158 158
159 159 upgrade = self.ask_ok(
160 160 'You are about to perform a database upgrade. Make '
161 161 'sure you have backed up your database. '
162 162 'Continue ? [y/n]')
163 163 if not upgrade:
164 164 log.info('No upgrade performed')
165 165 sys.exit(0)
166 166
167 167 repository_path = jn(dn(dn(dn(os.path.realpath(__file__)))),
168 168 'rhodecode/lib/dbmigrate')
169 169 db_uri = self.dburi
170 170
171 171 if version:
172 172 DbMigrateVersion.set_version(version)
173 173
174 174 try:
175 175 curr_version = api.db_version(db_uri, repository_path)
176 176 msg = ('Found current database db_uri under version '
177 177 'control with version {}'.format(curr_version))
178 178
179 179 except (RuntimeError, DatabaseNotControlledError):
180 180 curr_version = 1
181 181 msg = ('Current database is not under version control. Setting '
182 182 'as version %s' % curr_version)
183 183 api.version_control(db_uri, repository_path, curr_version)
184 184
185 185 notify(msg)
186 186
187 187
188 188 if curr_version == __dbversion__:
189 189 log.info('This database is already at the newest version')
190 190 sys.exit(0)
191 191
192 192 upgrade_steps = range(curr_version + 1, __dbversion__ + 1)
193 193 notify('attempting to upgrade database from '
194 194 'version %s to version %s' % (curr_version, __dbversion__))
195 195
196 196 # CALL THE PROPER ORDER OF STEPS TO PERFORM FULL UPGRADE
197 197 _step = None
198 198 for step in upgrade_steps:
199 199 notify('performing upgrade step %s' % step)
200 200 time.sleep(0.5)
201 201
202 202 api.upgrade(db_uri, repository_path, step)
203 203 self.sa.rollback()
204 204 notify('schema upgrade for step %s completed' % (step,))
205 205
206 206 _step = step
207 207
208 208 self.run_post_migration_tasks()
209 209 notify('upgrade to version %s successful' % _step)
210 210
211 211 def fix_repo_paths(self):
212 212 """
213 213 Fixes an old RhodeCode version path into new one without a '*'
214 214 """
215 215
216 216 paths = self.sa.query(RhodeCodeUi)\
217 217 .filter(RhodeCodeUi.ui_key == '/')\
218 218 .scalar()
219 219
220 220 paths.ui_value = paths.ui_value.replace('*', '')
221 221
222 222 try:
223 223 self.sa.add(paths)
224 224 self.sa.commit()
225 225 except Exception:
226 226 self.sa.rollback()
227 227 raise
228 228
229 229 def fix_default_user(self):
230 230 """
231 231 Fixes an old default user with some 'nicer' default values,
232 232 used mostly for anonymous access
233 233 """
234 234 def_user = self.sa.query(User)\
235 235 .filter(User.username == User.DEFAULT_USER)\
236 236 .one()
237 237
238 238 def_user.name = 'Anonymous'
239 239 def_user.lastname = 'User'
240 240 def_user.email = User.DEFAULT_USER_EMAIL
241 241
242 242 try:
243 243 self.sa.add(def_user)
244 244 self.sa.commit()
245 245 except Exception:
246 246 self.sa.rollback()
247 247 raise
248 248
249 249 def fix_settings(self):
250 250 """
251 251 Fixes rhodecode settings and adds ga_code key for google analytics
252 252 """
253 253
254 254 hgsettings3 = RhodeCodeSetting('ga_code', '')
255 255
256 256 try:
257 257 self.sa.add(hgsettings3)
258 258 self.sa.commit()
259 259 except Exception:
260 260 self.sa.rollback()
261 261 raise
262 262
263 263 def create_admin_and_prompt(self):
264 264
265 265 # defaults
266 266 defaults = self.cli_args
267 267 username = defaults.get('username')
268 268 password = defaults.get('password')
269 269 email = defaults.get('email')
270 270
271 271 if username is None:
272 username = input('Specify admin username:')
272 username = eval(input('Specify admin username:'))
273 273 if password is None:
274 274 password = self._get_admin_password()
275 275 if not password:
276 276 # second try
277 277 password = self._get_admin_password()
278 278 if not password:
279 279 sys.exit()
280 280 if email is None:
281 email = input('Specify admin email:')
281 email = eval(input('Specify admin email:'))
282 282 api_key = self.cli_args.get('api_key')
283 283 self.create_user(username, password, email, True,
284 284 strict_creation_check=False,
285 285 api_key=api_key)
286 286
287 287 def _get_admin_password(self):
288 288 password = getpass.getpass('Specify admin password '
289 289 '(min 6 chars):')
290 290 confirm = getpass.getpass('Confirm password:')
291 291
292 292 if password != confirm:
293 293 log.error('passwords mismatch')
294 294 return False
295 295 if len(password) < 6:
296 296 log.error('password is too short - use at least 6 characters')
297 297 return False
298 298
299 299 return password
300 300
301 301 def create_test_admin_and_users(self):
302 302 log.info('creating admin and regular test users')
303 303 from rhodecode.tests import TEST_USER_ADMIN_LOGIN, \
304 304 TEST_USER_ADMIN_PASS, TEST_USER_ADMIN_EMAIL, \
305 305 TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS, \
306 306 TEST_USER_REGULAR_EMAIL, TEST_USER_REGULAR2_LOGIN, \
307 307 TEST_USER_REGULAR2_PASS, TEST_USER_REGULAR2_EMAIL
308 308
309 309 self.create_user(TEST_USER_ADMIN_LOGIN, TEST_USER_ADMIN_PASS,
310 310 TEST_USER_ADMIN_EMAIL, True, api_key=True)
311 311
312 312 self.create_user(TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS,
313 313 TEST_USER_REGULAR_EMAIL, False, api_key=True)
314 314
315 315 self.create_user(TEST_USER_REGULAR2_LOGIN, TEST_USER_REGULAR2_PASS,
316 316 TEST_USER_REGULAR2_EMAIL, False, api_key=True)
317 317
318 318 def create_ui_settings(self, repo_store_path):
319 319 """
320 320 Creates ui settings, fills out hooks
321 321 and disables dotencode
322 322 """
323 323 settings_model = SettingsModel(sa=self.sa)
324 324 from rhodecode.lib.vcs.backends.hg import largefiles_store
325 325 from rhodecode.lib.vcs.backends.git import lfs_store
326 326
327 327 # Build HOOKS
328 328 hooks = [
329 329 (RhodeCodeUi.HOOK_REPO_SIZE, 'python:vcsserver.hooks.repo_size'),
330 330
331 331 # HG
332 332 (RhodeCodeUi.HOOK_PRE_PULL, 'python:vcsserver.hooks.pre_pull'),
333 333 (RhodeCodeUi.HOOK_PULL, 'python:vcsserver.hooks.log_pull_action'),
334 334 (RhodeCodeUi.HOOK_PRE_PUSH, 'python:vcsserver.hooks.pre_push'),
335 335 (RhodeCodeUi.HOOK_PRETX_PUSH, 'python:vcsserver.hooks.pre_push'),
336 336 (RhodeCodeUi.HOOK_PUSH, 'python:vcsserver.hooks.log_push_action'),
337 337 (RhodeCodeUi.HOOK_PUSH_KEY, 'python:vcsserver.hooks.key_push'),
338 338
339 339 ]
340 340
341 341 for key, value in hooks:
342 342 hook_obj = settings_model.get_ui_by_key(key)
343 343 hooks2 = hook_obj if hook_obj else RhodeCodeUi()
344 344 hooks2.ui_section = 'hooks'
345 345 hooks2.ui_key = key
346 346 hooks2.ui_value = value
347 347 self.sa.add(hooks2)
348 348
349 349 # enable largefiles
350 350 largefiles = RhodeCodeUi()
351 351 largefiles.ui_section = 'extensions'
352 352 largefiles.ui_key = 'largefiles'
353 353 largefiles.ui_value = ''
354 354 self.sa.add(largefiles)
355 355
356 356 # set default largefiles cache dir, defaults to
357 357 # /repo_store_location/.cache/largefiles
358 358 largefiles = RhodeCodeUi()
359 359 largefiles.ui_section = 'largefiles'
360 360 largefiles.ui_key = 'usercache'
361 361 largefiles.ui_value = largefiles_store(repo_store_path)
362 362
363 363 self.sa.add(largefiles)
364 364
365 365 # set default lfs cache dir, defaults to
366 366 # /repo_store_location/.cache/lfs_store
367 367 lfsstore = RhodeCodeUi()
368 368 lfsstore.ui_section = 'vcs_git_lfs'
369 369 lfsstore.ui_key = 'store_location'
370 370 lfsstore.ui_value = lfs_store(repo_store_path)
371 371
372 372 self.sa.add(lfsstore)
373 373
374 374 # enable hgsubversion disabled by default
375 375 hgsubversion = RhodeCodeUi()
376 376 hgsubversion.ui_section = 'extensions'
377 377 hgsubversion.ui_key = 'hgsubversion'
378 378 hgsubversion.ui_value = ''
379 379 hgsubversion.ui_active = False
380 380 self.sa.add(hgsubversion)
381 381
382 382 # enable hgevolve disabled by default
383 383 hgevolve = RhodeCodeUi()
384 384 hgevolve.ui_section = 'extensions'
385 385 hgevolve.ui_key = 'evolve'
386 386 hgevolve.ui_value = ''
387 387 hgevolve.ui_active = False
388 388 self.sa.add(hgevolve)
389 389
390 390 hgevolve = RhodeCodeUi()
391 391 hgevolve.ui_section = 'experimental'
392 392 hgevolve.ui_key = 'evolution'
393 393 hgevolve.ui_value = ''
394 394 hgevolve.ui_active = False
395 395 self.sa.add(hgevolve)
396 396
397 397 hgevolve = RhodeCodeUi()
398 398 hgevolve.ui_section = 'experimental'
399 399 hgevolve.ui_key = 'evolution.exchange'
400 400 hgevolve.ui_value = ''
401 401 hgevolve.ui_active = False
402 402 self.sa.add(hgevolve)
403 403
404 404 hgevolve = RhodeCodeUi()
405 405 hgevolve.ui_section = 'extensions'
406 406 hgevolve.ui_key = 'topic'
407 407 hgevolve.ui_value = ''
408 408 hgevolve.ui_active = False
409 409 self.sa.add(hgevolve)
410 410
411 411 # enable hggit disabled by default
412 412 hggit = RhodeCodeUi()
413 413 hggit.ui_section = 'extensions'
414 414 hggit.ui_key = 'hggit'
415 415 hggit.ui_value = ''
416 416 hggit.ui_active = False
417 417 self.sa.add(hggit)
418 418
419 419 # set svn branch defaults
420 420 branches = ["/branches/*", "/trunk"]
421 421 tags = ["/tags/*"]
422 422
423 423 for branch in branches:
424 424 settings_model.create_ui_section_value(
425 425 RhodeCodeUi.SVN_BRANCH_ID, branch)
426 426
427 427 for tag in tags:
428 428 settings_model.create_ui_section_value(RhodeCodeUi.SVN_TAG_ID, tag)
429 429
430 430 def create_auth_plugin_options(self, skip_existing=False):
431 431 """
432 432 Create default auth plugin settings, and make it active
433 433
434 434 :param skip_existing:
435 435 """
436 436 defaults = [
437 437 ('auth_plugins',
438 438 'egg:rhodecode-enterprise-ce#token,egg:rhodecode-enterprise-ce#rhodecode',
439 439 'list'),
440 440
441 441 ('auth_authtoken_enabled',
442 442 'True',
443 443 'bool'),
444 444
445 445 ('auth_rhodecode_enabled',
446 446 'True',
447 447 'bool'),
448 448 ]
449 449 for k, v, t in defaults:
450 450 if (skip_existing and
451 451 SettingsModel().get_setting_by_name(k) is not None):
452 452 log.debug('Skipping option %s', k)
453 453 continue
454 454 setting = RhodeCodeSetting(k, v, t)
455 455 self.sa.add(setting)
456 456
457 457 def create_default_options(self, skip_existing=False):
458 458 """Creates default settings"""
459 459
460 460 for k, v, t in [
461 461 ('default_repo_enable_locking', False, 'bool'),
462 462 ('default_repo_enable_downloads', False, 'bool'),
463 463 ('default_repo_enable_statistics', False, 'bool'),
464 464 ('default_repo_private', False, 'bool'),
465 465 ('default_repo_type', 'hg', 'unicode')]:
466 466
467 467 if (skip_existing and
468 468 SettingsModel().get_setting_by_name(k) is not None):
469 469 log.debug('Skipping option %s', k)
470 470 continue
471 471 setting = RhodeCodeSetting(k, v, t)
472 472 self.sa.add(setting)
473 473
474 474 def fixup_groups(self):
475 475 def_usr = User.get_default_user()
476 476 for g in RepoGroup.query().all():
477 477 g.group_name = g.get_new_name(g.name)
478 478 self.sa.add(g)
479 479 # get default perm
480 480 default = UserRepoGroupToPerm.query()\
481 481 .filter(UserRepoGroupToPerm.group == g)\
482 482 .filter(UserRepoGroupToPerm.user == def_usr)\
483 483 .scalar()
484 484
485 485 if default is None:
486 486 log.debug('missing default permission for group %s adding', g)
487 487 perm_obj = RepoGroupModel()._create_default_perms(g)
488 488 self.sa.add(perm_obj)
489 489
490 490 def reset_permissions(self, username):
491 491 """
492 492 Resets permissions to default state, useful when old systems had
493 493 bad permissions, we must clean them up
494 494
495 495 :param username:
496 496 """
497 497 default_user = User.get_by_username(username)
498 498 if not default_user:
499 499 return
500 500
501 501 u2p = UserToPerm.query()\
502 502 .filter(UserToPerm.user == default_user).all()
503 503 fixed = False
504 504 if len(u2p) != len(Permission.DEFAULT_USER_PERMISSIONS):
505 505 for p in u2p:
506 506 Session().delete(p)
507 507 fixed = True
508 508 self.populate_default_permissions()
509 509 return fixed
510 510
511 511 def config_prompt(self, test_repo_path='', retries=3):
512 512 defaults = self.cli_args
513 513 _path = defaults.get('repos_location')
514 514 if retries == 3:
515 515 log.info('Setting up repositories config')
516 516
517 517 if _path is not None:
518 518 path = _path
519 519 elif not self.tests and not test_repo_path:
520 path = input(
520 path = eval(input(
521 521 'Enter a valid absolute path to store repositories. '
522 522 'All repositories in that path will be added automatically:'
523 )
523 ))
524 524 else:
525 525 path = test_repo_path
526 526 path_ok = True
527 527
528 528 # check proper dir
529 529 if not os.path.isdir(path):
530 530 path_ok = False
531 531 log.error('Given path %s is not a valid directory', path)
532 532
533 533 elif not os.path.isabs(path):
534 534 path_ok = False
535 535 log.error('Given path %s is not an absolute path', path)
536 536
537 537 # check if path is at least readable.
538 538 if not os.access(path, os.R_OK):
539 539 path_ok = False
540 540 log.error('Given path %s is not readable', path)
541 541
542 542 # check write access, warn user about non writeable paths
543 543 elif not os.access(path, os.W_OK) and path_ok:
544 544 log.warning('No write permission to given path %s', path)
545 545
546 546 q = ('Given path %s is not writeable, do you want to '
547 547 'continue with read only mode ? [y/n]' % (path,))
548 548 if not self.ask_ok(q):
549 549 log.error('Canceled by user')
550 550 sys.exit(-1)
551 551
552 552 if retries == 0:
553 553 sys.exit('max retries reached')
554 554 if not path_ok:
555 555 retries -= 1
556 556 return self.config_prompt(test_repo_path, retries)
557 557
558 558 real_path = os.path.normpath(os.path.realpath(path))
559 559
560 560 if real_path != os.path.normpath(path):
561 561 q = ('Path looks like a symlink, RhodeCode Enterprise will store '
562 562 'given path as %s ? [y/n]') % (real_path,)
563 563 if not self.ask_ok(q):
564 564 log.error('Canceled by user')
565 565 sys.exit(-1)
566 566
567 567 return real_path
568 568
569 569 def create_settings(self, path):
570 570
571 571 self.create_ui_settings(path)
572 572
573 573 ui_config = [
574 574 ('web', 'push_ssl', 'False'),
575 575 ('web', 'allow_archive', 'gz zip bz2'),
576 576 ('web', 'allow_push', '*'),
577 577 ('web', 'baseurl', '/'),
578 578 ('paths', '/', path),
579 579 ('phases', 'publish', 'True')
580 580 ]
581 581 for section, key, value in ui_config:
582 582 ui_conf = RhodeCodeUi()
583 583 setattr(ui_conf, 'ui_section', section)
584 584 setattr(ui_conf, 'ui_key', key)
585 585 setattr(ui_conf, 'ui_value', value)
586 586 self.sa.add(ui_conf)
587 587
588 588 # rhodecode app settings
589 589 settings = [
590 590 ('realm', 'RhodeCode', 'unicode'),
591 591 ('title', '', 'unicode'),
592 592 ('pre_code', '', 'unicode'),
593 593 ('post_code', '', 'unicode'),
594 594
595 595 # Visual
596 596 ('show_public_icon', True, 'bool'),
597 597 ('show_private_icon', True, 'bool'),
598 598 ('stylify_metatags', True, 'bool'),
599 599 ('dashboard_items', 100, 'int'),
600 600 ('admin_grid_items', 25, 'int'),
601 601
602 602 ('markup_renderer', 'markdown', 'unicode'),
603 603
604 604 ('repository_fields', True, 'bool'),
605 605 ('show_version', True, 'bool'),
606 606 ('show_revision_number', True, 'bool'),
607 607 ('show_sha_length', 12, 'int'),
608 608
609 609 ('use_gravatar', False, 'bool'),
610 610 ('gravatar_url', User.DEFAULT_GRAVATAR_URL, 'unicode'),
611 611
612 612 ('clone_uri_tmpl', Repository.DEFAULT_CLONE_URI, 'unicode'),
613 613 ('clone_uri_id_tmpl', Repository.DEFAULT_CLONE_URI_ID, 'unicode'),
614 614 ('clone_uri_ssh_tmpl', Repository.DEFAULT_CLONE_URI_SSH, 'unicode'),
615 615 ('support_url', '', 'unicode'),
616 616 ('update_url', RhodeCodeSetting.DEFAULT_UPDATE_URL, 'unicode'),
617 617
618 618 # VCS Settings
619 619 ('pr_merge_enabled', True, 'bool'),
620 620 ('use_outdated_comments', True, 'bool'),
621 621 ('diff_cache', True, 'bool'),
622 622 ]
623 623
624 624 for key, val, type_ in settings:
625 625 sett = RhodeCodeSetting(key, val, type_)
626 626 self.sa.add(sett)
627 627
628 628 self.create_auth_plugin_options()
629 629 self.create_default_options()
630 630
631 631 log.info('created ui config')
632 632
633 633 def create_user(self, username, password, email='', admin=False,
634 634 strict_creation_check=True, api_key=None):
635 635 log.info('creating user `%s`', username)
636 636 user = UserModel().create_or_update(
637 username, password, email, firstname=u'RhodeCode', lastname=u'Admin',
637 username, password, email, firstname='RhodeCode', lastname='Admin',
638 638 active=True, admin=admin, extern_type="rhodecode",
639 639 strict_creation_check=strict_creation_check)
640 640
641 641 if api_key:
642 642 log.info('setting a new default auth token for user `%s`', username)
643 643 UserModel().add_auth_token(
644 644 user=user, lifetime_minutes=-1,
645 645 role=UserModel.auth_token_role.ROLE_ALL,
646 description=u'BUILTIN TOKEN')
646 description='BUILTIN TOKEN')
647 647
648 648 def create_default_user(self):
649 649 log.info('creating default user')
650 650 # create default user for handling default permissions.
651 651 user = UserModel().create_or_update(username=User.DEFAULT_USER,
652 652 password=str(uuid.uuid1())[:20],
653 653 email=User.DEFAULT_USER_EMAIL,
654 firstname=u'Anonymous',
655 lastname=u'User',
654 firstname='Anonymous',
655 lastname='User',
656 656 strict_creation_check=False)
657 657 # based on configuration options activate/de-activate this user which
658 # controlls anonymous access
658 # controls anonymous access
659 659 if self.cli_args.get('public_access') is False:
660 660 log.info('Public access disabled')
661 661 user.active = False
662 662 Session().add(user)
663 663 Session().commit()
664 664
665 665 def create_permissions(self):
666 666 """
667 667 Creates all permissions defined in the system
668 668 """
669 669 # module.(access|create|change|delete)_[name]
670 670 # module.(none|read|write|admin)
671 671 log.info('creating permissions')
672 672 PermissionModel(self.sa).create_permissions()
673 673
674 674 def populate_default_permissions(self):
675 675 """
676 676 Populate default permissions. It will create only the default
677 677 permissions that are missing, and not alter already defined ones
678 678 """
679 679 log.info('creating default user permissions')
680 680 PermissionModel(self.sa).create_default_user_permissions(user=User.DEFAULT_USER)
@@ -1,2031 +1,2031 b''
1 1 """Diff Match and Patch
2 2
3 3 Copyright 2006 Google Inc.
4 4 http://code.google.com/p/google-diff-match-patch/
5 5
6 6 Licensed under the Apache License, Version 2.0 (the "License");
7 7 you may not use this file except in compliance with the License.
8 8 You may obtain a copy of the License at
9 9
10 10 http://www.apache.org/licenses/LICENSE-2.0
11 11
12 12 Unless required by applicable law or agreed to in writing, software
13 13 distributed under the License is distributed on an "AS IS" BASIS,
14 14 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 15 See the License for the specific language governing permissions and
16 16 limitations under the License.
17 17 """
18 18
19 19 """Functions for diff, match and patch.
20 20
21 21 Computes the difference between two texts to create a patch.
22 22 Applies the patch onto another text, allowing for errors.
23 23 """
24 24
25 25 __author__ = "fraser@google.com (Neil Fraser)"
26 26
27 27 import math
28 28 import re
29 29 import sys
30 30 import time
31 31 import urllib.request, urllib.parse, urllib.error
32 32
33 33
34 34 class diff_match_patch:
35 35 """Class containing the diff, match and patch methods.
36 36
37 37 Also contains the behaviour settings.
38 38 """
39 39
40 40 def __init__(self):
41 41 """Inits a diff_match_patch object with default settings.
42 42 Redefine these in your program to override the defaults.
43 43 """
44 44
45 45 # Number of seconds to map a diff before giving up (0 for infinity).
46 46 self.Diff_Timeout = 1.0
47 47 # Cost of an empty edit operation in terms of edit characters.
48 48 self.Diff_EditCost = 4
49 49 # At what point is no match declared (0.0 = perfection, 1.0 = very loose).
50 50 self.Match_Threshold = 0.5
51 51 # How far to search for a match (0 = exact location, 1000+ = broad match).
52 52 # A match this many characters away from the expected location will add
53 53 # 1.0 to the score (0.0 is a perfect match).
54 54 self.Match_Distance = 1000
55 55 # When deleting a large block of text (over ~64 characters), how close do
56 56 # the contents have to be to match the expected contents. (0.0 = perfection,
57 57 # 1.0 = very loose). Note that Match_Threshold controls how closely the
58 58 # end points of a delete need to match.
59 59 self.Patch_DeleteThreshold = 0.5
60 60 # Chunk size for context length.
61 61 self.Patch_Margin = 4
62 62
63 63 # The number of bits in an int.
64 64 # Python has no maximum, thus to disable patch splitting set to 0.
65 65 # However to avoid long patches in certain pathological cases, use 32.
66 66 # Multiple short patches (using native ints) are much faster than long ones.
67 67 self.Match_MaxBits = 32
68 68
69 69 # DIFF FUNCTIONS
70 70
71 71 # The data structure representing a diff is an array of tuples:
72 72 # [(DIFF_DELETE, "Hello"), (DIFF_INSERT, "Goodbye"), (DIFF_EQUAL, " world.")]
73 73 # which means: delete "Hello", add "Goodbye" and keep " world."
74 74 DIFF_DELETE = -1
75 75 DIFF_INSERT = 1
76 76 DIFF_EQUAL = 0
77 77
78 78 def diff_main(self, text1, text2, checklines=True, deadline=None):
79 79 """Find the differences between two texts. Simplifies the problem by
80 80 stripping any common prefix or suffix off the texts before diffing.
81 81
82 82 Args:
83 83 text1: Old string to be diffed.
84 84 text2: New string to be diffed.
85 85 checklines: Optional speedup flag. If present and false, then don't run
86 86 a line-level diff first to identify the changed areas.
87 87 Defaults to true, which does a faster, slightly less optimal diff.
88 88 deadline: Optional time when the diff should be complete by. Used
89 89 internally for recursive calls. Users should set DiffTimeout instead.
90 90
91 91 Returns:
92 92 Array of changes.
93 93 """
94 94 # Set a deadline by which time the diff must be complete.
95 95 if deadline is None:
96 96 # Unlike in most languages, Python counts time in seconds.
97 97 if self.Diff_Timeout <= 0:
98 98 deadline = sys.maxsize
99 99 else:
100 100 deadline = time.time() + self.Diff_Timeout
101 101
102 102 # Check for null inputs.
103 103 if text1 is None or text2 is None:
104 104 raise ValueError("Null inputs. (diff_main)")
105 105
106 106 # Check for equality (speedup).
107 107 if text1 == text2:
108 108 if text1:
109 109 return [(self.DIFF_EQUAL, text1)]
110 110 return []
111 111
112 112 # Trim off common prefix (speedup).
113 113 commonlength = self.diff_commonPrefix(text1, text2)
114 114 commonprefix = text1[:commonlength]
115 115 text1 = text1[commonlength:]
116 116 text2 = text2[commonlength:]
117 117
118 118 # Trim off common suffix (speedup).
119 119 commonlength = self.diff_commonSuffix(text1, text2)
120 120 if commonlength == 0:
121 121 commonsuffix = ""
122 122 else:
123 123 commonsuffix = text1[-commonlength:]
124 124 text1 = text1[:-commonlength]
125 125 text2 = text2[:-commonlength]
126 126
127 127 # Compute the diff on the middle block.
128 128 diffs = self.diff_compute(text1, text2, checklines, deadline)
129 129
130 130 # Restore the prefix and suffix.
131 131 if commonprefix:
132 132 diffs[:0] = [(self.DIFF_EQUAL, commonprefix)]
133 133 if commonsuffix:
134 134 diffs.append((self.DIFF_EQUAL, commonsuffix))
135 135 self.diff_cleanupMerge(diffs)
136 136 return diffs
137 137
138 138 def diff_compute(self, text1, text2, checklines, deadline):
139 139 """Find the differences between two texts. Assumes that the texts do not
140 140 have any common prefix or suffix.
141 141
142 142 Args:
143 143 text1: Old string to be diffed.
144 144 text2: New string to be diffed.
145 145 checklines: Speedup flag. If false, then don't run a line-level diff
146 146 first to identify the changed areas.
147 147 If true, then run a faster, slightly less optimal diff.
148 148 deadline: Time when the diff should be complete by.
149 149
150 150 Returns:
151 151 Array of changes.
152 152 """
153 153 if not text1:
154 154 # Just add some text (speedup).
155 155 return [(self.DIFF_INSERT, text2)]
156 156
157 157 if not text2:
158 158 # Just delete some text (speedup).
159 159 return [(self.DIFF_DELETE, text1)]
160 160
161 161 if len(text1) > len(text2):
162 162 (longtext, shorttext) = (text1, text2)
163 163 else:
164 164 (shorttext, longtext) = (text1, text2)
165 165 i = longtext.find(shorttext)
166 166 if i != -1:
167 167 # Shorter text is inside the longer text (speedup).
168 168 diffs = [
169 169 (self.DIFF_INSERT, longtext[:i]),
170 170 (self.DIFF_EQUAL, shorttext),
171 171 (self.DIFF_INSERT, longtext[i + len(shorttext) :]),
172 172 ]
173 173 # Swap insertions for deletions if diff is reversed.
174 174 if len(text1) > len(text2):
175 175 diffs[0] = (self.DIFF_DELETE, diffs[0][1])
176 176 diffs[2] = (self.DIFF_DELETE, diffs[2][1])
177 177 return diffs
178 178
179 179 if len(shorttext) == 1:
180 180 # Single character string.
181 181 # After the previous speedup, the character can't be an equality.
182 182 return [(self.DIFF_DELETE, text1), (self.DIFF_INSERT, text2)]
183 183
184 184 # Check to see if the problem can be split in two.
185 185 hm = self.diff_halfMatch(text1, text2)
186 186 if hm:
187 187 # A half-match was found, sort out the return data.
188 188 (text1_a, text1_b, text2_a, text2_b, mid_common) = hm
189 189 # Send both pairs off for separate processing.
190 190 diffs_a = self.diff_main(text1_a, text2_a, checklines, deadline)
191 191 diffs_b = self.diff_main(text1_b, text2_b, checklines, deadline)
192 192 # Merge the results.
193 193 return diffs_a + [(self.DIFF_EQUAL, mid_common)] + diffs_b
194 194
195 195 if checklines and len(text1) > 100 and len(text2) > 100:
196 196 return self.diff_lineMode(text1, text2, deadline)
197 197
198 198 return self.diff_bisect(text1, text2, deadline)
199 199
200 200 def diff_lineMode(self, text1, text2, deadline):
201 201 """Do a quick line-level diff on both strings, then rediff the parts for
202 202 greater accuracy.
203 203 This speedup can produce non-minimal diffs.
204 204
205 205 Args:
206 206 text1: Old string to be diffed.
207 207 text2: New string to be diffed.
208 208 deadline: Time when the diff should be complete by.
209 209
210 210 Returns:
211 211 Array of changes.
212 212 """
213 213
214 214 # Scan the text on a line-by-line basis first.
215 215 (text1, text2, linearray) = self.diff_linesToChars(text1, text2)
216 216
217 217 diffs = self.diff_main(text1, text2, False, deadline)
218 218
219 219 # Convert the diff back to original text.
220 220 self.diff_charsToLines(diffs, linearray)
221 221 # Eliminate freak matches (e.g. blank lines)
222 222 self.diff_cleanupSemantic(diffs)
223 223
224 224 # Rediff any replacement blocks, this time character-by-character.
225 225 # Add a dummy entry at the end.
226 226 diffs.append((self.DIFF_EQUAL, ""))
227 227 pointer = 0
228 228 count_delete = 0
229 229 count_insert = 0
230 230 text_delete = ""
231 231 text_insert = ""
232 232 while pointer < len(diffs):
233 233 if diffs[pointer][0] == self.DIFF_INSERT:
234 234 count_insert += 1
235 235 text_insert += diffs[pointer][1]
236 236 elif diffs[pointer][0] == self.DIFF_DELETE:
237 237 count_delete += 1
238 238 text_delete += diffs[pointer][1]
239 239 elif diffs[pointer][0] == self.DIFF_EQUAL:
240 240 # Upon reaching an equality, check for prior redundancies.
241 241 if count_delete >= 1 and count_insert >= 1:
242 242 # Delete the offending records and add the merged ones.
243 243 a = self.diff_main(text_delete, text_insert, False, deadline)
244 244 diffs[pointer - count_delete - count_insert : pointer] = a
245 245 pointer = pointer - count_delete - count_insert + len(a)
246 246 count_insert = 0
247 247 count_delete = 0
248 248 text_delete = ""
249 249 text_insert = ""
250 250
251 251 pointer += 1
252 252
253 253 diffs.pop() # Remove the dummy entry at the end.
254 254
255 255 return diffs
256 256
257 257 def diff_bisect(self, text1, text2, deadline):
258 258 """Find the 'middle snake' of a diff, split the problem in two
259 259 and return the recursively constructed diff.
260 260 See Myers 1986 paper: An O(ND) Difference Algorithm and Its Variations.
261 261
262 262 Args:
263 263 text1: Old string to be diffed.
264 264 text2: New string to be diffed.
265 265 deadline: Time at which to bail if not yet complete.
266 266
267 267 Returns:
268 268 Array of diff tuples.
269 269 """
270 270
271 271 # Cache the text lengths to prevent multiple calls.
272 272 text1_length = len(text1)
273 273 text2_length = len(text2)
274 274 max_d = (text1_length + text2_length + 1) // 2
275 275 v_offset = max_d
276 276 v_length = 2 * max_d
277 277 v1 = [-1] * v_length
278 278 v1[v_offset + 1] = 0
279 279 v2 = v1[:]
280 280 delta = text1_length - text2_length
281 281 # If the total number of characters is odd, then the front path will
282 282 # collide with the reverse path.
283 283 front = delta % 2 != 0
284 284 # Offsets for start and end of k loop.
285 285 # Prevents mapping of space beyond the grid.
286 286 k1start = 0
287 287 k1end = 0
288 288 k2start = 0
289 289 k2end = 0
290 290 for d in range(max_d):
291 291 # Bail out if deadline is reached.
292 292 if time.time() > deadline:
293 293 break
294 294
295 295 # Walk the front path one step.
296 296 for k1 in range(-d + k1start, d + 1 - k1end, 2):
297 297 k1_offset = v_offset + k1
298 298 if k1 == -d or (k1 != d and v1[k1_offset - 1] < v1[k1_offset + 1]):
299 299 x1 = v1[k1_offset + 1]
300 300 else:
301 301 x1 = v1[k1_offset - 1] + 1
302 302 y1 = x1 - k1
303 303 while (
304 304 x1 < text1_length and y1 < text2_length and text1[x1] == text2[y1]
305 305 ):
306 306 x1 += 1
307 307 y1 += 1
308 308 v1[k1_offset] = x1
309 309 if x1 > text1_length:
310 310 # Ran off the right of the graph.
311 311 k1end += 2
312 312 elif y1 > text2_length:
313 313 # Ran off the bottom of the graph.
314 314 k1start += 2
315 315 elif front:
316 316 k2_offset = v_offset + delta - k1
317 317 if k2_offset >= 0 and k2_offset < v_length and v2[k2_offset] != -1:
318 318 # Mirror x2 onto top-left coordinate system.
319 319 x2 = text1_length - v2[k2_offset]
320 320 if x1 >= x2:
321 321 # Overlap detected.
322 322 return self.diff_bisectSplit(text1, text2, x1, y1, deadline)
323 323
324 324 # Walk the reverse path one step.
325 325 for k2 in range(-d + k2start, d + 1 - k2end, 2):
326 326 k2_offset = v_offset + k2
327 327 if k2 == -d or (k2 != d and v2[k2_offset - 1] < v2[k2_offset + 1]):
328 328 x2 = v2[k2_offset + 1]
329 329 else:
330 330 x2 = v2[k2_offset - 1] + 1
331 331 y2 = x2 - k2
332 332 while (
333 333 x2 < text1_length
334 334 and y2 < text2_length
335 335 and text1[-x2 - 1] == text2[-y2 - 1]
336 336 ):
337 337 x2 += 1
338 338 y2 += 1
339 339 v2[k2_offset] = x2
340 340 if x2 > text1_length:
341 341 # Ran off the left of the graph.
342 342 k2end += 2
343 343 elif y2 > text2_length:
344 344 # Ran off the top of the graph.
345 345 k2start += 2
346 346 elif not front:
347 347 k1_offset = v_offset + delta - k2
348 348 if k1_offset >= 0 and k1_offset < v_length and v1[k1_offset] != -1:
349 349 x1 = v1[k1_offset]
350 350 y1 = v_offset + x1 - k1_offset
351 351 # Mirror x2 onto top-left coordinate system.
352 352 x2 = text1_length - x2
353 353 if x1 >= x2:
354 354 # Overlap detected.
355 355 return self.diff_bisectSplit(text1, text2, x1, y1, deadline)
356 356
357 357 # Diff took too long and hit the deadline or
358 358 # number of diffs equals number of characters, no commonality at all.
359 359 return [(self.DIFF_DELETE, text1), (self.DIFF_INSERT, text2)]
360 360
361 361 def diff_bisectSplit(self, text1, text2, x, y, deadline):
362 362 """Given the location of the 'middle snake', split the diff in two parts
363 363 and recurse.
364 364
365 365 Args:
366 366 text1: Old string to be diffed.
367 367 text2: New string to be diffed.
368 368 x: Index of split point in text1.
369 369 y: Index of split point in text2.
370 370 deadline: Time at which to bail if not yet complete.
371 371
372 372 Returns:
373 373 Array of diff tuples.
374 374 """
375 375 text1a = text1[:x]
376 376 text2a = text2[:y]
377 377 text1b = text1[x:]
378 378 text2b = text2[y:]
379 379
380 380 # Compute both diffs serially.
381 381 diffs = self.diff_main(text1a, text2a, False, deadline)
382 382 diffsb = self.diff_main(text1b, text2b, False, deadline)
383 383
384 384 return diffs + diffsb
385 385
386 386 def diff_linesToChars(self, text1, text2):
387 387 """Split two texts into an array of strings. Reduce the texts to a string
388 388 of hashes where each Unicode character represents one line.
389 389
390 390 Args:
391 391 text1: First string.
392 392 text2: Second string.
393 393
394 394 Returns:
395 395 Three element tuple, containing the encoded text1, the encoded text2 and
396 396 the array of unique strings. The zeroth element of the array of unique
397 397 strings is intentionally blank.
398 398 """
399 399 lineArray = [] # e.g. lineArray[4] == "Hello\n"
400 400 lineHash = {} # e.g. lineHash["Hello\n"] == 4
401 401
402 402 # "\x00" is a valid character, but various debuggers don't like it.
403 403 # So we'll insert a junk entry to avoid generating a null character.
404 404 lineArray.append("")
405 405
406 406 def diff_linesToCharsMunge(text):
407 407 """Split a text into an array of strings. Reduce the texts to a string
408 408 of hashes where each Unicode character represents one line.
409 409 Modifies linearray and linehash through being a closure.
410 410
411 411 Args:
412 412 text: String to encode.
413 413
414 414 Returns:
415 415 Encoded string.
416 416 """
417 417 chars = []
418 418 # Walk the text, pulling out a substring for each line.
419 419 # text.split('\n') would would temporarily double our memory footprint.
420 420 # Modifying text would create many large strings to garbage collect.
421 421 lineStart = 0
422 422 lineEnd = -1
423 423 while lineEnd < len(text) - 1:
424 424 lineEnd = text.find("\n", lineStart)
425 425 if lineEnd == -1:
426 426 lineEnd = len(text) - 1
427 427 line = text[lineStart : lineEnd + 1]
428 428 lineStart = lineEnd + 1
429 429
430 430 if line in lineHash:
431 431 chars.append(chr(lineHash[line]))
432 432 else:
433 433 lineArray.append(line)
434 434 lineHash[line] = len(lineArray) - 1
435 435 chars.append(chr(len(lineArray) - 1))
436 436 return "".join(chars)
437 437
438 438 chars1 = diff_linesToCharsMunge(text1)
439 439 chars2 = diff_linesToCharsMunge(text2)
440 440 return (chars1, chars2, lineArray)
441 441
442 442 def diff_charsToLines(self, diffs, lineArray):
443 443 """Rehydrate the text in a diff from a string of line hashes to real lines
444 444 of text.
445 445
446 446 Args:
447 447 diffs: Array of diff tuples.
448 448 lineArray: Array of unique strings.
449 449 """
450 450 for x in range(len(diffs)):
451 451 text = []
452 452 for char in diffs[x][1]:
453 453 text.append(lineArray[ord(char)])
454 454 diffs[x] = (diffs[x][0], "".join(text))
455 455
456 456 def diff_commonPrefix(self, text1, text2):
457 457 """Determine the common prefix of two strings.
458 458
459 459 Args:
460 460 text1: First string.
461 461 text2: Second string.
462 462
463 463 Returns:
464 464 The number of characters common to the start of each string.
465 465 """
466 466 # Quick check for common null cases.
467 467 if not text1 or not text2 or text1[0] != text2[0]:
468 468 return 0
469 469 # Binary search.
470 470 # Performance analysis: http://neil.fraser.name/news/2007/10/09/
471 471 pointermin = 0
472 472 pointermax = min(len(text1), len(text2))
473 473 pointermid = pointermax
474 474 pointerstart = 0
475 475 while pointermin < pointermid:
476 476 if text1[pointerstart:pointermid] == text2[pointerstart:pointermid]:
477 477 pointermin = pointermid
478 478 pointerstart = pointermin
479 479 else:
480 480 pointermax = pointermid
481 481 pointermid = (pointermax - pointermin) // 2 + pointermin
482 482 return pointermid
483 483
484 484 def diff_commonSuffix(self, text1, text2):
485 485 """Determine the common suffix of two strings.
486 486
487 487 Args:
488 488 text1: First string.
489 489 text2: Second string.
490 490
491 491 Returns:
492 492 The number of characters common to the end of each string.
493 493 """
494 494 # Quick check for common null cases.
495 495 if not text1 or not text2 or text1[-1] != text2[-1]:
496 496 return 0
497 497 # Binary search.
498 498 # Performance analysis: http://neil.fraser.name/news/2007/10/09/
499 499 pointermin = 0
500 500 pointermax = min(len(text1), len(text2))
501 501 pointermid = pointermax
502 502 pointerend = 0
503 503 while pointermin < pointermid:
504 504 if (
505 505 text1[-pointermid : len(text1) - pointerend]
506 506 == text2[-pointermid : len(text2) - pointerend]
507 507 ):
508 508 pointermin = pointermid
509 509 pointerend = pointermin
510 510 else:
511 511 pointermax = pointermid
512 512 pointermid = (pointermax - pointermin) // 2 + pointermin
513 513 return pointermid
514 514
515 515 def diff_commonOverlap(self, text1, text2):
516 516 """Determine if the suffix of one string is the prefix of another.
517 517
518 518 Args:
519 519 text1 First string.
520 520 text2 Second string.
521 521
522 522 Returns:
523 523 The number of characters common to the end of the first
524 524 string and the start of the second string.
525 525 """
526 526 # Cache the text lengths to prevent multiple calls.
527 527 text1_length = len(text1)
528 528 text2_length = len(text2)
529 529 # Eliminate the null case.
530 530 if text1_length == 0 or text2_length == 0:
531 531 return 0
532 532 # Truncate the longer string.
533 533 if text1_length > text2_length:
534 534 text1 = text1[-text2_length:]
535 535 elif text1_length < text2_length:
536 536 text2 = text2[:text1_length]
537 537 text_length = min(text1_length, text2_length)
538 538 # Quick check for the worst case.
539 539 if text1 == text2:
540 540 return text_length
541 541
542 542 # Start by looking for a single character match
543 543 # and increase length until no match is found.
544 544 # Performance analysis: http://neil.fraser.name/news/2010/11/04/
545 545 best = 0
546 546 length = 1
547 547 while True:
548 548 pattern = text1[-length:]
549 549 found = text2.find(pattern)
550 550 if found == -1:
551 551 return best
552 552 length += found
553 553 if found == 0 or text1[-length:] == text2[:length]:
554 554 best = length
555 555 length += 1
556 556
557 557 def diff_halfMatch(self, text1, text2):
558 558 """Do the two texts share a substring which is at least half the length of
559 559 the longer text?
560 560 This speedup can produce non-minimal diffs.
561 561
562 562 Args:
563 563 text1: First string.
564 564 text2: Second string.
565 565
566 566 Returns:
567 567 Five element Array, containing the prefix of text1, the suffix of text1,
568 568 the prefix of text2, the suffix of text2 and the common middle. Or None
569 569 if there was no match.
570 570 """
571 571 if self.Diff_Timeout <= 0:
572 572 # Don't risk returning a non-optimal diff if we have unlimited time.
573 573 return None
574 574 if len(text1) > len(text2):
575 575 (longtext, shorttext) = (text1, text2)
576 576 else:
577 577 (shorttext, longtext) = (text1, text2)
578 578 if len(longtext) < 4 or len(shorttext) * 2 < len(longtext):
579 579 return None # Pointless.
580 580
581 581 def diff_halfMatchI(longtext, shorttext, i):
582 582 """Does a substring of shorttext exist within longtext such that the
583 583 substring is at least half the length of longtext?
584 584 Closure, but does not reference any external variables.
585 585
586 586 Args:
587 587 longtext: Longer string.
588 588 shorttext: Shorter string.
589 589 i: Start index of quarter length substring within longtext.
590 590
591 591 Returns:
592 592 Five element Array, containing the prefix of longtext, the suffix of
593 593 longtext, the prefix of shorttext, the suffix of shorttext and the
594 594 common middle. Or None if there was no match.
595 595 """
596 596 seed = longtext[i : i + len(longtext) // 4]
597 597 best_common = ""
598 598 j = shorttext.find(seed)
599 599 while j != -1:
600 600 prefixLength = self.diff_commonPrefix(longtext[i:], shorttext[j:])
601 601 suffixLength = self.diff_commonSuffix(longtext[:i], shorttext[:j])
602 602 if len(best_common) < suffixLength + prefixLength:
603 603 best_common = (
604 604 shorttext[j - suffixLength : j]
605 605 + shorttext[j : j + prefixLength]
606 606 )
607 607 best_longtext_a = longtext[: i - suffixLength]
608 608 best_longtext_b = longtext[i + prefixLength :]
609 609 best_shorttext_a = shorttext[: j - suffixLength]
610 610 best_shorttext_b = shorttext[j + prefixLength :]
611 611 j = shorttext.find(seed, j + 1)
612 612
613 613 if len(best_common) * 2 >= len(longtext):
614 614 return (
615 615 best_longtext_a,
616 616 best_longtext_b,
617 617 best_shorttext_a,
618 618 best_shorttext_b,
619 619 best_common,
620 620 )
621 621 else:
622 622 return None
623 623
624 624 # First check if the second quarter is the seed for a half-match.
625 625 hm1 = diff_halfMatchI(longtext, shorttext, (len(longtext) + 3) // 4)
626 626 # Check again based on the third quarter.
627 627 hm2 = diff_halfMatchI(longtext, shorttext, (len(longtext) + 1) // 2)
628 628 if not hm1 and not hm2:
629 629 return None
630 630 elif not hm2:
631 631 hm = hm1
632 632 elif not hm1:
633 633 hm = hm2
634 634 else:
635 635 # Both matched. Select the longest.
636 636 if len(hm1[4]) > len(hm2[4]):
637 637 hm = hm1
638 638 else:
639 639 hm = hm2
640 640
641 641 # A half-match was found, sort out the return data.
642 642 if len(text1) > len(text2):
643 643 (text1_a, text1_b, text2_a, text2_b, mid_common) = hm
644 644 else:
645 645 (text2_a, text2_b, text1_a, text1_b, mid_common) = hm
646 646 return (text1_a, text1_b, text2_a, text2_b, mid_common)
647 647
648 648 def diff_cleanupSemantic(self, diffs):
649 649 """Reduce the number of edits by eliminating semantically trivial
650 650 equalities.
651 651
652 652 Args:
653 653 diffs: Array of diff tuples.
654 654 """
655 655 changes = False
656 656 equalities = [] # Stack of indices where equalities are found.
657 657 lastequality = None # Always equal to diffs[equalities[-1]][1]
658 658 pointer = 0 # Index of current position.
659 659 # Number of chars that changed prior to the equality.
660 660 length_insertions1, length_deletions1 = 0, 0
661 661 # Number of chars that changed after the equality.
662 662 length_insertions2, length_deletions2 = 0, 0
663 663 while pointer < len(diffs):
664 664 if diffs[pointer][0] == self.DIFF_EQUAL: # Equality found.
665 665 equalities.append(pointer)
666 666 length_insertions1, length_insertions2 = length_insertions2, 0
667 667 length_deletions1, length_deletions2 = length_deletions2, 0
668 668 lastequality = diffs[pointer][1]
669 669 else: # An insertion or deletion.
670 670 if diffs[pointer][0] == self.DIFF_INSERT:
671 671 length_insertions2 += len(diffs[pointer][1])
672 672 else:
673 673 length_deletions2 += len(diffs[pointer][1])
674 674 # Eliminate an equality that is smaller or equal to the edits on both
675 675 # sides of it.
676 676 if (
677 677 lastequality
678 678 and (
679 679 len(lastequality) <= max(length_insertions1, length_deletions1)
680 680 )
681 681 and (
682 682 len(lastequality) <= max(length_insertions2, length_deletions2)
683 683 )
684 684 ):
685 685 # Duplicate record.
686 686 diffs.insert(equalities[-1], (self.DIFF_DELETE, lastequality))
687 687 # Change second copy to insert.
688 688 diffs[equalities[-1] + 1] = (
689 689 self.DIFF_INSERT,
690 690 diffs[equalities[-1] + 1][1],
691 691 )
692 692 # Throw away the equality we just deleted.
693 693 equalities.pop()
694 694 # Throw away the previous equality (it needs to be reevaluated).
695 695 if len(equalities):
696 696 equalities.pop()
697 697 if len(equalities):
698 698 pointer = equalities[-1]
699 699 else:
700 700 pointer = -1
701 701 # Reset the counters.
702 702 length_insertions1, length_deletions1 = 0, 0
703 703 length_insertions2, length_deletions2 = 0, 0
704 704 lastequality = None
705 705 changes = True
706 706 pointer += 1
707 707
708 708 # Normalize the diff.
709 709 if changes:
710 710 self.diff_cleanupMerge(diffs)
711 711 self.diff_cleanupSemanticLossless(diffs)
712 712
713 713 # Find any overlaps between deletions and insertions.
714 714 # e.g: <del>abcxxx</del><ins>xxxdef</ins>
715 715 # -> <del>abc</del>xxx<ins>def</ins>
716 716 # e.g: <del>xxxabc</del><ins>defxxx</ins>
717 717 # -> <ins>def</ins>xxx<del>abc</del>
718 718 # Only extract an overlap if it is as big as the edit ahead or behind it.
719 719 pointer = 1
720 720 while pointer < len(diffs):
721 721 if (
722 722 diffs[pointer - 1][0] == self.DIFF_DELETE
723 723 and diffs[pointer][0] == self.DIFF_INSERT
724 724 ):
725 725 deletion = diffs[pointer - 1][1]
726 726 insertion = diffs[pointer][1]
727 727 overlap_length1 = self.diff_commonOverlap(deletion, insertion)
728 728 overlap_length2 = self.diff_commonOverlap(insertion, deletion)
729 729 if overlap_length1 >= overlap_length2:
730 730 if (
731 731 overlap_length1 >= len(deletion) / 2.0
732 732 or overlap_length1 >= len(insertion) / 2.0
733 733 ):
734 734 # Overlap found. Insert an equality and trim the surrounding edits.
735 735 diffs.insert(
736 736 pointer, (self.DIFF_EQUAL, insertion[:overlap_length1])
737 737 )
738 738 diffs[pointer - 1] = (
739 739 self.DIFF_DELETE,
740 740 deletion[: len(deletion) - overlap_length1],
741 741 )
742 742 diffs[pointer + 1] = (
743 743 self.DIFF_INSERT,
744 744 insertion[overlap_length1:],
745 745 )
746 746 pointer += 1
747 747 else:
748 748 if (
749 749 overlap_length2 >= len(deletion) / 2.0
750 750 or overlap_length2 >= len(insertion) / 2.0
751 751 ):
752 752 # Reverse overlap found.
753 753 # Insert an equality and swap and trim the surrounding edits.
754 754 diffs.insert(
755 755 pointer, (self.DIFF_EQUAL, deletion[:overlap_length2])
756 756 )
757 757 diffs[pointer - 1] = (
758 758 self.DIFF_INSERT,
759 759 insertion[: len(insertion) - overlap_length2],
760 760 )
761 761 diffs[pointer + 1] = (
762 762 self.DIFF_DELETE,
763 763 deletion[overlap_length2:],
764 764 )
765 765 pointer += 1
766 766 pointer += 1
767 767 pointer += 1
768 768
769 769 def diff_cleanupSemanticLossless(self, diffs):
770 770 """Look for single edits surrounded on both sides by equalities
771 771 which can be shifted sideways to align the edit to a word boundary.
772 772 e.g: The c<ins>at c</ins>ame. -> The <ins>cat </ins>came.
773 773
774 774 Args:
775 775 diffs: Array of diff tuples.
776 776 """
777 777
778 778 def diff_cleanupSemanticScore(one, two):
779 779 """Given two strings, compute a score representing whether the
780 780 internal boundary falls on logical boundaries.
781 781 Scores range from 6 (best) to 0 (worst).
782 782 Closure, but does not reference any external variables.
783 783
784 784 Args:
785 785 one: First string.
786 786 two: Second string.
787 787
788 788 Returns:
789 789 The score.
790 790 """
791 791 if not one or not two:
792 792 # Edges are the best.
793 793 return 6
794 794
795 795 # Each port of this function behaves slightly differently due to
796 796 # subtle differences in each language's definition of things like
797 797 # 'whitespace'. Since this function's purpose is largely cosmetic,
798 798 # the choice has been made to use each language's native features
799 799 # rather than force total conformity.
800 800 char1 = one[-1]
801 801 char2 = two[0]
802 802 nonAlphaNumeric1 = not char1.isalnum()
803 803 nonAlphaNumeric2 = not char2.isalnum()
804 804 whitespace1 = nonAlphaNumeric1 and char1.isspace()
805 805 whitespace2 = nonAlphaNumeric2 and char2.isspace()
806 806 lineBreak1 = whitespace1 and (char1 == "\r" or char1 == "\n")
807 807 lineBreak2 = whitespace2 and (char2 == "\r" or char2 == "\n")
808 808 blankLine1 = lineBreak1 and self.BLANKLINEEND.search(one)
809 809 blankLine2 = lineBreak2 and self.BLANKLINESTART.match(two)
810 810
811 811 if blankLine1 or blankLine2:
812 812 # Five points for blank lines.
813 813 return 5
814 814 elif lineBreak1 or lineBreak2:
815 815 # Four points for line breaks.
816 816 return 4
817 817 elif nonAlphaNumeric1 and not whitespace1 and whitespace2:
818 818 # Three points for end of sentences.
819 819 return 3
820 820 elif whitespace1 or whitespace2:
821 821 # Two points for whitespace.
822 822 return 2
823 823 elif nonAlphaNumeric1 or nonAlphaNumeric2:
824 824 # One point for non-alphanumeric.
825 825 return 1
826 826 return 0
827 827
828 828 pointer = 1
829 829 # Intentionally ignore the first and last element (don't need checking).
830 830 while pointer < len(diffs) - 1:
831 831 if (
832 832 diffs[pointer - 1][0] == self.DIFF_EQUAL
833 833 and diffs[pointer + 1][0] == self.DIFF_EQUAL
834 834 ):
835 835 # This is a single edit surrounded by equalities.
836 836 equality1 = diffs[pointer - 1][1]
837 837 edit = diffs[pointer][1]
838 838 equality2 = diffs[pointer + 1][1]
839 839
840 840 # First, shift the edit as far left as possible.
841 841 commonOffset = self.diff_commonSuffix(equality1, edit)
842 842 if commonOffset:
843 843 commonString = edit[-commonOffset:]
844 844 equality1 = equality1[:-commonOffset]
845 845 edit = commonString + edit[:-commonOffset]
846 846 equality2 = commonString + equality2
847 847
848 848 # Second, step character by character right, looking for the best fit.
849 849 bestEquality1 = equality1
850 850 bestEdit = edit
851 851 bestEquality2 = equality2
852 852 bestScore = diff_cleanupSemanticScore(
853 853 equality1, edit
854 854 ) + diff_cleanupSemanticScore(edit, equality2)
855 855 while edit and equality2 and edit[0] == equality2[0]:
856 856 equality1 += edit[0]
857 857 edit = edit[1:] + equality2[0]
858 858 equality2 = equality2[1:]
859 859 score = diff_cleanupSemanticScore(
860 860 equality1, edit
861 861 ) + diff_cleanupSemanticScore(edit, equality2)
862 862 # The >= encourages trailing rather than leading whitespace on edits.
863 863 if score >= bestScore:
864 864 bestScore = score
865 865 bestEquality1 = equality1
866 866 bestEdit = edit
867 867 bestEquality2 = equality2
868 868
869 869 if diffs[pointer - 1][1] != bestEquality1:
870 870 # We have an improvement, save it back to the diff.
871 871 if bestEquality1:
872 872 diffs[pointer - 1] = (diffs[pointer - 1][0], bestEquality1)
873 873 else:
874 874 del diffs[pointer - 1]
875 875 pointer -= 1
876 876 diffs[pointer] = (diffs[pointer][0], bestEdit)
877 877 if bestEquality2:
878 878 diffs[pointer + 1] = (diffs[pointer + 1][0], bestEquality2)
879 879 else:
880 880 del diffs[pointer + 1]
881 881 pointer -= 1
882 882 pointer += 1
883 883
884 884 # Define some regex patterns for matching boundaries.
885 885 BLANKLINEEND = re.compile(r"\n\r?\n$")
886 886 BLANKLINESTART = re.compile(r"^\r?\n\r?\n")
887 887
888 888 def diff_cleanupEfficiency(self, diffs):
889 889 """Reduce the number of edits by eliminating operationally trivial
890 890 equalities.
891 891
892 892 Args:
893 893 diffs: Array of diff tuples.
894 894 """
895 895 changes = False
896 896 equalities = [] # Stack of indices where equalities are found.
897 897 lastequality = None # Always equal to diffs[equalities[-1]][1]
898 898 pointer = 0 # Index of current position.
899 899 pre_ins = False # Is there an insertion operation before the last equality.
900 900 pre_del = False # Is there a deletion operation before the last equality.
901 901 post_ins = False # Is there an insertion operation after the last equality.
902 902 post_del = False # Is there a deletion operation after the last equality.
903 903 while pointer < len(diffs):
904 904 if diffs[pointer][0] == self.DIFF_EQUAL: # Equality found.
905 905 if len(diffs[pointer][1]) < self.Diff_EditCost and (
906 906 post_ins or post_del
907 907 ):
908 908 # Candidate found.
909 909 equalities.append(pointer)
910 910 pre_ins = post_ins
911 911 pre_del = post_del
912 912 lastequality = diffs[pointer][1]
913 913 else:
914 914 # Not a candidate, and can never become one.
915 915 equalities = []
916 916 lastequality = None
917 917
918 918 post_ins = post_del = False
919 919 else: # An insertion or deletion.
920 920 if diffs[pointer][0] == self.DIFF_DELETE:
921 921 post_del = True
922 922 else:
923 923 post_ins = True
924 924
925 925 # Five types to be split:
926 926 # <ins>A</ins><del>B</del>XY<ins>C</ins><del>D</del>
927 927 # <ins>A</ins>X<ins>C</ins><del>D</del>
928 928 # <ins>A</ins><del>B</del>X<ins>C</ins>
929 929 # <ins>A</del>X<ins>C</ins><del>D</del>
930 930 # <ins>A</ins><del>B</del>X<del>C</del>
931 931
932 932 if lastequality and (
933 933 (pre_ins and pre_del and post_ins and post_del)
934 934 or (
935 935 (len(lastequality) < self.Diff_EditCost / 2)
936 936 and (pre_ins + pre_del + post_ins + post_del) == 3
937 937 )
938 938 ):
939 939 # Duplicate record.
940 940 diffs.insert(equalities[-1], (self.DIFF_DELETE, lastequality))
941 941 # Change second copy to insert.
942 942 diffs[equalities[-1] + 1] = (
943 943 self.DIFF_INSERT,
944 944 diffs[equalities[-1] + 1][1],
945 945 )
946 946 equalities.pop() # Throw away the equality we just deleted.
947 947 lastequality = None
948 948 if pre_ins and pre_del:
949 949 # No changes made which could affect previous entry, keep going.
950 950 post_ins = post_del = True
951 951 equalities = []
952 952 else:
953 953 if len(equalities):
954 954 equalities.pop() # Throw away the previous equality.
955 955 if len(equalities):
956 956 pointer = equalities[-1]
957 957 else:
958 958 pointer = -1
959 959 post_ins = post_del = False
960 960 changes = True
961 961 pointer += 1
962 962
963 963 if changes:
964 964 self.diff_cleanupMerge(diffs)
965 965
966 966 def diff_cleanupMerge(self, diffs):
967 967 """Reorder and merge like edit sections. Merge equalities.
968 968 Any edit section can move as long as it doesn't cross an equality.
969 969
970 970 Args:
971 971 diffs: Array of diff tuples.
972 972 """
973 973 diffs.append((self.DIFF_EQUAL, "")) # Add a dummy entry at the end.
974 974 pointer = 0
975 975 count_delete = 0
976 976 count_insert = 0
977 977 text_delete = ""
978 978 text_insert = ""
979 979 while pointer < len(diffs):
980 980 if diffs[pointer][0] == self.DIFF_INSERT:
981 981 count_insert += 1
982 982 text_insert += diffs[pointer][1]
983 983 pointer += 1
984 984 elif diffs[pointer][0] == self.DIFF_DELETE:
985 985 count_delete += 1
986 986 text_delete += diffs[pointer][1]
987 987 pointer += 1
988 988 elif diffs[pointer][0] == self.DIFF_EQUAL:
989 989 # Upon reaching an equality, check for prior redundancies.
990 990 if count_delete + count_insert > 1:
991 991 if count_delete != 0 and count_insert != 0:
992 992 # Factor out any common prefixies.
993 993 commonlength = self.diff_commonPrefix(text_insert, text_delete)
994 994 if commonlength != 0:
995 995 x = pointer - count_delete - count_insert - 1
996 996 if x >= 0 and diffs[x][0] == self.DIFF_EQUAL:
997 997 diffs[x] = (
998 998 diffs[x][0],
999 999 diffs[x][1] + text_insert[:commonlength],
1000 1000 )
1001 1001 else:
1002 1002 diffs.insert(
1003 1003 0, (self.DIFF_EQUAL, text_insert[:commonlength])
1004 1004 )
1005 1005 pointer += 1
1006 1006 text_insert = text_insert[commonlength:]
1007 1007 text_delete = text_delete[commonlength:]
1008 1008 # Factor out any common suffixies.
1009 1009 commonlength = self.diff_commonSuffix(text_insert, text_delete)
1010 1010 if commonlength != 0:
1011 1011 diffs[pointer] = (
1012 1012 diffs[pointer][0],
1013 1013 text_insert[-commonlength:] + diffs[pointer][1],
1014 1014 )
1015 1015 text_insert = text_insert[:-commonlength]
1016 1016 text_delete = text_delete[:-commonlength]
1017 1017 # Delete the offending records and add the merged ones.
1018 1018 if count_delete == 0:
1019 1019 diffs[pointer - count_insert : pointer] = [
1020 1020 (self.DIFF_INSERT, text_insert)
1021 1021 ]
1022 1022 elif count_insert == 0:
1023 1023 diffs[pointer - count_delete : pointer] = [
1024 1024 (self.DIFF_DELETE, text_delete)
1025 1025 ]
1026 1026 else:
1027 1027 diffs[pointer - count_delete - count_insert : pointer] = [
1028 1028 (self.DIFF_DELETE, text_delete),
1029 1029 (self.DIFF_INSERT, text_insert),
1030 1030 ]
1031 1031 pointer = pointer - count_delete - count_insert + 1
1032 1032 if count_delete != 0:
1033 1033 pointer += 1
1034 1034 if count_insert != 0:
1035 1035 pointer += 1
1036 1036 elif pointer != 0 and diffs[pointer - 1][0] == self.DIFF_EQUAL:
1037 1037 # Merge this equality with the previous one.
1038 1038 diffs[pointer - 1] = (
1039 1039 diffs[pointer - 1][0],
1040 1040 diffs[pointer - 1][1] + diffs[pointer][1],
1041 1041 )
1042 1042 del diffs[pointer]
1043 1043 else:
1044 1044 pointer += 1
1045 1045
1046 1046 count_insert = 0
1047 1047 count_delete = 0
1048 1048 text_delete = ""
1049 1049 text_insert = ""
1050 1050
1051 1051 if diffs[-1][1] == "":
1052 1052 diffs.pop() # Remove the dummy entry at the end.
1053 1053
1054 1054 # Second pass: look for single edits surrounded on both sides by equalities
1055 1055 # which can be shifted sideways to eliminate an equality.
1056 1056 # e.g: A<ins>BA</ins>C -> <ins>AB</ins>AC
1057 1057 changes = False
1058 1058 pointer = 1
1059 1059 # Intentionally ignore the first and last element (don't need checking).
1060 1060 while pointer < len(diffs) - 1:
1061 1061 if (
1062 1062 diffs[pointer - 1][0] == self.DIFF_EQUAL
1063 1063 and diffs[pointer + 1][0] == self.DIFF_EQUAL
1064 1064 ):
1065 1065 # This is a single edit surrounded by equalities.
1066 1066 if diffs[pointer][1].endswith(diffs[pointer - 1][1]):
1067 1067 # Shift the edit over the previous equality.
1068 1068 diffs[pointer] = (
1069 1069 diffs[pointer][0],
1070 1070 diffs[pointer - 1][1]
1071 1071 + diffs[pointer][1][: -len(diffs[pointer - 1][1])],
1072 1072 )
1073 1073 diffs[pointer + 1] = (
1074 1074 diffs[pointer + 1][0],
1075 1075 diffs[pointer - 1][1] + diffs[pointer + 1][1],
1076 1076 )
1077 1077 del diffs[pointer - 1]
1078 1078 changes = True
1079 1079 elif diffs[pointer][1].startswith(diffs[pointer + 1][1]):
1080 1080 # Shift the edit over the next equality.
1081 1081 diffs[pointer - 1] = (
1082 1082 diffs[pointer - 1][0],
1083 1083 diffs[pointer - 1][1] + diffs[pointer + 1][1],
1084 1084 )
1085 1085 diffs[pointer] = (
1086 1086 diffs[pointer][0],
1087 1087 diffs[pointer][1][len(diffs[pointer + 1][1]) :]
1088 1088 + diffs[pointer + 1][1],
1089 1089 )
1090 1090 del diffs[pointer + 1]
1091 1091 changes = True
1092 1092 pointer += 1
1093 1093
1094 1094 # If shifts were made, the diff needs reordering and another shift sweep.
1095 1095 if changes:
1096 1096 self.diff_cleanupMerge(diffs)
1097 1097
1098 1098 def diff_xIndex(self, diffs, loc):
1099 1099 """loc is a location in text1, compute and return the equivalent location
1100 1100 in text2. e.g. "The cat" vs "The big cat", 1->1, 5->8
1101 1101
1102 1102 Args:
1103 1103 diffs: Array of diff tuples.
1104 1104 loc: Location within text1.
1105 1105
1106 1106 Returns:
1107 1107 Location within text2.
1108 1108 """
1109 1109 chars1 = 0
1110 1110 chars2 = 0
1111 1111 last_chars1 = 0
1112 1112 last_chars2 = 0
1113 1113 for x in range(len(diffs)):
1114 1114 (op, text) = diffs[x]
1115 1115 if op != self.DIFF_INSERT: # Equality or deletion.
1116 1116 chars1 += len(text)
1117 1117 if op != self.DIFF_DELETE: # Equality or insertion.
1118 1118 chars2 += len(text)
1119 1119 if chars1 > loc: # Overshot the location.
1120 1120 break
1121 1121 last_chars1 = chars1
1122 1122 last_chars2 = chars2
1123 1123
1124 1124 if len(diffs) != x and diffs[x][0] == self.DIFF_DELETE:
1125 1125 # The location was deleted.
1126 1126 return last_chars2
1127 1127 # Add the remaining len(character).
1128 1128 return last_chars2 + (loc - last_chars1)
1129 1129
1130 1130 def diff_prettyHtml(self, diffs):
1131 1131 """Convert a diff array into a pretty HTML report.
1132 1132
1133 1133 Args:
1134 1134 diffs: Array of diff tuples.
1135 1135
1136 1136 Returns:
1137 1137 HTML representation.
1138 1138 """
1139 1139 html = []
1140 1140 for op, data in diffs:
1141 1141 text = (
1142 1142 data.replace("&", "&amp;")
1143 1143 .replace("<", "&lt;")
1144 1144 .replace(">", "&gt;")
1145 1145 .replace("\n", "&para;<br>")
1146 1146 )
1147 1147 if op == self.DIFF_INSERT:
1148 1148 html.append('<ins style="background:#e6ffe6;">%s</ins>' % text)
1149 1149 elif op == self.DIFF_DELETE:
1150 1150 html.append('<del style="background:#ffe6e6;">%s</del>' % text)
1151 1151 elif op == self.DIFF_EQUAL:
1152 1152 html.append("<span>%s</span>" % text)
1153 1153 return "".join(html)
1154 1154
1155 1155 def diff_text1(self, diffs):
1156 1156 """Compute and return the source text (all equalities and deletions).
1157 1157
1158 1158 Args:
1159 1159 diffs: Array of diff tuples.
1160 1160
1161 1161 Returns:
1162 1162 Source text.
1163 1163 """
1164 1164 text = []
1165 1165 for op, data in diffs:
1166 1166 if op != self.DIFF_INSERT:
1167 1167 text.append(data)
1168 1168 return "".join(text)
1169 1169
1170 1170 def diff_text2(self, diffs):
1171 1171 """Compute and return the destination text (all equalities and insertions).
1172 1172
1173 1173 Args:
1174 1174 diffs: Array of diff tuples.
1175 1175
1176 1176 Returns:
1177 1177 Destination text.
1178 1178 """
1179 1179 text = []
1180 1180 for op, data in diffs:
1181 1181 if op != self.DIFF_DELETE:
1182 1182 text.append(data)
1183 1183 return "".join(text)
1184 1184
1185 1185 def diff_levenshtein(self, diffs):
1186 1186 """Compute the Levenshtein distance; the number of inserted, deleted or
1187 1187 substituted characters.
1188 1188
1189 1189 Args:
1190 1190 diffs: Array of diff tuples.
1191 1191
1192 1192 Returns:
1193 1193 Number of changes.
1194 1194 """
1195 1195 levenshtein = 0
1196 1196 insertions = 0
1197 1197 deletions = 0
1198 1198 for op, data in diffs:
1199 1199 if op == self.DIFF_INSERT:
1200 1200 insertions += len(data)
1201 1201 elif op == self.DIFF_DELETE:
1202 1202 deletions += len(data)
1203 1203 elif op == self.DIFF_EQUAL:
1204 1204 # A deletion and an insertion is one substitution.
1205 1205 levenshtein += max(insertions, deletions)
1206 1206 insertions = 0
1207 1207 deletions = 0
1208 1208 levenshtein += max(insertions, deletions)
1209 1209 return levenshtein
1210 1210
1211 1211 def diff_toDelta(self, diffs):
1212 1212 """Crush the diff into an encoded string which describes the operations
1213 1213 required to transform text1 into text2.
1214 1214 E.g. =3\t-2\t+ing -> Keep 3 chars, delete 2 chars, insert 'ing'.
1215 1215 Operations are tab-separated. Inserted text is escaped using %xx notation.
1216 1216
1217 1217 Args:
1218 1218 diffs: Array of diff tuples.
1219 1219
1220 1220 Returns:
1221 1221 Delta text.
1222 1222 """
1223 1223 text = []
1224 1224 for op, data in diffs:
1225 1225 if op == self.DIFF_INSERT:
1226 1226 # High ascii will raise UnicodeDecodeError. Use Unicode instead.
1227 1227 data = data.encode("utf-8")
1228 1228 text.append("+" + urllib.parse.quote(data, "!~*'();/?:@&=+$,# "))
1229 1229 elif op == self.DIFF_DELETE:
1230 1230 text.append("-%d" % len(data))
1231 1231 elif op == self.DIFF_EQUAL:
1232 1232 text.append("=%d" % len(data))
1233 1233 return "\t".join(text)
1234 1234
1235 1235 def diff_fromDelta(self, text1, delta):
1236 1236 """Given the original text1, and an encoded string which describes the
1237 1237 operations required to transform text1 into text2, compute the full diff.
1238 1238
1239 1239 Args:
1240 1240 text1: Source string for the diff.
1241 1241 delta: Delta text.
1242 1242
1243 1243 Returns:
1244 1244 Array of diff tuples.
1245 1245
1246 1246 Raises:
1247 1247 ValueError: If invalid input.
1248 1248 """
1249 1249 if type(delta) == str:
1250 1250 # Deltas should be composed of a subset of ascii chars, Unicode not
1251 1251 # required. If this encode raises UnicodeEncodeError, delta is invalid.
1252 1252 delta = delta.encode("ascii")
1253 1253 diffs = []
1254 1254 pointer = 0 # Cursor in text1
1255 1255 tokens = delta.split("\t")
1256 1256 for token in tokens:
1257 1257 if token == "":
1258 1258 # Blank tokens are ok (from a trailing \t).
1259 1259 continue
1260 1260 # Each token begins with a one character parameter which specifies the
1261 1261 # operation of this token (delete, insert, equality).
1262 1262 param = token[1:]
1263 1263 if token[0] == "+":
1264 1264 param = urllib.parse.unquote(param)
1265 1265 diffs.append((self.DIFF_INSERT, param))
1266 1266 elif token[0] == "-" or token[0] == "=":
1267 1267 try:
1268 1268 n = int(param)
1269 1269 except ValueError:
1270 1270 raise ValueError("Invalid number in diff_fromDelta: " + param)
1271 1271 if n < 0:
1272 1272 raise ValueError("Negative number in diff_fromDelta: " + param)
1273 1273 text = text1[pointer : pointer + n]
1274 1274 pointer += n
1275 1275 if token[0] == "=":
1276 1276 diffs.append((self.DIFF_EQUAL, text))
1277 1277 else:
1278 1278 diffs.append((self.DIFF_DELETE, text))
1279 1279 else:
1280 1280 # Anything else is an error.
1281 1281 raise ValueError(
1282 1282 "Invalid diff operation in diff_fromDelta: " + token[0]
1283 1283 )
1284 1284 if pointer != len(text1):
1285 1285 raise ValueError(
1286 1286 "Delta length (%d) does not equal source text length (%d)."
1287 1287 % (pointer, len(text1))
1288 1288 )
1289 1289 return diffs
1290 1290
1291 1291 # MATCH FUNCTIONS
1292 1292
1293 1293 def match_main(self, text, pattern, loc):
1294 1294 """Locate the best instance of 'pattern' in 'text' near 'loc'.
1295 1295
1296 1296 Args:
1297 1297 text: The text to search.
1298 1298 pattern: The pattern to search for.
1299 1299 loc: The location to search around.
1300 1300
1301 1301 Returns:
1302 1302 Best match index or -1.
1303 1303 """
1304 1304 # Check for null inputs.
1305 1305 if text is None or pattern is None:
1306 1306 raise ValueError("Null inputs. (match_main)")
1307 1307
1308 1308 loc = max(0, min(loc, len(text)))
1309 1309 if text == pattern:
1310 1310 # Shortcut (potentially not guaranteed by the algorithm)
1311 1311 return 0
1312 1312 elif not text:
1313 1313 # Nothing to match.
1314 1314 return -1
1315 1315 elif text[loc : loc + len(pattern)] == pattern:
1316 1316 # Perfect match at the perfect spot! (Includes case of null pattern)
1317 1317 return loc
1318 1318 else:
1319 1319 # Do a fuzzy compare.
1320 1320 match = self.match_bitap(text, pattern, loc)
1321 1321 return match
1322 1322
1323 1323 def match_bitap(self, text, pattern, loc):
1324 1324 """Locate the best instance of 'pattern' in 'text' near 'loc' using the
1325 1325 Bitap algorithm.
1326 1326
1327 1327 Args:
1328 1328 text: The text to search.
1329 1329 pattern: The pattern to search for.
1330 1330 loc: The location to search around.
1331 1331
1332 1332 Returns:
1333 1333 Best match index or -1.
1334 1334 """
1335 1335 # Python doesn't have a maxint limit, so ignore this check.
1336 1336 # if self.Match_MaxBits != 0 and len(pattern) > self.Match_MaxBits:
1337 1337 # raise ValueError("Pattern too long for this application.")
1338 1338
1339 1339 # Initialise the alphabet.
1340 1340 s = self.match_alphabet(pattern)
1341 1341
1342 1342 def match_bitapScore(e, x):
1343 1343 """Compute and return the score for a match with e errors and x location.
1344 1344 Accesses loc and pattern through being a closure.
1345 1345
1346 1346 Args:
1347 1347 e: Number of errors in match.
1348 1348 x: Location of match.
1349 1349
1350 1350 Returns:
1351 1351 Overall score for match (0.0 = good, 1.0 = bad).
1352 1352 """
1353 1353 accuracy = float(e) / len(pattern)
1354 1354 proximity = abs(loc - x)
1355 1355 if not self.Match_Distance:
1356 1356 # Dodge divide by zero error.
1357 1357 return proximity and 1.0 or accuracy
1358 1358 return accuracy + (proximity / float(self.Match_Distance))
1359 1359
1360 1360 # Highest score beyond which we give up.
1361 1361 score_threshold = self.Match_Threshold
1362 1362 # Is there a nearby exact match? (speedup)
1363 1363 best_loc = text.find(pattern, loc)
1364 1364 if best_loc != -1:
1365 1365 score_threshold = min(match_bitapScore(0, best_loc), score_threshold)
1366 1366 # What about in the other direction? (speedup)
1367 1367 best_loc = text.rfind(pattern, loc + len(pattern))
1368 1368 if best_loc != -1:
1369 1369 score_threshold = min(match_bitapScore(0, best_loc), score_threshold)
1370 1370
1371 1371 # Initialise the bit arrays.
1372 1372 matchmask = 1 << (len(pattern) - 1)
1373 1373 best_loc = -1
1374 1374
1375 1375 bin_max = len(pattern) + len(text)
1376 1376 # Empty initialization added to appease pychecker.
1377 1377 last_rd = None
1378 1378 for d in range(len(pattern)):
1379 1379 # Scan for the best match each iteration allows for one more error.
1380 1380 # Run a binary search to determine how far from 'loc' we can stray at
1381 1381 # this error level.
1382 1382 bin_min = 0
1383 1383 bin_mid = bin_max
1384 1384 while bin_min < bin_mid:
1385 1385 if match_bitapScore(d, loc + bin_mid) <= score_threshold:
1386 1386 bin_min = bin_mid
1387 1387 else:
1388 1388 bin_max = bin_mid
1389 1389 bin_mid = (bin_max - bin_min) // 2 + bin_min
1390 1390
1391 1391 # Use the result from this iteration as the maximum for the next.
1392 1392 bin_max = bin_mid
1393 1393 start = max(1, loc - bin_mid + 1)
1394 1394 finish = min(loc + bin_mid, len(text)) + len(pattern)
1395 1395
1396 1396 rd = [0] * (finish + 2)
1397 1397 rd[finish + 1] = (1 << d) - 1
1398 1398 for j in range(finish, start - 1, -1):
1399 1399 if len(text) <= j - 1:
1400 1400 # Out of range.
1401 1401 charMatch = 0
1402 1402 else:
1403 1403 charMatch = s.get(text[j - 1], 0)
1404 1404 if d == 0: # First pass: exact match.
1405 1405 rd[j] = ((rd[j + 1] << 1) | 1) & charMatch
1406 1406 else: # Subsequent passes: fuzzy match.
1407 1407 rd[j] = (
1408 1408 (((rd[j + 1] << 1) | 1) & charMatch)
1409 1409 | (((last_rd[j + 1] | last_rd[j]) << 1) | 1)
1410 1410 | last_rd[j + 1]
1411 1411 )
1412 1412 if rd[j] & matchmask:
1413 1413 score = match_bitapScore(d, j - 1)
1414 1414 # This match will almost certainly be better than any existing match.
1415 1415 # But check anyway.
1416 1416 if score <= score_threshold:
1417 1417 # Told you so.
1418 1418 score_threshold = score
1419 1419 best_loc = j - 1
1420 1420 if best_loc > loc:
1421 1421 # When passing loc, don't exceed our current distance from loc.
1422 1422 start = max(1, 2 * loc - best_loc)
1423 1423 else:
1424 1424 # Already passed loc, downhill from here on in.
1425 1425 break
1426 1426 # No hope for a (better) match at greater error levels.
1427 1427 if match_bitapScore(d + 1, loc) > score_threshold:
1428 1428 break
1429 1429 last_rd = rd
1430 1430 return best_loc
1431 1431
1432 1432 def match_alphabet(self, pattern):
1433 1433 """Initialise the alphabet for the Bitap algorithm.
1434 1434
1435 1435 Args:
1436 1436 pattern: The text to encode.
1437 1437
1438 1438 Returns:
1439 1439 Hash of character locations.
1440 1440 """
1441 1441 s = {}
1442 1442 for char in pattern:
1443 1443 s[char] = 0
1444 1444 for i in range(len(pattern)):
1445 1445 s[pattern[i]] |= 1 << (len(pattern) - i - 1)
1446 1446 return s
1447 1447
1448 1448 # PATCH FUNCTIONS
1449 1449
1450 1450 def patch_addContext(self, patch, text):
1451 1451 """Increase the context until it is unique,
1452 1452 but don't let the pattern expand beyond Match_MaxBits.
1453 1453
1454 1454 Args:
1455 1455 patch: The patch to grow.
1456 1456 text: Source text.
1457 1457 """
1458 1458 if len(text) == 0:
1459 1459 return
1460 1460 pattern = text[patch.start2 : patch.start2 + patch.length1]
1461 1461 padding = 0
1462 1462
1463 1463 # Look for the first and last matches of pattern in text. If two different
1464 1464 # matches are found, increase the pattern length.
1465 1465 while text.find(pattern) != text.rfind(pattern) and (
1466 1466 self.Match_MaxBits == 0
1467 1467 or len(pattern) < self.Match_MaxBits - self.Patch_Margin - self.Patch_Margin
1468 1468 ):
1469 1469 padding += self.Patch_Margin
1470 1470 pattern = text[
1471 1471 max(0, patch.start2 - padding) : patch.start2 + patch.length1 + padding
1472 1472 ]
1473 1473 # Add one chunk for good luck.
1474 1474 padding += self.Patch_Margin
1475 1475
1476 1476 # Add the prefix.
1477 1477 prefix = text[max(0, patch.start2 - padding) : patch.start2]
1478 1478 if prefix:
1479 1479 patch.diffs[:0] = [(self.DIFF_EQUAL, prefix)]
1480 1480 # Add the suffix.
1481 1481 suffix = text[
1482 1482 patch.start2 + patch.length1 : patch.start2 + patch.length1 + padding
1483 1483 ]
1484 1484 if suffix:
1485 1485 patch.diffs.append((self.DIFF_EQUAL, suffix))
1486 1486
1487 1487 # Roll back the start points.
1488 1488 patch.start1 -= len(prefix)
1489 1489 patch.start2 -= len(prefix)
1490 1490 # Extend lengths.
1491 1491 patch.length1 += len(prefix) + len(suffix)
1492 1492 patch.length2 += len(prefix) + len(suffix)
1493 1493
1494 1494 def patch_make(self, a, b=None, c=None):
1495 1495 """Compute a list of patches to turn text1 into text2.
1496 1496 Use diffs if provided, otherwise compute it ourselves.
1497 1497 There are four ways to call this function, depending on what data is
1498 1498 available to the caller:
1499 1499 Method 1:
1500 1500 a = text1, b = text2
1501 1501 Method 2:
1502 1502 a = diffs
1503 1503 Method 3 (optimal):
1504 1504 a = text1, b = diffs
1505 1505 Method 4 (deprecated, use method 3):
1506 1506 a = text1, b = text2, c = diffs
1507 1507
1508 1508 Args:
1509 1509 a: text1 (methods 1,3,4) or Array of diff tuples for text1 to
1510 1510 text2 (method 2).
1511 1511 b: text2 (methods 1,4) or Array of diff tuples for text1 to
1512 1512 text2 (method 3) or undefined (method 2).
1513 1513 c: Array of diff tuples for text1 to text2 (method 4) or
1514 1514 undefined (methods 1,2,3).
1515 1515
1516 1516 Returns:
1517 1517 Array of Patch objects.
1518 1518 """
1519 1519 text1 = None
1520 1520 diffs = None
1521 1521 # Note that texts may arrive as 'str' or 'unicode'.
1522 1522 if isinstance(a, str) and isinstance(b, str) and c is None:
1523 1523 # Method 1: text1, text2
1524 1524 # Compute diffs from text1 and text2.
1525 1525 text1 = a
1526 1526 diffs = self.diff_main(text1, b, True)
1527 1527 if len(diffs) > 2:
1528 1528 self.diff_cleanupSemantic(diffs)
1529 1529 self.diff_cleanupEfficiency(diffs)
1530 1530 elif isinstance(a, list) and b is None and c is None:
1531 1531 # Method 2: diffs
1532 1532 # Compute text1 from diffs.
1533 1533 diffs = a
1534 1534 text1 = self.diff_text1(diffs)
1535 1535 elif isinstance(a, str) and isinstance(b, list) and c is None:
1536 1536 # Method 3: text1, diffs
1537 1537 text1 = a
1538 1538 diffs = b
1539 1539 elif isinstance(a, str) and isinstance(b, str) and isinstance(c, list):
1540 1540 # Method 4: text1, text2, diffs
1541 1541 # text2 is not used.
1542 1542 text1 = a
1543 1543 diffs = c
1544 1544 else:
1545 1545 raise ValueError("Unknown call format to patch_make.")
1546 1546
1547 1547 if not diffs:
1548 1548 return [] # Get rid of the None case.
1549 1549 patches = []
1550 1550 patch = patch_obj()
1551 1551 char_count1 = 0 # Number of characters into the text1 string.
1552 1552 char_count2 = 0 # Number of characters into the text2 string.
1553 1553 prepatch_text = text1 # Recreate the patches to determine context info.
1554 1554 postpatch_text = text1
1555 1555 for x in range(len(diffs)):
1556 1556 (diff_type, diff_text) = diffs[x]
1557 1557 if len(patch.diffs) == 0 and diff_type != self.DIFF_EQUAL:
1558 1558 # A new patch starts here.
1559 1559 patch.start1 = char_count1
1560 1560 patch.start2 = char_count2
1561 1561 if diff_type == self.DIFF_INSERT:
1562 1562 # Insertion
1563 1563 patch.diffs.append(diffs[x])
1564 1564 patch.length2 += len(diff_text)
1565 1565 postpatch_text = (
1566 1566 postpatch_text[:char_count2]
1567 1567 + diff_text
1568 1568 + postpatch_text[char_count2:]
1569 1569 )
1570 1570 elif diff_type == self.DIFF_DELETE:
1571 1571 # Deletion.
1572 1572 patch.length1 += len(diff_text)
1573 1573 patch.diffs.append(diffs[x])
1574 1574 postpatch_text = (
1575 1575 postpatch_text[:char_count2]
1576 1576 + postpatch_text[char_count2 + len(diff_text) :]
1577 1577 )
1578 1578 elif (
1579 1579 diff_type == self.DIFF_EQUAL
1580 1580 and len(diff_text) <= 2 * self.Patch_Margin
1581 1581 and len(patch.diffs) != 0
1582 1582 and len(diffs) != x + 1
1583 1583 ):
1584 1584 # Small equality inside a patch.
1585 1585 patch.diffs.append(diffs[x])
1586 1586 patch.length1 += len(diff_text)
1587 1587 patch.length2 += len(diff_text)
1588 1588
1589 1589 if diff_type == self.DIFF_EQUAL and len(diff_text) >= 2 * self.Patch_Margin:
1590 1590 # Time for a new patch.
1591 1591 if len(patch.diffs) != 0:
1592 1592 self.patch_addContext(patch, prepatch_text)
1593 1593 patches.append(patch)
1594 1594 patch = patch_obj()
1595 1595 # Unlike Unidiff, our patch lists have a rolling context.
1596 1596 # http://code.google.com/p/google-diff-match-patch/wiki/Unidiff
1597 1597 # Update prepatch text & pos to reflect the application of the
1598 1598 # just completed patch.
1599 1599 prepatch_text = postpatch_text
1600 1600 char_count1 = char_count2
1601 1601
1602 1602 # Update the current character count.
1603 1603 if diff_type != self.DIFF_INSERT:
1604 1604 char_count1 += len(diff_text)
1605 1605 if diff_type != self.DIFF_DELETE:
1606 1606 char_count2 += len(diff_text)
1607 1607
1608 1608 # Pick up the leftover patch if not empty.
1609 1609 if len(patch.diffs) != 0:
1610 1610 self.patch_addContext(patch, prepatch_text)
1611 1611 patches.append(patch)
1612 1612 return patches
1613 1613
1614 1614 def patch_deepCopy(self, patches):
1615 1615 """Given an array of patches, return another array that is identical.
1616 1616
1617 1617 Args:
1618 1618 patches: Array of Patch objects.
1619 1619
1620 1620 Returns:
1621 1621 Array of Patch objects.
1622 1622 """
1623 1623 patchesCopy = []
1624 1624 for patch in patches:
1625 1625 patchCopy = patch_obj()
1626 1626 # No need to deep copy the tuples since they are immutable.
1627 1627 patchCopy.diffs = patch.diffs[:]
1628 1628 patchCopy.start1 = patch.start1
1629 1629 patchCopy.start2 = patch.start2
1630 1630 patchCopy.length1 = patch.length1
1631 1631 patchCopy.length2 = patch.length2
1632 1632 patchesCopy.append(patchCopy)
1633 1633 return patchesCopy
1634 1634
1635 1635 def patch_apply(self, patches, text):
1636 1636 """Merge a set of patches onto the text. Return a patched text, as well
1637 1637 as a list of true/false values indicating which patches were applied.
1638 1638
1639 1639 Args:
1640 1640 patches: Array of Patch objects.
1641 1641 text: Old text.
1642 1642
1643 1643 Returns:
1644 1644 Two element Array, containing the new text and an array of boolean values.
1645 1645 """
1646 1646 if not patches:
1647 1647 return (text, [])
1648 1648
1649 1649 # Deep copy the patches so that no changes are made to originals.
1650 1650 patches = self.patch_deepCopy(patches)
1651 1651
1652 1652 nullPadding = self.patch_addPadding(patches)
1653 1653 text = nullPadding + text + nullPadding
1654 1654 self.patch_splitMax(patches)
1655 1655
1656 1656 # delta keeps track of the offset between the expected and actual location
1657 1657 # of the previous patch. If there are patches expected at positions 10 and
1658 1658 # 20, but the first patch was found at 12, delta is 2 and the second patch
1659 1659 # has an effective expected position of 22.
1660 1660 delta = 0
1661 1661 results = []
1662 1662 for patch in patches:
1663 1663 expected_loc = patch.start2 + delta
1664 1664 text1 = self.diff_text1(patch.diffs)
1665 1665 end_loc = -1
1666 1666 if len(text1) > self.Match_MaxBits:
1667 1667 # patch_splitMax will only provide an oversized pattern in the case of
1668 1668 # a monster delete.
1669 1669 start_loc = self.match_main(
1670 1670 text, text1[: self.Match_MaxBits], expected_loc
1671 1671 )
1672 1672 if start_loc != -1:
1673 1673 end_loc = self.match_main(
1674 1674 text,
1675 1675 text1[-self.Match_MaxBits :],
1676 1676 expected_loc + len(text1) - self.Match_MaxBits,
1677 1677 )
1678 1678 if end_loc == -1 or start_loc >= end_loc:
1679 1679 # Can't find valid trailing context. Drop this patch.
1680 1680 start_loc = -1
1681 1681 else:
1682 1682 start_loc = self.match_main(text, text1, expected_loc)
1683 1683 if start_loc == -1:
1684 1684 # No match found. :(
1685 1685 results.append(False)
1686 1686 # Subtract the delta for this failed patch from subsequent patches.
1687 1687 delta -= patch.length2 - patch.length1
1688 1688 else:
1689 1689 # Found a match. :)
1690 1690 results.append(True)
1691 1691 delta = start_loc - expected_loc
1692 1692 if end_loc == -1:
1693 1693 text2 = text[start_loc : start_loc + len(text1)]
1694 1694 else:
1695 1695 text2 = text[start_loc : end_loc + self.Match_MaxBits]
1696 1696 if text1 == text2:
1697 1697 # Perfect match, just shove the replacement text in.
1698 1698 text = (
1699 1699 text[:start_loc]
1700 1700 + self.diff_text2(patch.diffs)
1701 1701 + text[start_loc + len(text1) :]
1702 1702 )
1703 1703 else:
1704 1704 # Imperfect match.
1705 1705 # Run a diff to get a framework of equivalent indices.
1706 1706 diffs = self.diff_main(text1, text2, False)
1707 1707 if (
1708 1708 len(text1) > self.Match_MaxBits
1709 1709 and self.diff_levenshtein(diffs) / float(len(text1))
1710 1710 > self.Patch_DeleteThreshold
1711 1711 ):
1712 1712 # The end points match, but the content is unacceptably bad.
1713 1713 results[-1] = False
1714 1714 else:
1715 1715 self.diff_cleanupSemanticLossless(diffs)
1716 1716 index1 = 0
1717 1717 for op, data in patch.diffs:
1718 1718 if op != self.DIFF_EQUAL:
1719 1719 index2 = self.diff_xIndex(diffs, index1)
1720 1720 if op == self.DIFF_INSERT: # Insertion
1721 1721 text = (
1722 1722 text[: start_loc + index2]
1723 1723 + data
1724 1724 + text[start_loc + index2 :]
1725 1725 )
1726 1726 elif op == self.DIFF_DELETE: # Deletion
1727 1727 text = (
1728 1728 text[: start_loc + index2]
1729 1729 + text[
1730 1730 start_loc
1731 1731 + self.diff_xIndex(diffs, index1 + len(data)) :
1732 1732 ]
1733 1733 )
1734 1734 if op != self.DIFF_DELETE:
1735 1735 index1 += len(data)
1736 1736 # Strip the padding off.
1737 1737 text = text[len(nullPadding) : -len(nullPadding)]
1738 1738 return (text, results)
1739 1739
1740 1740 def patch_addPadding(self, patches):
1741 1741 """Add some padding on text start and end so that edges can match
1742 1742 something. Intended to be called only from within patch_apply.
1743 1743
1744 1744 Args:
1745 1745 patches: Array of Patch objects.
1746 1746
1747 1747 Returns:
1748 1748 The padding string added to each side.
1749 1749 """
1750 1750 paddingLength = self.Patch_Margin
1751 1751 nullPadding = ""
1752 1752 for x in range(1, paddingLength + 1):
1753 1753 nullPadding += chr(x)
1754 1754
1755 1755 # Bump all the patches forward.
1756 1756 for patch in patches:
1757 1757 patch.start1 += paddingLength
1758 1758 patch.start2 += paddingLength
1759 1759
1760 1760 # Add some padding on start of first diff.
1761 1761 patch = patches[0]
1762 1762 diffs = patch.diffs
1763 1763 if not diffs or diffs[0][0] != self.DIFF_EQUAL:
1764 1764 # Add nullPadding equality.
1765 1765 diffs.insert(0, (self.DIFF_EQUAL, nullPadding))
1766 1766 patch.start1 -= paddingLength # Should be 0.
1767 1767 patch.start2 -= paddingLength # Should be 0.
1768 1768 patch.length1 += paddingLength
1769 1769 patch.length2 += paddingLength
1770 1770 elif paddingLength > len(diffs[0][1]):
1771 1771 # Grow first equality.
1772 1772 extraLength = paddingLength - len(diffs[0][1])
1773 1773 newText = nullPadding[len(diffs[0][1]) :] + diffs[0][1]
1774 1774 diffs[0] = (diffs[0][0], newText)
1775 1775 patch.start1 -= extraLength
1776 1776 patch.start2 -= extraLength
1777 1777 patch.length1 += extraLength
1778 1778 patch.length2 += extraLength
1779 1779
1780 1780 # Add some padding on end of last diff.
1781 1781 patch = patches[-1]
1782 1782 diffs = patch.diffs
1783 1783 if not diffs or diffs[-1][0] != self.DIFF_EQUAL:
1784 1784 # Add nullPadding equality.
1785 1785 diffs.append((self.DIFF_EQUAL, nullPadding))
1786 1786 patch.length1 += paddingLength
1787 1787 patch.length2 += paddingLength
1788 1788 elif paddingLength > len(diffs[-1][1]):
1789 1789 # Grow last equality.
1790 1790 extraLength = paddingLength - len(diffs[-1][1])
1791 1791 newText = diffs[-1][1] + nullPadding[:extraLength]
1792 1792 diffs[-1] = (diffs[-1][0], newText)
1793 1793 patch.length1 += extraLength
1794 1794 patch.length2 += extraLength
1795 1795
1796 1796 return nullPadding
1797 1797
1798 1798 def patch_splitMax(self, patches):
1799 1799 """Look through the patches and break up any which are longer than the
1800 1800 maximum limit of the match algorithm.
1801 1801 Intended to be called only from within patch_apply.
1802 1802
1803 1803 Args:
1804 1804 patches: Array of Patch objects.
1805 1805 """
1806 1806 patch_size = self.Match_MaxBits
1807 1807 if patch_size == 0:
1808 1808 # Python has the option of not splitting strings due to its ability
1809 1809 # to handle integers of arbitrary precision.
1810 1810 return
1811 1811 for x in range(len(patches)):
1812 1812 if patches[x].length1 <= patch_size:
1813 1813 continue
1814 1814 bigpatch = patches[x]
1815 1815 # Remove the big old patch.
1816 1816 del patches[x]
1817 1817 x -= 1
1818 1818 start1 = bigpatch.start1
1819 1819 start2 = bigpatch.start2
1820 1820 precontext = ""
1821 1821 while len(bigpatch.diffs) != 0:
1822 1822 # Create one of several smaller patches.
1823 1823 patch = patch_obj()
1824 1824 empty = True
1825 1825 patch.start1 = start1 - len(precontext)
1826 1826 patch.start2 = start2 - len(precontext)
1827 1827 if precontext:
1828 1828 patch.length1 = patch.length2 = len(precontext)
1829 1829 patch.diffs.append((self.DIFF_EQUAL, precontext))
1830 1830
1831 1831 while (
1832 1832 len(bigpatch.diffs) != 0
1833 1833 and patch.length1 < patch_size - self.Patch_Margin
1834 1834 ):
1835 1835 (diff_type, diff_text) = bigpatch.diffs[0]
1836 1836 if diff_type == self.DIFF_INSERT:
1837 1837 # Insertions are harmless.
1838 1838 patch.length2 += len(diff_text)
1839 1839 start2 += len(diff_text)
1840 1840 patch.diffs.append(bigpatch.diffs.pop(0))
1841 1841 empty = False
1842 1842 elif (
1843 1843 diff_type == self.DIFF_DELETE
1844 1844 and len(patch.diffs) == 1
1845 1845 and patch.diffs[0][0] == self.DIFF_EQUAL
1846 1846 and len(diff_text) > 2 * patch_size
1847 1847 ):
1848 1848 # This is a large deletion. Let it pass in one chunk.
1849 1849 patch.length1 += len(diff_text)
1850 1850 start1 += len(diff_text)
1851 1851 empty = False
1852 1852 patch.diffs.append((diff_type, diff_text))
1853 1853 del bigpatch.diffs[0]
1854 1854 else:
1855 1855 # Deletion or equality. Only take as much as we can stomach.
1856 1856 diff_text = diff_text[
1857 1857 : patch_size - patch.length1 - self.Patch_Margin
1858 1858 ]
1859 1859 patch.length1 += len(diff_text)
1860 1860 start1 += len(diff_text)
1861 1861 if diff_type == self.DIFF_EQUAL:
1862 1862 patch.length2 += len(diff_text)
1863 1863 start2 += len(diff_text)
1864 1864 else:
1865 1865 empty = False
1866 1866
1867 1867 patch.diffs.append((diff_type, diff_text))
1868 1868 if diff_text == bigpatch.diffs[0][1]:
1869 1869 del bigpatch.diffs[0]
1870 1870 else:
1871 1871 bigpatch.diffs[0] = (
1872 1872 bigpatch.diffs[0][0],
1873 1873 bigpatch.diffs[0][1][len(diff_text) :],
1874 1874 )
1875 1875
1876 1876 # Compute the head context for the next patch.
1877 1877 precontext = self.diff_text2(patch.diffs)
1878 1878 precontext = precontext[-self.Patch_Margin :]
1879 1879 # Append the end context for this patch.
1880 1880 postcontext = self.diff_text1(bigpatch.diffs)[: self.Patch_Margin]
1881 1881 if postcontext:
1882 1882 patch.length1 += len(postcontext)
1883 1883 patch.length2 += len(postcontext)
1884 1884 if len(patch.diffs) != 0 and patch.diffs[-1][0] == self.DIFF_EQUAL:
1885 1885 patch.diffs[-1] = (
1886 1886 self.DIFF_EQUAL,
1887 1887 patch.diffs[-1][1] + postcontext,
1888 1888 )
1889 1889 else:
1890 1890 patch.diffs.append((self.DIFF_EQUAL, postcontext))
1891 1891
1892 1892 if not empty:
1893 1893 x += 1
1894 1894 patches.insert(x, patch)
1895 1895
1896 1896 def patch_toText(self, patches):
1897 1897 """Take a list of patches and return a textual representation.
1898 1898
1899 1899 Args:
1900 1900 patches: Array of Patch objects.
1901 1901
1902 1902 Returns:
1903 1903 Text representation of patches.
1904 1904 """
1905 1905 text = []
1906 1906 for patch in patches:
1907 1907 text.append(str(patch))
1908 1908 return "".join(text)
1909 1909
1910 1910 def patch_fromText(self, textline):
1911 1911 """Parse a textual representation of patches and return a list of patch
1912 1912 objects.
1913 1913
1914 1914 Args:
1915 1915 textline: Text representation of patches.
1916 1916
1917 1917 Returns:
1918 1918 Array of Patch objects.
1919 1919
1920 1920 Raises:
1921 1921 ValueError: If invalid input.
1922 1922 """
1923 if type(textline) == unicode:
1923 if type(textline) == str:
1924 1924 # Patches should be composed of a subset of ascii chars, Unicode not
1925 1925 # required. If this encode raises UnicodeEncodeError, patch is invalid.
1926 1926 textline = textline.encode("ascii")
1927 1927 patches = []
1928 1928 if not textline:
1929 1929 return patches
1930 1930 text = textline.split("\n")
1931 1931 while len(text) != 0:
1932 1932 m = re.match("^@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@$", text[0])
1933 1933 if not m:
1934 1934 raise ValueError("Invalid patch string: " + text[0])
1935 1935 patch = patch_obj()
1936 1936 patches.append(patch)
1937 1937 patch.start1 = int(m.group(1))
1938 1938 if m.group(2) == "":
1939 1939 patch.start1 -= 1
1940 1940 patch.length1 = 1
1941 1941 elif m.group(2) == "0":
1942 1942 patch.length1 = 0
1943 1943 else:
1944 1944 patch.start1 -= 1
1945 1945 patch.length1 = int(m.group(2))
1946 1946
1947 1947 patch.start2 = int(m.group(3))
1948 1948 if m.group(4) == "":
1949 1949 patch.start2 -= 1
1950 1950 patch.length2 = 1
1951 1951 elif m.group(4) == "0":
1952 1952 patch.length2 = 0
1953 1953 else:
1954 1954 patch.start2 -= 1
1955 1955 patch.length2 = int(m.group(4))
1956 1956
1957 1957 del text[0]
1958 1958
1959 1959 while len(text) != 0:
1960 1960 if text[0]:
1961 1961 sign = text[0][0]
1962 1962 else:
1963 1963 sign = ""
1964 1964 line = urllib.parse.unquote(text[0][1:])
1965 1965 line = line.decode("utf-8")
1966 1966 if sign == "+":
1967 1967 # Insertion.
1968 1968 patch.diffs.append((self.DIFF_INSERT, line))
1969 1969 elif sign == "-":
1970 1970 # Deletion.
1971 1971 patch.diffs.append((self.DIFF_DELETE, line))
1972 1972 elif sign == " ":
1973 1973 # Minor equality.
1974 1974 patch.diffs.append((self.DIFF_EQUAL, line))
1975 1975 elif sign == "@":
1976 1976 # Start of next patch.
1977 1977 break
1978 1978 elif sign == "":
1979 1979 # Blank line? Whatever.
1980 1980 pass
1981 1981 else:
1982 1982 # WTF?
1983 1983 raise ValueError("Invalid patch mode: '%s'\n%s" % (sign, line))
1984 1984 del text[0]
1985 1985 return patches
1986 1986
1987 1987
1988 1988 class patch_obj:
1989 1989 """Class representing one patch operation."""
1990 1990
1991 1991 def __init__(self):
1992 1992 """Initializes with an empty list of diffs."""
1993 1993 self.diffs = []
1994 1994 self.start1 = None
1995 1995 self.start2 = None
1996 1996 self.length1 = 0
1997 1997 self.length2 = 0
1998 1998
1999 1999 def __str__(self):
2000 2000 """Emmulate GNU diff's format.
2001 2001 Header: @@ -382,8 +481,9 @@
2002 2002 Indicies are printed as 1-based, not 0-based.
2003 2003
2004 2004 Returns:
2005 2005 The GNU diff string.
2006 2006 """
2007 2007 if self.length1 == 0:
2008 2008 coords1 = str(self.start1) + ",0"
2009 2009 elif self.length1 == 1:
2010 2010 coords1 = str(self.start1 + 1)
2011 2011 else:
2012 2012 coords1 = str(self.start1 + 1) + "," + str(self.length1)
2013 2013 if self.length2 == 0:
2014 2014 coords2 = str(self.start2) + ",0"
2015 2015 elif self.length2 == 1:
2016 2016 coords2 = str(self.start2 + 1)
2017 2017 else:
2018 2018 coords2 = str(self.start2 + 1) + "," + str(self.length2)
2019 2019 text = ["@@ -", coords1, " +", coords2, " @@\n"]
2020 2020 # Escape the body of the patch with %xx notation.
2021 2021 for op, data in self.diffs:
2022 2022 if op == diff_match_patch.DIFF_INSERT:
2023 2023 text.append("+")
2024 2024 elif op == diff_match_patch.DIFF_DELETE:
2025 2025 text.append("-")
2026 2026 elif op == diff_match_patch.DIFF_EQUAL:
2027 2027 text.append(" ")
2028 2028 # High ascii will raise UnicodeDecodeError. Use Unicode instead.
2029 2029 data = data.encode("utf-8")
2030 2030 text.append(urllib.parse.quote(data, "!~*'();/?:@&=+$,# ") + "\n")
2031 2031 return "".join(text)
@@ -1,1272 +1,1271 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21
22 22 """
23 23 Set of diffing helpers, previously part of vcs
24 24 """
25 25
26 26 import os
27 27 import re
28 28 import bz2
29 29 import gzip
30 30 import time
31 31
32 32 import collections
33 33 import difflib
34 34 import logging
35 35 import pickle
36 36 from itertools import tee
37 37
38 38 from rhodecode.lib.vcs.exceptions import VCSError
39 39 from rhodecode.lib.vcs.nodes import FileNode, SubModuleNode
40 40 from rhodecode.lib.utils2 import safe_unicode, safe_str
41 41
42 42 log = logging.getLogger(__name__)
43 43
44 44 # define max context, a file with more than this numbers of lines is unusable
45 45 # in browser anyway
46 46 MAX_CONTEXT = 20 * 1024
47 47 DEFAULT_CONTEXT = 3
48 48
49 49
50 50 def get_diff_context(request):
51 51 return MAX_CONTEXT if request.GET.get('fullcontext', '') == '1' else DEFAULT_CONTEXT
52 52
53 53
54 54 def get_diff_whitespace_flag(request):
55 55 return request.GET.get('ignorews', '') == '1'
56 56
57 57
58 58 class OPS(object):
59 59 ADD = 'A'
60 60 MOD = 'M'
61 61 DEL = 'D'
62 62
63 63
64 64 def get_gitdiff(filenode_old, filenode_new, ignore_whitespace=True, context=3):
65 65 """
66 66 Returns git style diff between given ``filenode_old`` and ``filenode_new``.
67 67
68 68 :param ignore_whitespace: ignore whitespaces in diff
69 69 """
70 70 # make sure we pass in default context
71 71 context = context or 3
72 72 # protect against IntOverflow when passing HUGE context
73 73 if context > MAX_CONTEXT:
74 74 context = MAX_CONTEXT
75 75
76 submodules = filter(lambda o: isinstance(o, SubModuleNode),
77 [filenode_new, filenode_old])
76 submodules = [o for o in [filenode_new, filenode_old] if isinstance(o, SubModuleNode)]
78 77 if submodules:
79 78 return ''
80 79
81 80 for filenode in (filenode_old, filenode_new):
82 81 if not isinstance(filenode, FileNode):
83 82 raise VCSError(
84 83 "Given object should be FileNode object, not %s"
85 84 % filenode.__class__)
86 85
87 86 repo = filenode_new.commit.repository
88 87 old_commit = filenode_old.commit or repo.EMPTY_COMMIT
89 88 new_commit = filenode_new.commit
90 89
91 90 vcs_gitdiff = repo.get_diff(
92 91 old_commit, new_commit, filenode_new.path,
93 92 ignore_whitespace, context, path1=filenode_old.path)
94 93 return vcs_gitdiff
95 94
96 95 NEW_FILENODE = 1
97 96 DEL_FILENODE = 2
98 97 MOD_FILENODE = 3
99 98 RENAMED_FILENODE = 4
100 99 COPIED_FILENODE = 5
101 100 CHMOD_FILENODE = 6
102 101 BIN_FILENODE = 7
103 102
104 103
105 104 class LimitedDiffContainer(object):
106 105
107 106 def __init__(self, diff_limit, cur_diff_size, diff):
108 107 self.diff = diff
109 108 self.diff_limit = diff_limit
110 109 self.cur_diff_size = cur_diff_size
111 110
112 111 def __getitem__(self, key):
113 112 return self.diff.__getitem__(key)
114 113
115 114 def __iter__(self):
116 115 for l in self.diff:
117 116 yield l
118 117
119 118
120 119 class Action(object):
121 120 """
122 121 Contains constants for the action value of the lines in a parsed diff.
123 122 """
124 123
125 124 ADD = 'add'
126 125 DELETE = 'del'
127 126 UNMODIFIED = 'unmod'
128 127
129 128 CONTEXT = 'context'
130 129 OLD_NO_NL = 'old-no-nl'
131 130 NEW_NO_NL = 'new-no-nl'
132 131
133 132
134 133 class DiffProcessor(object):
135 134 """
136 135 Give it a unified or git diff and it returns a list of the files that were
137 136 mentioned in the diff together with a dict of meta information that
138 137 can be used to render it in a HTML template.
139 138
140 139 .. note:: Unicode handling
141 140
142 141 The original diffs are a byte sequence and can contain filenames
143 142 in mixed encodings. This class generally returns `unicode` objects
144 143 since the result is intended for presentation to the user.
145 144
146 145 """
147 146 _chunk_re = re.compile(r'^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@(.*)')
148 147 _newline_marker = re.compile(r'^\\ No newline at end of file')
149 148
150 149 # used for inline highlighter word split
151 150 _token_re = re.compile(r'()(&gt;|&lt;|&amp;|\W+?)')
152 151
153 152 # collapse ranges of commits over given number
154 153 _collapse_commits_over = 5
155 154
156 155 def __init__(self, diff, format='gitdiff', diff_limit=None,
157 156 file_limit=None, show_full_diff=True):
158 157 """
159 158 :param diff: A `Diff` object representing a diff from a vcs backend
160 159 :param format: format of diff passed, `udiff` or `gitdiff`
161 160 :param diff_limit: define the size of diff that is considered "big"
162 161 based on that parameter cut off will be triggered, set to None
163 162 to show full diff
164 163 """
165 164 self._diff = diff
166 165 self._format = format
167 166 self.adds = 0
168 167 self.removes = 0
169 168 # calculate diff size
170 169 self.diff_limit = diff_limit
171 170 self.file_limit = file_limit
172 171 self.show_full_diff = show_full_diff
173 172 self.cur_diff_size = 0
174 173 self.parsed = False
175 174 self.parsed_diff = []
176 175
177 176 log.debug('Initialized DiffProcessor with %s mode', format)
178 177 if format == 'gitdiff':
179 178 self.differ = self._highlight_line_difflib
180 179 self._parser = self._parse_gitdiff
181 180 else:
182 181 self.differ = self._highlight_line_udiff
183 182 self._parser = self._new_parse_gitdiff
184 183
185 184 def _copy_iterator(self):
186 185 """
187 186 make a fresh copy of generator, we should not iterate thru
188 187 an original as it's needed for repeating operations on
189 188 this instance of DiffProcessor
190 189 """
191 190 self.__udiff, iterator_copy = tee(self.__udiff)
192 191 return iterator_copy
193 192
194 193 def _escaper(self, string):
195 194 """
196 195 Escaper for diff escapes special chars and checks the diff limit
197 196
198 197 :param string:
199 198 """
200 199 self.cur_diff_size += len(string)
201 200
202 201 if not self.show_full_diff and (self.cur_diff_size > self.diff_limit):
203 202 raise DiffLimitExceeded('Diff Limit Exceeded')
204 203
205 204 return string \
206 205 .replace('&', '&amp;')\
207 206 .replace('<', '&lt;')\
208 207 .replace('>', '&gt;')
209 208
210 209 def _line_counter(self, l):
211 210 """
212 211 Checks each line and bumps total adds/removes for this diff
213 212
214 213 :param l:
215 214 """
216 215 if l.startswith('+') and not l.startswith('+++'):
217 216 self.adds += 1
218 217 elif l.startswith('-') and not l.startswith('---'):
219 218 self.removes += 1
220 219 return safe_unicode(l)
221 220
222 221 def _highlight_line_difflib(self, line, next_):
223 222 """
224 223 Highlight inline changes in both lines.
225 224 """
226 225
227 226 if line['action'] == Action.DELETE:
228 227 old, new = line, next_
229 228 else:
230 229 old, new = next_, line
231 230
232 231 oldwords = self._token_re.split(old['line'])
233 232 newwords = self._token_re.split(new['line'])
234 233 sequence = difflib.SequenceMatcher(None, oldwords, newwords)
235 234
236 235 oldfragments, newfragments = [], []
237 236 for tag, i1, i2, j1, j2 in sequence.get_opcodes():
238 237 oldfrag = ''.join(oldwords[i1:i2])
239 238 newfrag = ''.join(newwords[j1:j2])
240 239 if tag != 'equal':
241 240 if oldfrag:
242 241 oldfrag = '<del>%s</del>' % oldfrag
243 242 if newfrag:
244 243 newfrag = '<ins>%s</ins>' % newfrag
245 244 oldfragments.append(oldfrag)
246 245 newfragments.append(newfrag)
247 246
248 247 old['line'] = "".join(oldfragments)
249 248 new['line'] = "".join(newfragments)
250 249
251 250 def _highlight_line_udiff(self, line, next_):
252 251 """
253 252 Highlight inline changes in both lines.
254 253 """
255 254 start = 0
256 255 limit = min(len(line['line']), len(next_['line']))
257 256 while start < limit and line['line'][start] == next_['line'][start]:
258 257 start += 1
259 258 end = -1
260 259 limit -= start
261 260 while -end <= limit and line['line'][end] == next_['line'][end]:
262 261 end -= 1
263 262 end += 1
264 263 if start or end:
265 264 def do(l):
266 265 last = end + len(l['line'])
267 266 if l['action'] == Action.ADD:
268 267 tag = 'ins'
269 268 else:
270 269 tag = 'del'
271 270 l['line'] = '%s<%s>%s</%s>%s' % (
272 271 l['line'][:start],
273 272 tag,
274 273 l['line'][start:last],
275 274 tag,
276 275 l['line'][last:]
277 276 )
278 277 do(line)
279 278 do(next_)
280 279
281 280 def _clean_line(self, line, command):
282 281 if command in ['+', '-', ' ']:
283 282 # only modify the line if it's actually a diff thing
284 283 line = line[1:]
285 284 return line
286 285
287 286 def _parse_gitdiff(self, inline_diff=True):
288 287 _files = []
289 288 diff_container = lambda arg: arg
290 289
291 290 for chunk in self._diff.chunks():
292 291 head = chunk.header
293 292
294 293 diff = map(self._escaper, self.diff_splitter(chunk.diff))
295 294 raw_diff = chunk.raw
296 295 limited_diff = False
297 296 exceeds_limit = False
298 297
299 298 op = None
300 299 stats = {
301 300 'added': 0,
302 301 'deleted': 0,
303 302 'binary': False,
304 303 'ops': {},
305 304 }
306 305
307 306 if head['deleted_file_mode']:
308 307 op = OPS.DEL
309 308 stats['binary'] = True
310 309 stats['ops'][DEL_FILENODE] = 'deleted file'
311 310
312 311 elif head['new_file_mode']:
313 312 op = OPS.ADD
314 313 stats['binary'] = True
315 314 stats['ops'][NEW_FILENODE] = 'new file %s' % head['new_file_mode']
316 315 else: # modify operation, can be copy, rename or chmod
317 316
318 317 # CHMOD
319 318 if head['new_mode'] and head['old_mode']:
320 319 op = OPS.MOD
321 320 stats['binary'] = True
322 321 stats['ops'][CHMOD_FILENODE] = (
323 322 'modified file chmod %s => %s' % (
324 323 head['old_mode'], head['new_mode']))
325 324 # RENAME
326 325 if head['rename_from'] != head['rename_to']:
327 326 op = OPS.MOD
328 327 stats['binary'] = True
329 328 stats['ops'][RENAMED_FILENODE] = (
330 329 'file renamed from %s to %s' % (
331 330 head['rename_from'], head['rename_to']))
332 331 # COPY
333 332 if head.get('copy_from') and head.get('copy_to'):
334 333 op = OPS.MOD
335 334 stats['binary'] = True
336 335 stats['ops'][COPIED_FILENODE] = (
337 336 'file copied from %s to %s' % (
338 337 head['copy_from'], head['copy_to']))
339 338
340 339 # If our new parsed headers didn't match anything fallback to
341 340 # old style detection
342 341 if op is None:
343 342 if not head['a_file'] and head['b_file']:
344 343 op = OPS.ADD
345 344 stats['binary'] = True
346 345 stats['ops'][NEW_FILENODE] = 'new file'
347 346
348 347 elif head['a_file'] and not head['b_file']:
349 348 op = OPS.DEL
350 349 stats['binary'] = True
351 350 stats['ops'][DEL_FILENODE] = 'deleted file'
352 351
353 352 # it's not ADD not DELETE
354 353 if op is None:
355 354 op = OPS.MOD
356 355 stats['binary'] = True
357 356 stats['ops'][MOD_FILENODE] = 'modified file'
358 357
359 358 # a real non-binary diff
360 359 if head['a_file'] or head['b_file']:
361 360 try:
362 361 raw_diff, chunks, _stats = self._parse_lines(diff)
363 362 stats['binary'] = False
364 363 stats['added'] = _stats[0]
365 364 stats['deleted'] = _stats[1]
366 365 # explicit mark that it's a modified file
367 366 if op == OPS.MOD:
368 367 stats['ops'][MOD_FILENODE] = 'modified file'
369 368 exceeds_limit = len(raw_diff) > self.file_limit
370 369
371 370 # changed from _escaper function so we validate size of
372 371 # each file instead of the whole diff
373 372 # diff will hide big files but still show small ones
374 373 # from my tests, big files are fairly safe to be parsed
375 374 # but the browser is the bottleneck
376 375 if not self.show_full_diff and exceeds_limit:
377 376 raise DiffLimitExceeded('File Limit Exceeded')
378 377
379 378 except DiffLimitExceeded:
380 379 diff_container = lambda _diff: \
381 380 LimitedDiffContainer(
382 381 self.diff_limit, self.cur_diff_size, _diff)
383 382
384 383 exceeds_limit = len(raw_diff) > self.file_limit
385 384 limited_diff = True
386 385 chunks = []
387 386
388 387 else: # GIT format binary patch, or possibly empty diff
389 388 if head['bin_patch']:
390 389 # we have operation already extracted, but we mark simply
391 390 # it's a diff we wont show for binary files
392 391 stats['ops'][BIN_FILENODE] = 'binary diff hidden'
393 392 chunks = []
394 393
395 394 if chunks and not self.show_full_diff and op == OPS.DEL:
396 395 # if not full diff mode show deleted file contents
397 396 # TODO: anderson: if the view is not too big, there is no way
398 397 # to see the content of the file
399 398 chunks = []
400 399
401 400 chunks.insert(0, [{
402 401 'old_lineno': '',
403 402 'new_lineno': '',
404 403 'action': Action.CONTEXT,
405 404 'line': msg,
406 405 } for _op, msg in stats['ops'].items()
407 406 if _op not in [MOD_FILENODE]])
408 407
409 408 _files.append({
410 409 'filename': safe_unicode(head['b_path']),
411 410 'old_revision': head['a_blob_id'],
412 411 'new_revision': head['b_blob_id'],
413 412 'chunks': chunks,
414 413 'raw_diff': safe_unicode(raw_diff),
415 414 'operation': op,
416 415 'stats': stats,
417 416 'exceeds_limit': exceeds_limit,
418 417 'is_limited_diff': limited_diff,
419 418 })
420 419
421 420 sorter = lambda info: {OPS.ADD: 0, OPS.MOD: 1,
422 421 OPS.DEL: 2}.get(info['operation'])
423 422
424 423 if not inline_diff:
425 424 return diff_container(sorted(_files, key=sorter))
426 425
427 426 # highlight inline changes
428 427 for diff_data in _files:
429 428 for chunk in diff_data['chunks']:
430 429 lineiter = iter(chunk)
431 430 try:
432 431 while 1:
433 432 line = next(lineiter)
434 433 if line['action'] not in (
435 434 Action.UNMODIFIED, Action.CONTEXT):
436 435 nextline = next(lineiter)
437 436 if nextline['action'] in ['unmod', 'context'] or \
438 437 nextline['action'] == line['action']:
439 438 continue
440 439 self.differ(line, nextline)
441 440 except StopIteration:
442 441 pass
443 442
444 443 return diff_container(sorted(_files, key=sorter))
445 444
446 445 def _check_large_diff(self):
447 446 if self.diff_limit:
448 447 log.debug('Checking if diff exceeds current diff_limit of %s', self.diff_limit)
449 448 if not self.show_full_diff and (self.cur_diff_size > self.diff_limit):
450 449 raise DiffLimitExceeded('Diff Limit `%s` Exceeded', self.diff_limit)
451 450
452 451 # FIXME: NEWDIFFS: dan: this replaces _parse_gitdiff
453 452 def _new_parse_gitdiff(self, inline_diff=True):
454 453 _files = []
455 454
456 455 # this can be overriden later to a LimitedDiffContainer type
457 456 diff_container = lambda arg: arg
458 457
459 458 for chunk in self._diff.chunks():
460 459 head = chunk.header
461 460 log.debug('parsing diff %r', head)
462 461
463 462 raw_diff = chunk.raw
464 463 limited_diff = False
465 464 exceeds_limit = False
466 465
467 466 op = None
468 467 stats = {
469 468 'added': 0,
470 469 'deleted': 0,
471 470 'binary': False,
472 471 'old_mode': None,
473 472 'new_mode': None,
474 473 'ops': {},
475 474 }
476 475 if head['old_mode']:
477 476 stats['old_mode'] = head['old_mode']
478 477 if head['new_mode']:
479 478 stats['new_mode'] = head['new_mode']
480 479 if head['b_mode']:
481 480 stats['new_mode'] = head['b_mode']
482 481
483 482 # delete file
484 483 if head['deleted_file_mode']:
485 484 op = OPS.DEL
486 485 stats['binary'] = True
487 486 stats['ops'][DEL_FILENODE] = 'deleted file'
488 487
489 488 # new file
490 489 elif head['new_file_mode']:
491 490 op = OPS.ADD
492 491 stats['binary'] = True
493 492 stats['old_mode'] = None
494 493 stats['new_mode'] = head['new_file_mode']
495 494 stats['ops'][NEW_FILENODE] = 'new file %s' % head['new_file_mode']
496 495
497 496 # modify operation, can be copy, rename or chmod
498 497 else:
499 498 # CHMOD
500 499 if head['new_mode'] and head['old_mode']:
501 500 op = OPS.MOD
502 501 stats['binary'] = True
503 502 stats['ops'][CHMOD_FILENODE] = (
504 503 'modified file chmod %s => %s' % (
505 504 head['old_mode'], head['new_mode']))
506 505
507 506 # RENAME
508 507 if head['rename_from'] != head['rename_to']:
509 508 op = OPS.MOD
510 509 stats['binary'] = True
511 510 stats['renamed'] = (head['rename_from'], head['rename_to'])
512 511 stats['ops'][RENAMED_FILENODE] = (
513 512 'file renamed from %s to %s' % (
514 513 head['rename_from'], head['rename_to']))
515 514 # COPY
516 515 if head.get('copy_from') and head.get('copy_to'):
517 516 op = OPS.MOD
518 517 stats['binary'] = True
519 518 stats['copied'] = (head['copy_from'], head['copy_to'])
520 519 stats['ops'][COPIED_FILENODE] = (
521 520 'file copied from %s to %s' % (
522 521 head['copy_from'], head['copy_to']))
523 522
524 523 # If our new parsed headers didn't match anything fallback to
525 524 # old style detection
526 525 if op is None:
527 526 if not head['a_file'] and head['b_file']:
528 527 op = OPS.ADD
529 528 stats['binary'] = True
530 529 stats['new_file'] = True
531 530 stats['ops'][NEW_FILENODE] = 'new file'
532 531
533 532 elif head['a_file'] and not head['b_file']:
534 533 op = OPS.DEL
535 534 stats['binary'] = True
536 535 stats['ops'][DEL_FILENODE] = 'deleted file'
537 536
538 537 # it's not ADD not DELETE
539 538 if op is None:
540 539 op = OPS.MOD
541 540 stats['binary'] = True
542 541 stats['ops'][MOD_FILENODE] = 'modified file'
543 542
544 543 # a real non-binary diff
545 544 if head['a_file'] or head['b_file']:
546 545 # simulate splitlines, so we keep the line end part
547 546 diff = self.diff_splitter(chunk.diff)
548 547
549 548 # append each file to the diff size
550 549 raw_chunk_size = len(raw_diff)
551 550
552 551 exceeds_limit = raw_chunk_size > self.file_limit
553 552 self.cur_diff_size += raw_chunk_size
554 553
555 554 try:
556 555 # Check each file instead of the whole diff.
557 556 # Diff will hide big files but still show small ones.
558 557 # From the tests big files are fairly safe to be parsed
559 558 # but the browser is the bottleneck.
560 559 if not self.show_full_diff and exceeds_limit:
561 560 log.debug('File `%s` exceeds current file_limit of %s',
562 561 safe_unicode(head['b_path']), self.file_limit)
563 562 raise DiffLimitExceeded(
564 563 'File Limit %s Exceeded', self.file_limit)
565 564
566 565 self._check_large_diff()
567 566
568 567 raw_diff, chunks, _stats = self._new_parse_lines(diff)
569 568 stats['binary'] = False
570 569 stats['added'] = _stats[0]
571 570 stats['deleted'] = _stats[1]
572 571 # explicit mark that it's a modified file
573 572 if op == OPS.MOD:
574 573 stats['ops'][MOD_FILENODE] = 'modified file'
575 574
576 575 except DiffLimitExceeded:
577 576 diff_container = lambda _diff: \
578 577 LimitedDiffContainer(
579 578 self.diff_limit, self.cur_diff_size, _diff)
580 579
581 580 limited_diff = True
582 581 chunks = []
583 582
584 583 else: # GIT format binary patch, or possibly empty diff
585 584 if head['bin_patch']:
586 585 # we have operation already extracted, but we mark simply
587 586 # it's a diff we wont show for binary files
588 587 stats['ops'][BIN_FILENODE] = 'binary diff hidden'
589 588 chunks = []
590 589
591 590 # Hide content of deleted node by setting empty chunks
592 591 if chunks and not self.show_full_diff and op == OPS.DEL:
593 592 # if not full diff mode show deleted file contents
594 593 # TODO: anderson: if the view is not too big, there is no way
595 594 # to see the content of the file
596 595 chunks = []
597 596
598 597 chunks.insert(
599 598 0, [{'old_lineno': '',
600 599 'new_lineno': '',
601 600 'action': Action.CONTEXT,
602 601 'line': msg,
603 602 } for _op, msg in stats['ops'].items()
604 603 if _op not in [MOD_FILENODE]])
605 604
606 605 original_filename = safe_unicode(head['a_path'])
607 606 _files.append({
608 607 'original_filename': original_filename,
609 608 'filename': safe_unicode(head['b_path']),
610 609 'old_revision': head['a_blob_id'],
611 610 'new_revision': head['b_blob_id'],
612 611 'chunks': chunks,
613 612 'raw_diff': safe_unicode(raw_diff),
614 613 'operation': op,
615 614 'stats': stats,
616 615 'exceeds_limit': exceeds_limit,
617 616 'is_limited_diff': limited_diff,
618 617 })
619 618
620 619 sorter = lambda info: {OPS.ADD: 0, OPS.MOD: 1,
621 620 OPS.DEL: 2}.get(info['operation'])
622 621
623 622 return diff_container(sorted(_files, key=sorter))
624 623
625 624 # FIXME: NEWDIFFS: dan: this gets replaced by _new_parse_lines
626 625 def _parse_lines(self, diff_iter):
627 626 """
628 627 Parse the diff an return data for the template.
629 628 """
630 629
631 630 stats = [0, 0]
632 631 chunks = []
633 632 raw_diff = []
634 633
635 634 try:
636 635 line = next(diff_iter)
637 636
638 637 while line:
639 638 raw_diff.append(line)
640 639 lines = []
641 640 chunks.append(lines)
642 641
643 642 match = self._chunk_re.match(line)
644 643
645 644 if not match:
646 645 break
647 646
648 647 gr = match.groups()
649 648 (old_line, old_end,
650 649 new_line, new_end) = [int(x or 1) for x in gr[:-1]]
651 650 old_line -= 1
652 651 new_line -= 1
653 652
654 653 context = len(gr) == 5
655 654 old_end += old_line
656 655 new_end += new_line
657 656
658 657 if context:
659 658 # skip context only if it's first line
660 659 if int(gr[0]) > 1:
661 660 lines.append({
662 661 'old_lineno': '...',
663 662 'new_lineno': '...',
664 663 'action': Action.CONTEXT,
665 664 'line': line,
666 665 })
667 666
668 667 line = next(diff_iter)
669 668
670 669 while old_line < old_end or new_line < new_end:
671 670 command = ' '
672 671 if line:
673 672 command = line[0]
674 673
675 674 affects_old = affects_new = False
676 675
677 676 # ignore those if we don't expect them
678 677 if command in '#@':
679 678 continue
680 679 elif command == '+':
681 680 affects_new = True
682 681 action = Action.ADD
683 682 stats[0] += 1
684 683 elif command == '-':
685 684 affects_old = True
686 685 action = Action.DELETE
687 686 stats[1] += 1
688 687 else:
689 688 affects_old = affects_new = True
690 689 action = Action.UNMODIFIED
691 690
692 691 if not self._newline_marker.match(line):
693 692 old_line += affects_old
694 693 new_line += affects_new
695 694 lines.append({
696 695 'old_lineno': affects_old and old_line or '',
697 696 'new_lineno': affects_new and new_line or '',
698 697 'action': action,
699 698 'line': self._clean_line(line, command)
700 699 })
701 700 raw_diff.append(line)
702 701
703 702 line = next(diff_iter)
704 703
705 704 if self._newline_marker.match(line):
706 705 # we need to append to lines, since this is not
707 706 # counted in the line specs of diff
708 707 lines.append({
709 708 'old_lineno': '...',
710 709 'new_lineno': '...',
711 710 'action': Action.CONTEXT,
712 711 'line': self._clean_line(line, command)
713 712 })
714 713
715 714 except StopIteration:
716 715 pass
717 716 return ''.join(raw_diff), chunks, stats
718 717
719 718 # FIXME: NEWDIFFS: dan: this replaces _parse_lines
720 719 def _new_parse_lines(self, diff_iter):
721 720 """
722 721 Parse the diff an return data for the template.
723 722 """
724 723
725 724 stats = [0, 0]
726 725 chunks = []
727 726 raw_diff = []
728 727
729 728 try:
730 729 line = next(diff_iter)
731 730
732 731 while line:
733 732 raw_diff.append(line)
734 733 # match header e.g @@ -0,0 +1 @@\n'
735 734 match = self._chunk_re.match(line)
736 735
737 736 if not match:
738 737 break
739 738
740 739 gr = match.groups()
741 740 (old_line, old_end,
742 741 new_line, new_end) = [int(x or 1) for x in gr[:-1]]
743 742
744 743 lines = []
745 744 hunk = {
746 745 'section_header': gr[-1],
747 746 'source_start': old_line,
748 747 'source_length': old_end,
749 748 'target_start': new_line,
750 749 'target_length': new_end,
751 750 'lines': lines,
752 751 }
753 752 chunks.append(hunk)
754 753
755 754 old_line -= 1
756 755 new_line -= 1
757 756
758 757 context = len(gr) == 5
759 758 old_end += old_line
760 759 new_end += new_line
761 760
762 761 line = next(diff_iter)
763 762
764 763 while old_line < old_end or new_line < new_end:
765 764 command = ' '
766 765 if line:
767 766 command = line[0]
768 767
769 768 affects_old = affects_new = False
770 769
771 770 # ignore those if we don't expect them
772 771 if command in '#@':
773 772 continue
774 773 elif command == '+':
775 774 affects_new = True
776 775 action = Action.ADD
777 776 stats[0] += 1
778 777 elif command == '-':
779 778 affects_old = True
780 779 action = Action.DELETE
781 780 stats[1] += 1
782 781 else:
783 782 affects_old = affects_new = True
784 783 action = Action.UNMODIFIED
785 784
786 785 if not self._newline_marker.match(line):
787 786 old_line += affects_old
788 787 new_line += affects_new
789 788 lines.append({
790 789 'old_lineno': affects_old and old_line or '',
791 790 'new_lineno': affects_new and new_line or '',
792 791 'action': action,
793 792 'line': self._clean_line(line, command)
794 793 })
795 794 raw_diff.append(line)
796 795
797 796 line = next(diff_iter)
798 797
799 798 if self._newline_marker.match(line):
800 799 # we need to append to lines, since this is not
801 800 # counted in the line specs of diff
802 801 if affects_old:
803 802 action = Action.OLD_NO_NL
804 803 elif affects_new:
805 804 action = Action.NEW_NO_NL
806 805 else:
807 806 raise Exception('invalid context for no newline')
808 807
809 808 lines.append({
810 809 'old_lineno': None,
811 810 'new_lineno': None,
812 811 'action': action,
813 812 'line': self._clean_line(line, command)
814 813 })
815 814
816 815 except StopIteration:
817 816 pass
818 817
819 818 return ''.join(raw_diff), chunks, stats
820 819
821 820 def _safe_id(self, idstring):
822 821 """Make a string safe for including in an id attribute.
823 822
824 823 The HTML spec says that id attributes 'must begin with
825 824 a letter ([A-Za-z]) and may be followed by any number
826 825 of letters, digits ([0-9]), hyphens ("-"), underscores
827 826 ("_"), colons (":"), and periods (".")'. These regexps
828 827 are slightly over-zealous, in that they remove colons
829 828 and periods unnecessarily.
830 829
831 830 Whitespace is transformed into underscores, and then
832 831 anything which is not a hyphen or a character that
833 832 matches \w (alphanumerics and underscore) is removed.
834 833
835 834 """
836 835 # Transform all whitespace to underscore
837 836 idstring = re.sub(r'\s', "_", '%s' % idstring)
838 837 # Remove everything that is not a hyphen or a member of \w
839 838 idstring = re.sub(r'(?!-)\W', "", idstring).lower()
840 839 return idstring
841 840
842 841 @classmethod
843 842 def diff_splitter(cls, string):
844 843 """
845 844 Diff split that emulates .splitlines() but works only on \n
846 845 """
847 846 if not string:
848 847 return
849 848 elif string == '\n':
850 yield u'\n'
849 yield '\n'
851 850 else:
852 851
853 852 has_newline = string.endswith('\n')
854 853 elements = string.split('\n')
855 854 if has_newline:
856 855 # skip last element as it's empty string from newlines
857 856 elements = elements[:-1]
858 857
859 858 len_elements = len(elements)
860 859
861 860 for cnt, line in enumerate(elements, start=1):
862 861 last_line = cnt == len_elements
863 862 if last_line and not has_newline:
864 863 yield safe_unicode(line)
865 864 else:
866 865 yield safe_unicode(line) + '\n'
867 866
868 867 def prepare(self, inline_diff=True):
869 868 """
870 869 Prepare the passed udiff for HTML rendering.
871 870
872 871 :return: A list of dicts with diff information.
873 872 """
874 873 parsed = self._parser(inline_diff=inline_diff)
875 874 self.parsed = True
876 875 self.parsed_diff = parsed
877 876 return parsed
878 877
879 878 def as_raw(self, diff_lines=None):
880 879 """
881 880 Returns raw diff as a byte string
882 881 """
883 882 return self._diff.raw
884 883
885 884 def as_html(self, table_class='code-difftable', line_class='line',
886 885 old_lineno_class='lineno old', new_lineno_class='lineno new',
887 886 code_class='code', enable_comments=False, parsed_lines=None):
888 887 """
889 888 Return given diff as html table with customized css classes
890 889 """
891 890 # TODO(marcink): not sure how to pass in translator
892 891 # here in an efficient way, leave the _ for proper gettext extraction
893 892 _ = lambda s: s
894 893
895 894 def _link_to_if(condition, label, url):
896 895 """
897 896 Generates a link if condition is meet or just the label if not.
898 897 """
899 898
900 899 if condition:
901 900 return '''<a href="%(url)s" class="tooltip"
902 901 title="%(title)s">%(label)s</a>''' % {
903 902 'title': _('Click to select line'),
904 903 'url': url,
905 904 'label': label
906 905 }
907 906 else:
908 907 return label
909 908 if not self.parsed:
910 909 self.prepare()
911 910
912 911 diff_lines = self.parsed_diff
913 912 if parsed_lines:
914 913 diff_lines = parsed_lines
915 914
916 915 _html_empty = True
917 916 _html = []
918 917 _html.append('''<table class="%(table_class)s">\n''' % {
919 918 'table_class': table_class
920 919 })
921 920
922 921 for diff in diff_lines:
923 922 for line in diff['chunks']:
924 923 _html_empty = False
925 924 for change in line:
926 925 _html.append('''<tr class="%(lc)s %(action)s">\n''' % {
927 926 'lc': line_class,
928 927 'action': change['action']
929 928 })
930 929 anchor_old_id = ''
931 930 anchor_new_id = ''
932 931 anchor_old = "%(filename)s_o%(oldline_no)s" % {
933 932 'filename': self._safe_id(diff['filename']),
934 933 'oldline_no': change['old_lineno']
935 934 }
936 935 anchor_new = "%(filename)s_n%(oldline_no)s" % {
937 936 'filename': self._safe_id(diff['filename']),
938 937 'oldline_no': change['new_lineno']
939 938 }
940 939 cond_old = (change['old_lineno'] != '...' and
941 940 change['old_lineno'])
942 941 cond_new = (change['new_lineno'] != '...' and
943 942 change['new_lineno'])
944 943 if cond_old:
945 944 anchor_old_id = 'id="%s"' % anchor_old
946 945 if cond_new:
947 946 anchor_new_id = 'id="%s"' % anchor_new
948 947
949 948 if change['action'] != Action.CONTEXT:
950 949 anchor_link = True
951 950 else:
952 951 anchor_link = False
953 952
954 953 ###########################################################
955 954 # COMMENT ICONS
956 955 ###########################################################
957 956 _html.append('''\t<td class="add-comment-line"><span class="add-comment-content">''')
958 957
959 958 if enable_comments and change['action'] != Action.CONTEXT:
960 959 _html.append('''<a href="#"><span class="icon-comment-add"></span></a>''')
961 960
962 961 _html.append('''</span></td><td class="comment-toggle tooltip" title="Toggle Comment Thread"><i class="icon-comment"></i></td>\n''')
963 962
964 963 ###########################################################
965 964 # OLD LINE NUMBER
966 965 ###########################################################
967 966 _html.append('''\t<td %(a_id)s class="%(olc)s">''' % {
968 967 'a_id': anchor_old_id,
969 968 'olc': old_lineno_class
970 969 })
971 970
972 971 _html.append('''%(link)s''' % {
973 972 'link': _link_to_if(anchor_link, change['old_lineno'],
974 973 '#%s' % anchor_old)
975 974 })
976 975 _html.append('''</td>\n''')
977 976 ###########################################################
978 977 # NEW LINE NUMBER
979 978 ###########################################################
980 979
981 980 _html.append('''\t<td %(a_id)s class="%(nlc)s">''' % {
982 981 'a_id': anchor_new_id,
983 982 'nlc': new_lineno_class
984 983 })
985 984
986 985 _html.append('''%(link)s''' % {
987 986 'link': _link_to_if(anchor_link, change['new_lineno'],
988 987 '#%s' % anchor_new)
989 988 })
990 989 _html.append('''</td>\n''')
991 990 ###########################################################
992 991 # CODE
993 992 ###########################################################
994 993 code_classes = [code_class]
995 994 if (not enable_comments or
996 995 change['action'] == Action.CONTEXT):
997 996 code_classes.append('no-comment')
998 997 _html.append('\t<td class="%s">' % ' '.join(code_classes))
999 998 _html.append('''\n\t\t<pre>%(code)s</pre>\n''' % {
1000 999 'code': change['line']
1001 1000 })
1002 1001
1003 1002 _html.append('''\t</td>''')
1004 1003 _html.append('''\n</tr>\n''')
1005 1004 _html.append('''</table>''')
1006 1005 if _html_empty:
1007 1006 return None
1008 1007 return ''.join(_html)
1009 1008
1010 1009 def stat(self):
1011 1010 """
1012 1011 Returns tuple of added, and removed lines for this instance
1013 1012 """
1014 1013 return self.adds, self.removes
1015 1014
1016 1015 def get_context_of_line(
1017 1016 self, path, diff_line=None, context_before=3, context_after=3):
1018 1017 """
1019 1018 Returns the context lines for the specified diff line.
1020 1019
1021 1020 :type diff_line: :class:`DiffLineNumber`
1022 1021 """
1023 1022 assert self.parsed, "DiffProcessor is not initialized."
1024 1023
1025 1024 if None not in diff_line:
1026 1025 raise ValueError(
1027 1026 "Cannot specify both line numbers: {}".format(diff_line))
1028 1027
1029 1028 file_diff = self._get_file_diff(path)
1030 1029 chunk, idx = self._find_chunk_line_index(file_diff, diff_line)
1031 1030
1032 1031 first_line_to_include = max(idx - context_before, 0)
1033 1032 first_line_after_context = idx + context_after + 1
1034 1033 context_lines = chunk[first_line_to_include:first_line_after_context]
1035 1034
1036 1035 line_contents = [
1037 1036 _context_line(line) for line in context_lines
1038 1037 if _is_diff_content(line)]
1039 1038 # TODO: johbo: Interim fixup, the diff chunks drop the final newline.
1040 1039 # Once they are fixed, we can drop this line here.
1041 1040 if line_contents:
1042 1041 line_contents[-1] = (
1043 1042 line_contents[-1][0], line_contents[-1][1].rstrip('\n') + '\n')
1044 1043 return line_contents
1045 1044
1046 1045 def find_context(self, path, context, offset=0):
1047 1046 """
1048 1047 Finds the given `context` inside of the diff.
1049 1048
1050 1049 Use the parameter `offset` to specify which offset the target line has
1051 1050 inside of the given `context`. This way the correct diff line will be
1052 1051 returned.
1053 1052
1054 1053 :param offset: Shall be used to specify the offset of the main line
1055 1054 within the given `context`.
1056 1055 """
1057 1056 if offset < 0 or offset >= len(context):
1058 1057 raise ValueError(
1059 1058 "Only positive values up to the length of the context "
1060 1059 "minus one are allowed.")
1061 1060
1062 1061 matches = []
1063 1062 file_diff = self._get_file_diff(path)
1064 1063
1065 1064 for chunk in file_diff['chunks']:
1066 1065 context_iter = iter(context)
1067 1066 for line_idx, line in enumerate(chunk):
1068 1067 try:
1069 1068 if _context_line(line) == next(context_iter):
1070 1069 continue
1071 1070 except StopIteration:
1072 1071 matches.append((line_idx, chunk))
1073 1072 context_iter = iter(context)
1074 1073
1075 1074 # Increment position and triger StopIteration
1076 1075 # if we had a match at the end
1077 1076 line_idx += 1
1078 1077 try:
1079 1078 next(context_iter)
1080 1079 except StopIteration:
1081 1080 matches.append((line_idx, chunk))
1082 1081
1083 1082 effective_offset = len(context) - offset
1084 1083 found_at_diff_lines = [
1085 1084 _line_to_diff_line_number(chunk[idx - effective_offset])
1086 1085 for idx, chunk in matches]
1087 1086
1088 1087 return found_at_diff_lines
1089 1088
1090 1089 def _get_file_diff(self, path):
1091 1090 for file_diff in self.parsed_diff:
1092 1091 if file_diff['filename'] == path:
1093 1092 break
1094 1093 else:
1095 1094 raise FileNotInDiffException("File {} not in diff".format(path))
1096 1095 return file_diff
1097 1096
1098 1097 def _find_chunk_line_index(self, file_diff, diff_line):
1099 1098 for chunk in file_diff['chunks']:
1100 1099 for idx, line in enumerate(chunk):
1101 1100 if line['old_lineno'] == diff_line.old:
1102 1101 return chunk, idx
1103 1102 if line['new_lineno'] == diff_line.new:
1104 1103 return chunk, idx
1105 1104 raise LineNotInDiffException(
1106 1105 "The line {} is not part of the diff.".format(diff_line))
1107 1106
1108 1107
1109 1108 def _is_diff_content(line):
1110 1109 return line['action'] in (
1111 1110 Action.UNMODIFIED, Action.ADD, Action.DELETE)
1112 1111
1113 1112
1114 1113 def _context_line(line):
1115 1114 return (line['action'], line['line'])
1116 1115
1117 1116
1118 1117 DiffLineNumber = collections.namedtuple('DiffLineNumber', ['old', 'new'])
1119 1118
1120 1119
1121 1120 def _line_to_diff_line_number(line):
1122 1121 new_line_no = line['new_lineno'] or None
1123 1122 old_line_no = line['old_lineno'] or None
1124 1123 return DiffLineNumber(old=old_line_no, new=new_line_no)
1125 1124
1126 1125
1127 1126 class FileNotInDiffException(Exception):
1128 1127 """
1129 1128 Raised when the context for a missing file is requested.
1130 1129
1131 1130 If you request the context for a line in a file which is not part of the
1132 1131 given diff, then this exception is raised.
1133 1132 """
1134 1133
1135 1134
1136 1135 class LineNotInDiffException(Exception):
1137 1136 """
1138 1137 Raised when the context for a missing line is requested.
1139 1138
1140 1139 If you request the context for a line in a file and this line is not
1141 1140 part of the given diff, then this exception is raised.
1142 1141 """
1143 1142
1144 1143
1145 1144 class DiffLimitExceeded(Exception):
1146 1145 pass
1147 1146
1148 1147
1149 1148 # NOTE(marcink): if diffs.mako change, probably this
1150 1149 # needs a bump to next version
1151 1150 CURRENT_DIFF_VERSION = 'v5'
1152 1151
1153 1152
1154 1153 def _cleanup_cache_file(cached_diff_file):
1155 1154 # cleanup file to not store it "damaged"
1156 1155 try:
1157 1156 os.remove(cached_diff_file)
1158 1157 except Exception:
1159 1158 log.exception('Failed to cleanup path %s', cached_diff_file)
1160 1159
1161 1160
1162 1161 def _get_compression_mode(cached_diff_file):
1163 1162 mode = 'bz2'
1164 1163 if 'mode:plain' in cached_diff_file:
1165 1164 mode = 'plain'
1166 1165 elif 'mode:gzip' in cached_diff_file:
1167 1166 mode = 'gzip'
1168 1167 return mode
1169 1168
1170 1169
1171 1170 def cache_diff(cached_diff_file, diff, commits):
1172 1171 compression_mode = _get_compression_mode(cached_diff_file)
1173 1172
1174 1173 struct = {
1175 1174 'version': CURRENT_DIFF_VERSION,
1176 1175 'diff': diff,
1177 1176 'commits': commits
1178 1177 }
1179 1178
1180 1179 start = time.time()
1181 1180 try:
1182 1181 if compression_mode == 'plain':
1183 1182 with open(cached_diff_file, 'wb') as f:
1184 1183 pickle.dump(struct, f)
1185 1184 elif compression_mode == 'gzip':
1186 1185 with gzip.GzipFile(cached_diff_file, 'wb') as f:
1187 1186 pickle.dump(struct, f)
1188 1187 else:
1189 1188 with bz2.BZ2File(cached_diff_file, 'wb') as f:
1190 1189 pickle.dump(struct, f)
1191 1190 except Exception:
1192 1191 log.warn('Failed to save cache', exc_info=True)
1193 1192 _cleanup_cache_file(cached_diff_file)
1194 1193
1195 1194 log.debug('Saved diff cache under %s in %.4fs', cached_diff_file, time.time() - start)
1196 1195
1197 1196
1198 1197 def load_cached_diff(cached_diff_file):
1199 1198 compression_mode = _get_compression_mode(cached_diff_file)
1200 1199
1201 1200 default_struct = {
1202 1201 'version': CURRENT_DIFF_VERSION,
1203 1202 'diff': None,
1204 1203 'commits': None
1205 1204 }
1206 1205
1207 1206 has_cache = os.path.isfile(cached_diff_file)
1208 1207 if not has_cache:
1209 1208 log.debug('Reading diff cache file failed %s', cached_diff_file)
1210 1209 return default_struct
1211 1210
1212 1211 data = None
1213 1212
1214 1213 start = time.time()
1215 1214 try:
1216 1215 if compression_mode == 'plain':
1217 1216 with open(cached_diff_file, 'rb') as f:
1218 1217 data = pickle.load(f)
1219 1218 elif compression_mode == 'gzip':
1220 1219 with gzip.GzipFile(cached_diff_file, 'rb') as f:
1221 1220 data = pickle.load(f)
1222 1221 else:
1223 1222 with bz2.BZ2File(cached_diff_file, 'rb') as f:
1224 1223 data = pickle.load(f)
1225 1224 except Exception:
1226 1225 log.warn('Failed to read diff cache file', exc_info=True)
1227 1226
1228 1227 if not data:
1229 1228 data = default_struct
1230 1229
1231 1230 if not isinstance(data, dict):
1232 1231 # old version of data ?
1233 1232 data = default_struct
1234 1233
1235 1234 # check version
1236 1235 if data.get('version') != CURRENT_DIFF_VERSION:
1237 1236 # purge cache
1238 1237 _cleanup_cache_file(cached_diff_file)
1239 1238 return default_struct
1240 1239
1241 1240 log.debug('Loaded diff cache from %s in %.4fs', cached_diff_file, time.time() - start)
1242 1241
1243 1242 return data
1244 1243
1245 1244
1246 1245 def generate_diff_cache_key(*args):
1247 1246 """
1248 1247 Helper to generate a cache key using arguments
1249 1248 """
1250 1249 def arg_mapper(input_param):
1251 1250 input_param = safe_str(input_param)
1252 1251 # we cannot allow '/' in arguments since it would allow
1253 1252 # subdirectory usage
1254 1253 input_param.replace('/', '_')
1255 1254 return input_param or None # prevent empty string arguments
1256 1255
1257 1256 return '_'.join([
1258 1257 '{}' for i in range(len(args))]).format(*map(arg_mapper, args))
1259 1258
1260 1259
1261 1260 def diff_cache_exist(cache_storage, *args):
1262 1261 """
1263 1262 Based on all generated arguments check and return a cache path
1264 1263 """
1265 1264 args = list(args) + ['mode:gzip']
1266 1265 cache_key = generate_diff_cache_key(*args)
1267 1266 cache_file_path = os.path.join(cache_storage, cache_key)
1268 1267 # prevent path traversal attacks using some param that have e.g '../../'
1269 1268 if not os.path.abspath(cache_file_path).startswith(cache_storage):
1270 1269 raise ValueError('Final path must be within {}'.format(cache_storage))
1271 1270
1272 1271 return cache_file_path
@@ -1,444 +1,444 b''
1 1 # Copyright (c) Django Software Foundation and individual contributors.
2 2 # All rights reserved.
3 3 #
4 4 # Redistribution and use in source and binary forms, with or without modification,
5 5 # are permitted provided that the following conditions are met:
6 6 #
7 7 # 1. Redistributions of source code must retain the above copyright notice,
8 8 # this list of conditions and the following disclaimer.
9 9 #
10 10 # 2. Redistributions in binary form must reproduce the above copyright
11 11 # notice, this list of conditions and the following disclaimer in the
12 12 # documentation and/or other materials provided with the distribution.
13 13 #
14 14 # 3. Neither the name of Django nor the names of its contributors may be used
15 15 # to endorse or promote products derived from this software without
16 16 # specific prior written permission.
17 17 #
18 18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
19 19 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
20 20 # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 21 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
22 22 # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
23 23 # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
24 24 # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
25 25 # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26 26 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27 27 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 28
29 29 """
30 30 For definitions of the different versions of RSS, see:
31 31 http://web.archive.org/web/20110718035220/http://diveintomark.org/archives/2004/02/04/incompatible-rss
32 32 """
33 33
34 34
35 35 import datetime
36 from io import StringIO
36 import io
37 37
38 38 import pytz
39 39 from six.moves.urllib import parse as urlparse
40 40
41 41 from rhodecode.lib.feedgenerator import datetime_safe
42 42 from rhodecode.lib.feedgenerator.utils import SimplerXMLGenerator, iri_to_uri, force_text
43 43
44 44
45 45 #### The following code comes from ``django.utils.feedgenerator`` ####
46 46
47 47
48 48 def rfc2822_date(date):
49 49 # We can't use strftime() because it produces locale-dependent results, so
50 50 # we have to map english month and day names manually
51 51 months = ('Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec',)
52 52 days = ('Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun')
53 53 # Support datetime objects older than 1900
54 54 date = datetime_safe.new_datetime(date)
55 55 # We do this ourselves to be timezone aware, email.Utils is not tz aware.
56 56 dow = days[date.weekday()]
57 57 month = months[date.month - 1]
58 58 time_str = date.strftime('%s, %%d %s %%Y %%H:%%M:%%S ' % (dow, month))
59 59
60 60 offset = date.utcoffset()
61 61 # Historically, this function assumes that naive datetimes are in UTC.
62 62 if offset is None:
63 63 return time_str + '-0000'
64 64 else:
65 65 timezone = (offset.days * 24 * 60) + (offset.seconds // 60)
66 66 hour, minute = divmod(timezone, 60)
67 67 return time_str + '%+03d%02d' % (hour, minute)
68 68
69 69
70 70 def rfc3339_date(date):
71 71 # Support datetime objects older than 1900
72 72 date = datetime_safe.new_datetime(date)
73 73 time_str = date.strftime('%Y-%m-%dT%H:%M:%S')
74 74
75 75 offset = date.utcoffset()
76 76 # Historically, this function assumes that naive datetimes are in UTC.
77 77 if offset is None:
78 78 return time_str + 'Z'
79 79 else:
80 80 timezone = (offset.days * 24 * 60) + (offset.seconds // 60)
81 81 hour, minute = divmod(timezone, 60)
82 82 return time_str + '%+03d:%02d' % (hour, minute)
83 83
84 84
85 85 def get_tag_uri(url, date):
86 86 """
87 87 Creates a TagURI.
88 88
89 89 See http://web.archive.org/web/20110514113830/http://diveintomark.org/archives/2004/05/28/howto-atom-id
90 90 """
91 91 bits = urlparse(url)
92 92 d = ''
93 93 if date is not None:
94 94 d = ',%s' % datetime_safe.new_datetime(date).strftime('%Y-%m-%d')
95 95 return 'tag:%s%s:%s/%s' % (bits.hostname, d, bits.path, bits.fragment)
96 96
97 97
98 98 class SyndicationFeed(object):
99 99 """Base class for all syndication feeds. Subclasses should provide write()"""
100 100
101 101 def __init__(self, title, link, description, language=None, author_email=None,
102 102 author_name=None, author_link=None, subtitle=None, categories=None,
103 103 feed_url=None, feed_copyright=None, feed_guid=None, ttl=None, **kwargs):
104 104 def to_unicode(s):
105 105 return force_text(s, strings_only=True)
106 106 if categories:
107 107 categories = [force_text(c) for c in categories]
108 108 if ttl is not None:
109 109 # Force ints to unicode
110 110 ttl = force_text(ttl)
111 111 self.feed = {
112 112 'title': to_unicode(title),
113 113 'link': iri_to_uri(link),
114 114 'description': to_unicode(description),
115 115 'language': to_unicode(language),
116 116 'author_email': to_unicode(author_email),
117 117 'author_name': to_unicode(author_name),
118 118 'author_link': iri_to_uri(author_link),
119 119 'subtitle': to_unicode(subtitle),
120 120 'categories': categories or (),
121 121 'feed_url': iri_to_uri(feed_url),
122 122 'feed_copyright': to_unicode(feed_copyright),
123 123 'id': feed_guid or link,
124 124 'ttl': ttl,
125 125 }
126 126 self.feed.update(kwargs)
127 127 self.items = []
128 128
129 129 def add_item(self, title, link, description, author_email=None,
130 130 author_name=None, author_link=None, pubdate=None, comments=None,
131 131 unique_id=None, unique_id_is_permalink=None, enclosure=None,
132 132 categories=(), item_copyright=None, ttl=None, updateddate=None,
133 133 enclosures=None, **kwargs):
134 134 """
135 135 Adds an item to the feed. All args are expected to be Python Unicode
136 136 objects except pubdate and updateddate, which are datetime.datetime
137 137 objects, and enclosures, which is an iterable of instances of the
138 138 Enclosure class.
139 139 """
140 140 def to_unicode(s):
141 141 return force_text(s, strings_only=True)
142 142 if categories:
143 143 categories = [to_unicode(c) for c in categories]
144 144 if ttl is not None:
145 145 # Force ints to unicode
146 146 ttl = force_text(ttl)
147 147 if enclosure is None:
148 148 enclosures = [] if enclosures is None else enclosures
149 149
150 150 item = {
151 151 'title': to_unicode(title),
152 152 'link': iri_to_uri(link),
153 153 'description': to_unicode(description),
154 154 'author_email': to_unicode(author_email),
155 155 'author_name': to_unicode(author_name),
156 156 'author_link': iri_to_uri(author_link),
157 157 'pubdate': pubdate,
158 158 'updateddate': updateddate,
159 159 'comments': to_unicode(comments),
160 160 'unique_id': to_unicode(unique_id),
161 161 'unique_id_is_permalink': unique_id_is_permalink,
162 162 'enclosures': enclosures,
163 163 'categories': categories or (),
164 164 'item_copyright': to_unicode(item_copyright),
165 165 'ttl': ttl,
166 166 }
167 167 item.update(kwargs)
168 168 self.items.append(item)
169 169
170 170 def num_items(self):
171 171 return len(self.items)
172 172
173 173 def root_attributes(self):
174 174 """
175 175 Return extra attributes to place on the root (i.e. feed/channel) element.
176 176 Called from write().
177 177 """
178 178 return {}
179 179
180 180 def add_root_elements(self, handler):
181 181 """
182 182 Add elements in the root (i.e. feed/channel) element. Called
183 183 from write().
184 184 """
185 185 pass
186 186
187 187 def item_attributes(self, item):
188 188 """
189 189 Return extra attributes to place on each item (i.e. item/entry) element.
190 190 """
191 191 return {}
192 192
193 193 def add_item_elements(self, handler, item):
194 194 """
195 195 Add elements on each item (i.e. item/entry) element.
196 196 """
197 197 pass
198 198
199 199 def write(self, outfile, encoding):
200 200 """
201 201 Outputs the feed in the given encoding to outfile, which is a file-like
202 202 object. Subclasses should override this.
203 203 """
204 204 raise NotImplementedError('subclasses of SyndicationFeed must provide a write() method')
205 205
206 206 def writeString(self, encoding):
207 207 """
208 208 Returns the feed in the given encoding as a string.
209 209 """
210 s = StringIO()
210 s = io.StringIO()
211 211 self.write(s, encoding)
212 212 return s.getvalue()
213 213
214 214 def latest_post_date(self):
215 215 """
216 216 Returns the latest item's pubdate or updateddate. If no items
217 217 have either of these attributes this returns the current UTC date/time.
218 218 """
219 219 latest_date = None
220 220 date_keys = ('updateddate', 'pubdate')
221 221
222 222 for item in self.items:
223 223 for date_key in date_keys:
224 224 item_date = item.get(date_key)
225 225 if item_date:
226 226 if latest_date is None or item_date > latest_date:
227 227 latest_date = item_date
228 228
229 229 # datetime.now(tz=utc) is slower, as documented in django.utils.timezone.now
230 230 return latest_date or datetime.datetime.utcnow().replace(tzinfo=pytz.utc)
231 231
232 232
233 233 class Enclosure(object):
234 234 """Represents an RSS enclosure"""
235 235 def __init__(self, url, length, mime_type):
236 236 """All args are expected to be Python Unicode objects"""
237 237 self.length, self.mime_type = length, mime_type
238 238 self.url = iri_to_uri(url)
239 239
240 240
241 241 class RssFeed(SyndicationFeed):
242 242 content_type = 'application/rss+xml; charset=utf-8'
243 243
244 244 def write(self, outfile, encoding):
245 245 handler = SimplerXMLGenerator(outfile, encoding)
246 246 handler.startDocument()
247 247 handler.startElement("rss", self.rss_attributes())
248 248 handler.startElement("channel", self.root_attributes())
249 249 self.add_root_elements(handler)
250 250 self.write_items(handler)
251 251 self.endChannelElement(handler)
252 252 handler.endElement("rss")
253 253
254 254 def rss_attributes(self):
255 255 return {"version": self._version,
256 256 "xmlns:atom": "http://www.w3.org/2005/Atom"}
257 257
258 258 def write_items(self, handler):
259 259 for item in self.items:
260 260 handler.startElement('item', self.item_attributes(item))
261 261 self.add_item_elements(handler, item)
262 262 handler.endElement("item")
263 263
264 264 def add_root_elements(self, handler):
265 265 handler.addQuickElement("title", self.feed['title'])
266 266 handler.addQuickElement("link", self.feed['link'])
267 267 handler.addQuickElement("description", self.feed['description'])
268 268 if self.feed['feed_url'] is not None:
269 269 handler.addQuickElement("atom:link", None, {"rel": "self", "href": self.feed['feed_url']})
270 270 if self.feed['language'] is not None:
271 271 handler.addQuickElement("language", self.feed['language'])
272 272 for cat in self.feed['categories']:
273 273 handler.addQuickElement("category", cat)
274 274 if self.feed['feed_copyright'] is not None:
275 275 handler.addQuickElement("copyright", self.feed['feed_copyright'])
276 276 handler.addQuickElement("lastBuildDate", rfc2822_date(self.latest_post_date()))
277 277 if self.feed['ttl'] is not None:
278 278 handler.addQuickElement("ttl", self.feed['ttl'])
279 279
280 280 def endChannelElement(self, handler):
281 281 handler.endElement("channel")
282 282
283 283
284 284 class RssUserland091Feed(RssFeed):
285 285 _version = "0.91"
286 286
287 287 def add_item_elements(self, handler, item):
288 288 handler.addQuickElement("title", item['title'])
289 289 handler.addQuickElement("link", item['link'])
290 290 if item['description'] is not None:
291 291 handler.addQuickElement("description", item['description'])
292 292
293 293
294 294 class Rss201rev2Feed(RssFeed):
295 295 # Spec: http://blogs.law.harvard.edu/tech/rss
296 296 _version = "2.0"
297 297
298 298 def add_item_elements(self, handler, item):
299 299 handler.addQuickElement("title", item['title'])
300 300 handler.addQuickElement("link", item['link'])
301 301 if item['description'] is not None:
302 302 handler.addQuickElement("description", item['description'])
303 303
304 304 # Author information.
305 305 if item["author_name"] and item["author_email"]:
306 306 handler.addQuickElement("author", "%s (%s)" % (item['author_email'], item['author_name']))
307 307 elif item["author_email"]:
308 308 handler.addQuickElement("author", item["author_email"])
309 309 elif item["author_name"]:
310 310 handler.addQuickElement(
311 311 "dc:creator", item["author_name"], {"xmlns:dc": "http://purl.org/dc/elements/1.1/"}
312 312 )
313 313
314 314 if item['pubdate'] is not None:
315 315 handler.addQuickElement("pubDate", rfc2822_date(item['pubdate']))
316 316 if item['comments'] is not None:
317 317 handler.addQuickElement("comments", item['comments'])
318 318 if item['unique_id'] is not None:
319 319 guid_attrs = {}
320 320 if isinstance(item.get('unique_id_is_permalink'), bool):
321 321 guid_attrs['isPermaLink'] = str(item['unique_id_is_permalink']).lower()
322 322 handler.addQuickElement("guid", item['unique_id'], guid_attrs)
323 323 if item['ttl'] is not None:
324 324 handler.addQuickElement("ttl", item['ttl'])
325 325
326 326 # Enclosure.
327 327 if item['enclosures']:
328 328 enclosures = list(item['enclosures'])
329 329 if len(enclosures) > 1:
330 330 raise ValueError(
331 331 "RSS feed items may only have one enclosure, see "
332 332 "http://www.rssboard.org/rss-profile#element-channel-item-enclosure"
333 333 )
334 334 enclosure = enclosures[0]
335 335 handler.addQuickElement('enclosure', '', {
336 336 'url': enclosure.url,
337 337 'length': enclosure.length,
338 338 'type': enclosure.mime_type,
339 339 })
340 340
341 341 # Categories.
342 342 for cat in item['categories']:
343 343 handler.addQuickElement("category", cat)
344 344
345 345
346 346 class Atom1Feed(SyndicationFeed):
347 347 # Spec: https://tools.ietf.org/html/rfc4287
348 348 content_type = 'application/atom+xml; charset=utf-8'
349 349 ns = "http://www.w3.org/2005/Atom"
350 350
351 351 def write(self, outfile, encoding):
352 352 handler = SimplerXMLGenerator(outfile, encoding)
353 353 handler.startDocument()
354 354 handler.startElement('feed', self.root_attributes())
355 355 self.add_root_elements(handler)
356 356 self.write_items(handler)
357 357 handler.endElement("feed")
358 358
359 359 def root_attributes(self):
360 360 if self.feed['language'] is not None:
361 361 return {"xmlns": self.ns, "xml:lang": self.feed['language']}
362 362 else:
363 363 return {"xmlns": self.ns}
364 364
365 365 def add_root_elements(self, handler):
366 366 handler.addQuickElement("title", self.feed['title'])
367 367 handler.addQuickElement("link", "", {"rel": "alternate", "href": self.feed['link']})
368 368 if self.feed['feed_url'] is not None:
369 369 handler.addQuickElement("link", "", {"rel": "self", "href": self.feed['feed_url']})
370 370 handler.addQuickElement("id", self.feed['id'])
371 371 handler.addQuickElement("updated", rfc3339_date(self.latest_post_date()))
372 372 if self.feed['author_name'] is not None:
373 373 handler.startElement("author", {})
374 374 handler.addQuickElement("name", self.feed['author_name'])
375 375 if self.feed['author_email'] is not None:
376 376 handler.addQuickElement("email", self.feed['author_email'])
377 377 if self.feed['author_link'] is not None:
378 378 handler.addQuickElement("uri", self.feed['author_link'])
379 379 handler.endElement("author")
380 380 if self.feed['subtitle'] is not None:
381 381 handler.addQuickElement("subtitle", self.feed['subtitle'])
382 382 for cat in self.feed['categories']:
383 383 handler.addQuickElement("category", "", {"term": cat})
384 384 if self.feed['feed_copyright'] is not None:
385 385 handler.addQuickElement("rights", self.feed['feed_copyright'])
386 386
387 387 def write_items(self, handler):
388 388 for item in self.items:
389 389 handler.startElement("entry", self.item_attributes(item))
390 390 self.add_item_elements(handler, item)
391 391 handler.endElement("entry")
392 392
393 393 def add_item_elements(self, handler, item):
394 394 handler.addQuickElement("title", item['title'])
395 395 handler.addQuickElement("link", "", {"href": item['link'], "rel": "alternate"})
396 396
397 397 if item['pubdate'] is not None:
398 398 handler.addQuickElement('published', rfc3339_date(item['pubdate']))
399 399
400 400 if item['updateddate'] is not None:
401 401 handler.addQuickElement('updated', rfc3339_date(item['updateddate']))
402 402
403 403 # Author information.
404 404 if item['author_name'] is not None:
405 405 handler.startElement("author", {})
406 406 handler.addQuickElement("name", item['author_name'])
407 407 if item['author_email'] is not None:
408 408 handler.addQuickElement("email", item['author_email'])
409 409 if item['author_link'] is not None:
410 410 handler.addQuickElement("uri", item['author_link'])
411 411 handler.endElement("author")
412 412
413 413 # Unique ID.
414 414 if item['unique_id'] is not None:
415 415 unique_id = item['unique_id']
416 416 else:
417 417 unique_id = get_tag_uri(item['link'], item['pubdate'])
418 418 handler.addQuickElement("id", unique_id)
419 419
420 420 # Summary.
421 421 if item['description'] is not None:
422 422 handler.addQuickElement("summary", item['description'], {"type": "html"})
423 423
424 424 # Enclosures.
425 425 for enclosure in item['enclosures']:
426 426 handler.addQuickElement('link', '', {
427 427 'rel': 'enclosure',
428 428 'href': enclosure.url,
429 429 'length': enclosure.length,
430 430 'type': enclosure.mime_type,
431 431 })
432 432
433 433 # Categories.
434 434 for cat in item['categories']:
435 435 handler.addQuickElement("category", "", {"term": cat})
436 436
437 437 # Rights.
438 438 if item['item_copyright'] is not None:
439 439 handler.addQuickElement("rights", item['item_copyright'])
440 440
441 441
442 442 # This isolates the decision of what the system default is, so calling code can
443 443 # do "feedgenerator.DefaultFeed" instead of "feedgenerator.Rss201rev2Feed".
444 444 DefaultFeed = Rss201rev2Feed No newline at end of file
@@ -1,538 +1,538 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2013-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21
22 22 """
23 23 Set of hooks run by RhodeCode Enterprise
24 24 """
25 25
26 26 import os
27 27 import logging
28 28
29 29 import rhodecode
30 30 from rhodecode import events
31 31 from rhodecode.lib import helpers as h
32 32 from rhodecode.lib import audit_logger
33 33 from rhodecode.lib.utils2 import safe_str, user_agent_normalizer
34 34 from rhodecode.lib.exceptions import (
35 35 HTTPLockedRC, HTTPBranchProtected, UserCreationError)
36 36 from rhodecode.model.db import Repository, User
37 37 from rhodecode.lib.statsd_client import StatsdClient
38 38
39 39 log = logging.getLogger(__name__)
40 40
41 41
42 42 class HookResponse(object):
43 43 def __init__(self, status, output):
44 44 self.status = status
45 45 self.output = output
46 46
47 47 def __add__(self, other):
48 48 other_status = getattr(other, 'status', 0)
49 49 new_status = max(self.status, other_status)
50 50 other_output = getattr(other, 'output', '')
51 51 new_output = self.output + other_output
52 52
53 53 return HookResponse(new_status, new_output)
54 54
55 55 def __bool__(self):
56 56 return self.status == 0
57 57
58 58
59 59 def is_shadow_repo(extras):
60 60 """
61 61 Returns ``True`` if this is an action executed against a shadow repository.
62 62 """
63 63 return extras['is_shadow_repo']
64 64
65 65
66 66 def _get_scm_size(alias, root_path):
67 67
68 68 if not alias.startswith('.'):
69 69 alias += '.'
70 70
71 71 size_scm, size_root = 0, 0
72 72 for path, unused_dirs, files in os.walk(safe_str(root_path)):
73 73 if path.find(alias) != -1:
74 74 for f in files:
75 75 try:
76 76 size_scm += os.path.getsize(os.path.join(path, f))
77 77 except OSError:
78 78 pass
79 79 else:
80 80 for f in files:
81 81 try:
82 82 size_root += os.path.getsize(os.path.join(path, f))
83 83 except OSError:
84 84 pass
85 85
86 86 size_scm_f = h.format_byte_size_binary(size_scm)
87 87 size_root_f = h.format_byte_size_binary(size_root)
88 88 size_total_f = h.format_byte_size_binary(size_root + size_scm)
89 89
90 90 return size_scm_f, size_root_f, size_total_f
91 91
92 92
93 93 # actual hooks called by Mercurial internally, and GIT by our Python Hooks
94 94 def repo_size(extras):
95 95 """Present size of repository after push."""
96 96 repo = Repository.get_by_repo_name(extras.repository)
97 vcs_part = safe_str(u'.%s' % repo.repo_type)
97 vcs_part = safe_str('.%s' % repo.repo_type)
98 98 size_vcs, size_root, size_total = _get_scm_size(vcs_part,
99 99 repo.repo_full_path)
100 100 msg = ('Repository `%s` size summary %s:%s repo:%s total:%s\n'
101 101 % (repo.repo_name, vcs_part, size_vcs, size_root, size_total))
102 102 return HookResponse(0, msg)
103 103
104 104
105 105 def pre_push(extras):
106 106 """
107 107 Hook executed before pushing code.
108 108
109 109 It bans pushing when the repository is locked.
110 110 """
111 111
112 112 user = User.get_by_username(extras.username)
113 113 output = ''
114 114 if extras.locked_by[0] and user.user_id != int(extras.locked_by[0]):
115 115 locked_by = User.get(extras.locked_by[0]).username
116 116 reason = extras.locked_by[2]
117 117 # this exception is interpreted in git/hg middlewares and based
118 118 # on that proper return code is server to client
119 119 _http_ret = HTTPLockedRC(
120 120 _locked_by_explanation(extras.repository, locked_by, reason))
121 121 if str(_http_ret.code).startswith('2'):
122 122 # 2xx Codes don't raise exceptions
123 123 output = _http_ret.title
124 124 else:
125 125 raise _http_ret
126 126
127 127 hook_response = ''
128 128 if not is_shadow_repo(extras):
129 129 if extras.commit_ids and extras.check_branch_perms:
130 130
131 131 auth_user = user.AuthUser()
132 132 repo = Repository.get_by_repo_name(extras.repository)
133 133 affected_branches = []
134 134 if repo.repo_type == 'hg':
135 135 for entry in extras.commit_ids:
136 136 if entry['type'] == 'branch':
137 137 is_forced = bool(entry['multiple_heads'])
138 138 affected_branches.append([entry['name'], is_forced])
139 139 elif repo.repo_type == 'git':
140 140 for entry in extras.commit_ids:
141 141 if entry['type'] == 'heads':
142 142 is_forced = bool(entry['pruned_sha'])
143 143 affected_branches.append([entry['name'], is_forced])
144 144
145 145 for branch_name, is_forced in affected_branches:
146 146
147 147 rule, branch_perm = auth_user.get_rule_and_branch_permission(
148 148 extras.repository, branch_name)
149 149 if not branch_perm:
150 150 # no branch permission found for this branch, just keep checking
151 151 continue
152 152
153 153 if branch_perm == 'branch.push_force':
154 154 continue
155 155 elif branch_perm == 'branch.push' and is_forced is False:
156 156 continue
157 157 elif branch_perm == 'branch.push' and is_forced is True:
158 158 halt_message = 'Branch `{}` changes rejected by rule {}. ' \
159 159 'FORCE PUSH FORBIDDEN.'.format(branch_name, rule)
160 160 else:
161 161 halt_message = 'Branch `{}` changes rejected by rule {}.'.format(
162 162 branch_name, rule)
163 163
164 164 if halt_message:
165 165 _http_ret = HTTPBranchProtected(halt_message)
166 166 raise _http_ret
167 167
168 168 # Propagate to external components. This is done after checking the
169 169 # lock, for consistent behavior.
170 170 hook_response = pre_push_extension(
171 171 repo_store_path=Repository.base_path(), **extras)
172 172 events.trigger(events.RepoPrePushEvent(
173 173 repo_name=extras.repository, extras=extras))
174 174
175 175 return HookResponse(0, output) + hook_response
176 176
177 177
178 178 def pre_pull(extras):
179 179 """
180 180 Hook executed before pulling the code.
181 181
182 182 It bans pulling when the repository is locked.
183 183 """
184 184
185 185 output = ''
186 186 if extras.locked_by[0]:
187 187 locked_by = User.get(extras.locked_by[0]).username
188 188 reason = extras.locked_by[2]
189 189 # this exception is interpreted in git/hg middlewares and based
190 190 # on that proper return code is server to client
191 191 _http_ret = HTTPLockedRC(
192 192 _locked_by_explanation(extras.repository, locked_by, reason))
193 193 if str(_http_ret.code).startswith('2'):
194 194 # 2xx Codes don't raise exceptions
195 195 output = _http_ret.title
196 196 else:
197 197 raise _http_ret
198 198
199 199 # Propagate to external components. This is done after checking the
200 200 # lock, for consistent behavior.
201 201 hook_response = ''
202 202 if not is_shadow_repo(extras):
203 203 extras.hook_type = extras.hook_type or 'pre_pull'
204 204 hook_response = pre_pull_extension(
205 205 repo_store_path=Repository.base_path(), **extras)
206 206 events.trigger(events.RepoPrePullEvent(
207 207 repo_name=extras.repository, extras=extras))
208 208
209 209 return HookResponse(0, output) + hook_response
210 210
211 211
212 212 def post_pull(extras):
213 213 """Hook executed after client pulls the code."""
214 214
215 215 audit_user = audit_logger.UserWrap(
216 216 username=extras.username,
217 217 ip_addr=extras.ip)
218 218 repo = audit_logger.RepoWrap(repo_name=extras.repository)
219 219 audit_logger.store(
220 220 'user.pull', action_data={'user_agent': extras.user_agent},
221 221 user=audit_user, repo=repo, commit=True)
222 222
223 223 statsd = StatsdClient.statsd
224 224 if statsd:
225 225 statsd.incr('rhodecode_pull_total', tags=[
226 226 'user-agent:{}'.format(user_agent_normalizer(extras.user_agent)),
227 227 ])
228 228 output = ''
229 229 # make lock is a tri state False, True, None. We only make lock on True
230 230 if extras.make_lock is True and not is_shadow_repo(extras):
231 231 user = User.get_by_username(extras.username)
232 232 Repository.lock(Repository.get_by_repo_name(extras.repository),
233 233 user.user_id,
234 234 lock_reason=Repository.LOCK_PULL)
235 235 msg = 'Made lock on repo `%s`' % (extras.repository,)
236 236 output += msg
237 237
238 238 if extras.locked_by[0]:
239 239 locked_by = User.get(extras.locked_by[0]).username
240 240 reason = extras.locked_by[2]
241 241 _http_ret = HTTPLockedRC(
242 242 _locked_by_explanation(extras.repository, locked_by, reason))
243 243 if str(_http_ret.code).startswith('2'):
244 244 # 2xx Codes don't raise exceptions
245 245 output += _http_ret.title
246 246
247 247 # Propagate to external components.
248 248 hook_response = ''
249 249 if not is_shadow_repo(extras):
250 250 extras.hook_type = extras.hook_type or 'post_pull'
251 251 hook_response = post_pull_extension(
252 252 repo_store_path=Repository.base_path(), **extras)
253 253 events.trigger(events.RepoPullEvent(
254 254 repo_name=extras.repository, extras=extras))
255 255
256 256 return HookResponse(0, output) + hook_response
257 257
258 258
259 259 def post_push(extras):
260 260 """Hook executed after user pushes to the repository."""
261 261 commit_ids = extras.commit_ids
262 262
263 263 # log the push call
264 264 audit_user = audit_logger.UserWrap(
265 265 username=extras.username, ip_addr=extras.ip)
266 266 repo = audit_logger.RepoWrap(repo_name=extras.repository)
267 267 audit_logger.store(
268 268 'user.push', action_data={
269 269 'user_agent': extras.user_agent,
270 270 'commit_ids': commit_ids[:400]},
271 271 user=audit_user, repo=repo, commit=True)
272 272
273 273 statsd = StatsdClient.statsd
274 274 if statsd:
275 275 statsd.incr('rhodecode_push_total', tags=[
276 276 'user-agent:{}'.format(user_agent_normalizer(extras.user_agent)),
277 277 ])
278 278
279 279 # Propagate to external components.
280 280 output = ''
281 281 # make lock is a tri state False, True, None. We only release lock on False
282 282 if extras.make_lock is False and not is_shadow_repo(extras):
283 283 Repository.unlock(Repository.get_by_repo_name(extras.repository))
284 284 msg = 'Released lock on repo `{}`\n'.format(safe_str(extras.repository))
285 285 output += msg
286 286
287 287 if extras.locked_by[0]:
288 288 locked_by = User.get(extras.locked_by[0]).username
289 289 reason = extras.locked_by[2]
290 290 _http_ret = HTTPLockedRC(
291 291 _locked_by_explanation(extras.repository, locked_by, reason))
292 292 # TODO: johbo: if not?
293 293 if str(_http_ret.code).startswith('2'):
294 294 # 2xx Codes don't raise exceptions
295 295 output += _http_ret.title
296 296
297 297 if extras.new_refs:
298 298 tmpl = '{}/{}/pull-request/new?{{ref_type}}={{ref_name}}'.format(
299 299 safe_str(extras.server_url), safe_str(extras.repository))
300 300
301 301 for branch_name in extras.new_refs['branches']:
302 302 output += 'RhodeCode: open pull request link: {}\n'.format(
303 303 tmpl.format(ref_type='branch', ref_name=safe_str(branch_name)))
304 304
305 305 for book_name in extras.new_refs['bookmarks']:
306 306 output += 'RhodeCode: open pull request link: {}\n'.format(
307 307 tmpl.format(ref_type='bookmark', ref_name=safe_str(book_name)))
308 308
309 309 hook_response = ''
310 310 if not is_shadow_repo(extras):
311 311 hook_response = post_push_extension(
312 312 repo_store_path=Repository.base_path(),
313 313 **extras)
314 314 events.trigger(events.RepoPushEvent(
315 315 repo_name=extras.repository, pushed_commit_ids=commit_ids, extras=extras))
316 316
317 317 output += 'RhodeCode: push completed\n'
318 318 return HookResponse(0, output) + hook_response
319 319
320 320
321 321 def _locked_by_explanation(repo_name, user_name, reason):
322 322 message = (
323 323 'Repository `%s` locked by user `%s`. Reason:`%s`'
324 324 % (repo_name, user_name, reason))
325 325 return message
326 326
327 327
328 328 def check_allowed_create_user(user_dict, created_by, **kwargs):
329 329 # pre create hooks
330 330 if pre_create_user.is_active():
331 331 hook_result = pre_create_user(created_by=created_by, **user_dict)
332 332 allowed = hook_result.status == 0
333 333 if not allowed:
334 334 reason = hook_result.output
335 335 raise UserCreationError(reason)
336 336
337 337
338 338 class ExtensionCallback(object):
339 339 """
340 340 Forwards a given call to rcextensions, sanitizes keyword arguments.
341 341
342 342 Does check if there is an extension active for that hook. If it is
343 343 there, it will forward all `kwargs_keys` keyword arguments to the
344 344 extension callback.
345 345 """
346 346
347 347 def __init__(self, hook_name, kwargs_keys):
348 348 self._hook_name = hook_name
349 349 self._kwargs_keys = set(kwargs_keys)
350 350
351 351 def __call__(self, *args, **kwargs):
352 352 log.debug('Calling extension callback for `%s`', self._hook_name)
353 353 callback = self._get_callback()
354 354 if not callback:
355 355 log.debug('extension callback `%s` not found, skipping...', self._hook_name)
356 356 return
357 357
358 358 kwargs_to_pass = {}
359 359 for key in self._kwargs_keys:
360 360 try:
361 361 kwargs_to_pass[key] = kwargs[key]
362 362 except KeyError:
363 363 log.error('Failed to fetch %s key from given kwargs. '
364 364 'Expected keys: %s', key, self._kwargs_keys)
365 365 raise
366 366
367 367 # backward compat for removed api_key for old hooks. This was it works
368 368 # with older rcextensions that require api_key present
369 369 if self._hook_name in ['CREATE_USER_HOOK', 'DELETE_USER_HOOK']:
370 370 kwargs_to_pass['api_key'] = '_DEPRECATED_'
371 371 return callback(**kwargs_to_pass)
372 372
373 373 def is_active(self):
374 374 return hasattr(rhodecode.EXTENSIONS, self._hook_name)
375 375
376 376 def _get_callback(self):
377 377 return getattr(rhodecode.EXTENSIONS, self._hook_name, None)
378 378
379 379
380 380 pre_pull_extension = ExtensionCallback(
381 381 hook_name='PRE_PULL_HOOK',
382 382 kwargs_keys=(
383 383 'server_url', 'config', 'scm', 'username', 'ip', 'action',
384 384 'repository', 'hook_type', 'user_agent', 'repo_store_path',))
385 385
386 386
387 387 post_pull_extension = ExtensionCallback(
388 388 hook_name='PULL_HOOK',
389 389 kwargs_keys=(
390 390 'server_url', 'config', 'scm', 'username', 'ip', 'action',
391 391 'repository', 'hook_type', 'user_agent', 'repo_store_path',))
392 392
393 393
394 394 pre_push_extension = ExtensionCallback(
395 395 hook_name='PRE_PUSH_HOOK',
396 396 kwargs_keys=(
397 397 'server_url', 'config', 'scm', 'username', 'ip', 'action',
398 398 'repository', 'repo_store_path', 'commit_ids', 'hook_type', 'user_agent',))
399 399
400 400
401 401 post_push_extension = ExtensionCallback(
402 402 hook_name='PUSH_HOOK',
403 403 kwargs_keys=(
404 404 'server_url', 'config', 'scm', 'username', 'ip', 'action',
405 405 'repository', 'repo_store_path', 'commit_ids', 'hook_type', 'user_agent',))
406 406
407 407
408 408 pre_create_user = ExtensionCallback(
409 409 hook_name='PRE_CREATE_USER_HOOK',
410 410 kwargs_keys=(
411 411 'username', 'password', 'email', 'firstname', 'lastname', 'active',
412 412 'admin', 'created_by'))
413 413
414 414
415 415 create_pull_request = ExtensionCallback(
416 416 hook_name='CREATE_PULL_REQUEST',
417 417 kwargs_keys=(
418 418 'server_url', 'config', 'scm', 'username', 'ip', 'action',
419 419 'repository', 'pull_request_id', 'url', 'title', 'description',
420 420 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
421 421 'mergeable', 'source', 'target', 'author', 'reviewers'))
422 422
423 423
424 424 merge_pull_request = ExtensionCallback(
425 425 hook_name='MERGE_PULL_REQUEST',
426 426 kwargs_keys=(
427 427 'server_url', 'config', 'scm', 'username', 'ip', 'action',
428 428 'repository', 'pull_request_id', 'url', 'title', 'description',
429 429 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
430 430 'mergeable', 'source', 'target', 'author', 'reviewers'))
431 431
432 432
433 433 close_pull_request = ExtensionCallback(
434 434 hook_name='CLOSE_PULL_REQUEST',
435 435 kwargs_keys=(
436 436 'server_url', 'config', 'scm', 'username', 'ip', 'action',
437 437 'repository', 'pull_request_id', 'url', 'title', 'description',
438 438 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
439 439 'mergeable', 'source', 'target', 'author', 'reviewers'))
440 440
441 441
442 442 review_pull_request = ExtensionCallback(
443 443 hook_name='REVIEW_PULL_REQUEST',
444 444 kwargs_keys=(
445 445 'server_url', 'config', 'scm', 'username', 'ip', 'action',
446 446 'repository', 'pull_request_id', 'url', 'title', 'description',
447 447 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
448 448 'mergeable', 'source', 'target', 'author', 'reviewers'))
449 449
450 450
451 451 comment_pull_request = ExtensionCallback(
452 452 hook_name='COMMENT_PULL_REQUEST',
453 453 kwargs_keys=(
454 454 'server_url', 'config', 'scm', 'username', 'ip', 'action',
455 455 'repository', 'pull_request_id', 'url', 'title', 'description',
456 456 'status', 'comment', 'created_on', 'updated_on', 'commit_ids', 'review_status',
457 457 'mergeable', 'source', 'target', 'author', 'reviewers'))
458 458
459 459
460 460 comment_edit_pull_request = ExtensionCallback(
461 461 hook_name='COMMENT_EDIT_PULL_REQUEST',
462 462 kwargs_keys=(
463 463 'server_url', 'config', 'scm', 'username', 'ip', 'action',
464 464 'repository', 'pull_request_id', 'url', 'title', 'description',
465 465 'status', 'comment', 'created_on', 'updated_on', 'commit_ids', 'review_status',
466 466 'mergeable', 'source', 'target', 'author', 'reviewers'))
467 467
468 468
469 469 update_pull_request = ExtensionCallback(
470 470 hook_name='UPDATE_PULL_REQUEST',
471 471 kwargs_keys=(
472 472 'server_url', 'config', 'scm', 'username', 'ip', 'action',
473 473 'repository', 'pull_request_id', 'url', 'title', 'description',
474 474 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
475 475 'mergeable', 'source', 'target', 'author', 'reviewers'))
476 476
477 477
478 478 create_user = ExtensionCallback(
479 479 hook_name='CREATE_USER_HOOK',
480 480 kwargs_keys=(
481 481 'username', 'full_name_or_username', 'full_contact', 'user_id',
482 482 'name', 'firstname', 'short_contact', 'admin', 'lastname',
483 483 'ip_addresses', 'extern_type', 'extern_name',
484 484 'email', 'api_keys', 'last_login',
485 485 'full_name', 'active', 'password', 'emails',
486 486 'inherit_default_permissions', 'created_by', 'created_on'))
487 487
488 488
489 489 delete_user = ExtensionCallback(
490 490 hook_name='DELETE_USER_HOOK',
491 491 kwargs_keys=(
492 492 'username', 'full_name_or_username', 'full_contact', 'user_id',
493 493 'name', 'firstname', 'short_contact', 'admin', 'lastname',
494 494 'ip_addresses',
495 495 'email', 'last_login',
496 496 'full_name', 'active', 'password', 'emails',
497 497 'inherit_default_permissions', 'deleted_by'))
498 498
499 499
500 500 create_repository = ExtensionCallback(
501 501 hook_name='CREATE_REPO_HOOK',
502 502 kwargs_keys=(
503 503 'repo_name', 'repo_type', 'description', 'private', 'created_on',
504 504 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
505 505 'clone_uri', 'fork_id', 'group_id', 'created_by'))
506 506
507 507
508 508 delete_repository = ExtensionCallback(
509 509 hook_name='DELETE_REPO_HOOK',
510 510 kwargs_keys=(
511 511 'repo_name', 'repo_type', 'description', 'private', 'created_on',
512 512 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
513 513 'clone_uri', 'fork_id', 'group_id', 'deleted_by', 'deleted_on'))
514 514
515 515
516 516 comment_commit_repository = ExtensionCallback(
517 517 hook_name='COMMENT_COMMIT_REPO_HOOK',
518 518 kwargs_keys=(
519 519 'repo_name', 'repo_type', 'description', 'private', 'created_on',
520 520 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
521 521 'clone_uri', 'fork_id', 'group_id',
522 522 'repository', 'created_by', 'comment', 'commit'))
523 523
524 524 comment_edit_commit_repository = ExtensionCallback(
525 525 hook_name='COMMENT_EDIT_COMMIT_REPO_HOOK',
526 526 kwargs_keys=(
527 527 'repo_name', 'repo_type', 'description', 'private', 'created_on',
528 528 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
529 529 'clone_uri', 'fork_id', 'group_id',
530 530 'repository', 'created_by', 'comment', 'commit'))
531 531
532 532
533 533 create_repository_group = ExtensionCallback(
534 534 hook_name='CREATE_REPO_GROUP_HOOK',
535 535 kwargs_keys=(
536 536 'group_name', 'group_parent_id', 'group_description',
537 537 'group_id', 'user_id', 'created_by', 'created_on',
538 538 'enable_locking'))
@@ -1,187 +1,187 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import sys
22 22 import logging
23 23
24 24
25 BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(30, 38)
25 BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = list(range(30, 38))
26 26
27 27 # Sequences
28 28 RESET_SEQ = "\033[0m"
29 29 COLOR_SEQ = "\033[0;%dm"
30 30 BOLD_SEQ = "\033[1m"
31 31
32 32 COLORS = {
33 33 'CRITICAL': MAGENTA,
34 34 'ERROR': RED,
35 35 'WARNING': CYAN,
36 36 'INFO': GREEN,
37 37 'DEBUG': BLUE,
38 38 'SQL': YELLOW
39 39 }
40 40
41 41
42 42 def _inject_req_id(record, with_prefix=True):
43 43 from pyramid.threadlocal import get_current_request
44 44 dummy = '00000000-0000-0000-0000-000000000000'
45 45 req_id = None
46 46
47 47 req = get_current_request()
48 48 if req:
49 49 req_id = getattr(req, 'req_id', None)
50 50 if with_prefix:
51 51 req_id = 'req_id:%-36s' % (req_id or dummy)
52 52 else:
53 53 req_id = (req_id or dummy)
54 54 record.req_id = req_id
55 55
56 56
57 57 def _add_log_to_debug_bucket(formatted_record):
58 58 from pyramid.threadlocal import get_current_request
59 59 req = get_current_request()
60 60 if req:
61 61 req.req_id_bucket.append(formatted_record)
62 62
63 63
64 64 def one_space_trim(s):
65 65 if s.find(" ") == -1:
66 66 return s
67 67 else:
68 68 s = s.replace(' ', ' ')
69 69 return one_space_trim(s)
70 70
71 71
72 72 def format_sql(sql):
73 73 sql = sql.replace('\n', '')
74 74 sql = one_space_trim(sql)
75 75 sql = sql\
76 76 .replace(',', ',\n\t')\
77 77 .replace('SELECT', '\n\tSELECT \n\t')\
78 78 .replace('UPDATE', '\n\tUPDATE \n\t')\
79 79 .replace('DELETE', '\n\tDELETE \n\t')\
80 80 .replace('FROM', '\n\tFROM')\
81 81 .replace('ORDER BY', '\n\tORDER BY')\
82 82 .replace('LIMIT', '\n\tLIMIT')\
83 83 .replace('WHERE', '\n\tWHERE')\
84 84 .replace('AND', '\n\tAND')\
85 85 .replace('LEFT', '\n\tLEFT')\
86 86 .replace('INNER', '\n\tINNER')\
87 87 .replace('INSERT', '\n\tINSERT')\
88 88 .replace('DELETE', '\n\tDELETE')
89 89 return sql
90 90
91 91
92 92 class ExceptionAwareFormatter(logging.Formatter):
93 93 """
94 94 Extended logging formatter which prints out remote tracebacks.
95 95 """
96 96
97 97 def formatException(self, ei):
98 98 ex_type, ex_value, ex_tb = ei
99 99
100 100 local_tb = logging.Formatter.formatException(self, ei)
101 101 if hasattr(ex_value, '_vcs_server_traceback'):
102 102
103 103 def formatRemoteTraceback(remote_tb_lines):
104 104 result = ["\n +--- This exception occured remotely on VCSServer - Remote traceback:\n\n"]
105 105 result.append(remote_tb_lines)
106 106 result.append("\n +--- End of remote traceback\n")
107 107 return result
108 108
109 109 try:
110 110 if ex_type is not None and ex_value is None and ex_tb is None:
111 111 # possible old (3.x) call syntax where caller is only
112 112 # providing exception object
113 113 if type(ex_type) is not type:
114 114 raise TypeError(
115 115 "invalid argument: ex_type should be an exception "
116 116 "type, or just supply no arguments at all")
117 117 if ex_type is None and ex_tb is None:
118 118 ex_type, ex_value, ex_tb = sys.exc_info()
119 119
120 120 remote_tb = getattr(ex_value, "_vcs_server_traceback", None)
121 121
122 122 if remote_tb:
123 123 remote_tb = formatRemoteTraceback(remote_tb)
124 124 return local_tb + ''.join(remote_tb)
125 125 finally:
126 126 # clean up cycle to traceback, to allow proper GC
127 127 del ex_type, ex_value, ex_tb
128 128
129 129 return local_tb
130 130
131 131
132 132 class RequestTrackingFormatter(ExceptionAwareFormatter):
133 133 def format(self, record):
134 134 _inject_req_id(record)
135 135 def_record = logging.Formatter.format(self, record)
136 136 _add_log_to_debug_bucket(def_record)
137 137 return def_record
138 138
139 139
140 140 class ColorFormatter(ExceptionAwareFormatter):
141 141
142 142 def format(self, record):
143 143 """
144 144 Changes record's levelname to use with COLORS enum
145 145 """
146 146 def_record = super(ColorFormatter, self).format(record)
147 147
148 148 levelname = record.levelname
149 149 start = COLOR_SEQ % (COLORS[levelname])
150 150 end = RESET_SEQ
151 151
152 152 colored_record = ''.join([start, def_record, end])
153 153 return colored_record
154 154
155 155
156 156 class ColorRequestTrackingFormatter(RequestTrackingFormatter):
157 157
158 158 def format(self, record):
159 159 """
160 160 Changes record's levelname to use with COLORS enum
161 161 """
162 162 def_record = super(ColorRequestTrackingFormatter, self).format(record)
163 163
164 164 levelname = record.levelname
165 165 start = COLOR_SEQ % (COLORS[levelname])
166 166 end = RESET_SEQ
167 167
168 168 colored_record = ''.join([start, def_record, end])
169 169 return colored_record
170 170
171 171
172 172 class ColorFormatterSql(logging.Formatter):
173 173
174 174 def format(self, record):
175 175 """
176 176 Changes record's levelname to use with COLORS enum
177 177 """
178 178
179 179 start = COLOR_SEQ % (COLORS['SQL'])
180 180 def_record = format_sql(logging.Formatter.format(self, record))
181 181 end = RESET_SEQ
182 182
183 183 colored_record = ''.join([start, def_record, end])
184 184 return colored_record
185 185
186 186 # marcink: needs to stay with this name for backward .ini compatability
187 187 Pyro4AwareFormatter = ExceptionAwareFormatter
@@ -1,580 +1,580 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21
22 22 """
23 23 Renderer for markup languages with ability to parse using rst or markdown
24 24 """
25 25
26 26 import re
27 27 import os
28 28 import lxml
29 29 import logging
30 30 import urllib.parse
31 31 import bleach
32 32
33 33 from mako.lookup import TemplateLookup
34 34 from mako.template import Template as MakoTemplate
35 35
36 36 from docutils.core import publish_parts
37 37 from docutils.parsers.rst import directives
38 38 from docutils import writers
39 39 from docutils.writers import html4css1
40 40 import markdown
41 41
42 42 from rhodecode.lib.markdown_ext import GithubFlavoredMarkdownExtension
43 43 from rhodecode.lib.utils2 import (safe_unicode, md5_safe, MENTIONS_REGEX)
44 44
45 45 log = logging.getLogger(__name__)
46 46
47 47 # default renderer used to generate automated comments
48 48 DEFAULT_COMMENTS_RENDERER = 'rst'
49 49
50 50 try:
51 51 from lxml.html import fromstring
52 52 from lxml.html import tostring
53 53 except ImportError:
54 54 log.exception('Failed to import lxml')
55 55 fromstring = None
56 56 tostring = None
57 57
58 58
59 59 class CustomHTMLTranslator(writers.html4css1.HTMLTranslator):
60 60 """
61 61 Custom HTML Translator used for sandboxing potential
62 62 JS injections in ref links
63 63 """
64 64 def visit_literal_block(self, node):
65 65 self.body.append(self.starttag(node, 'pre', CLASS='codehilite literal-block'))
66 66
67 67 def visit_reference(self, node):
68 68 if 'refuri' in node.attributes:
69 69 refuri = node['refuri']
70 70 if ':' in refuri:
71 71 prefix, link = refuri.lstrip().split(':', 1)
72 72 prefix = prefix or ''
73 73
74 74 if prefix.lower() == 'javascript':
75 75 # we don't allow javascript type of refs...
76 76 node['refuri'] = 'javascript:alert("SandBoxedJavascript")'
77 77
78 78 # old style class requires this...
79 79 return html4css1.HTMLTranslator.visit_reference(self, node)
80 80
81 81
82 82 class RhodeCodeWriter(writers.html4css1.Writer):
83 83 def __init__(self):
84 84 writers.Writer.__init__(self)
85 85 self.translator_class = CustomHTMLTranslator
86 86
87 87
88 88 def relative_links(html_source, server_paths):
89 89 if not html_source:
90 90 return html_source
91 91
92 92 if not fromstring and tostring:
93 93 return html_source
94 94
95 95 try:
96 96 doc = lxml.html.fromstring(html_source)
97 97 except Exception:
98 98 return html_source
99 99
100 100 for el in doc.cssselect('img, video'):
101 101 src = el.attrib.get('src')
102 102 if src:
103 103 el.attrib['src'] = relative_path(src, server_paths['raw'])
104 104
105 105 for el in doc.cssselect('a:not(.gfm)'):
106 106 src = el.attrib.get('href')
107 107 if src:
108 108 raw_mode = el.attrib['href'].endswith('?raw=1')
109 109 if raw_mode:
110 110 el.attrib['href'] = relative_path(src, server_paths['raw'])
111 111 else:
112 112 el.attrib['href'] = relative_path(src, server_paths['standard'])
113 113
114 114 return lxml.html.tostring(doc)
115 115
116 116
117 117 def relative_path(path, request_path, is_repo_file=None):
118 118 """
119 119 relative link support, path is a rel path, and request_path is current
120 120 server path (not absolute)
121 121
122 122 e.g.
123 123
124 124 path = '../logo.png'
125 125 request_path= '/repo/files/path/file.md'
126 126 produces: '/repo/files/logo.png'
127 127 """
128 128 # TODO(marcink): unicode/str support ?
129 129 # maybe=> safe_unicode(urllib.quote(safe_str(final_path), '/:'))
130 130
131 131 def dummy_check(p):
132 132 return True # assume default is a valid file path
133 133
134 134 is_repo_file = is_repo_file or dummy_check
135 135 if not path:
136 136 return request_path
137 137
138 138 path = safe_unicode(path)
139 139 request_path = safe_unicode(request_path)
140 140
141 if path.startswith((u'data:', u'javascript:', u'#', u':')):
141 if path.startswith(('data:', 'javascript:', '#', ':')):
142 142 # skip data, anchor, invalid links
143 143 return path
144 144
145 145 is_absolute = bool(urllib.parse.urlparse(path).netloc)
146 146 if is_absolute:
147 147 return path
148 148
149 149 if not request_path:
150 150 return path
151 151
152 if path.startswith(u'/'):
152 if path.startswith('/'):
153 153 path = path[1:]
154 154
155 if path.startswith(u'./'):
155 if path.startswith('./'):
156 156 path = path[2:]
157 157
158 158 parts = request_path.split('/')
159 159 # compute how deep we need to traverse the request_path
160 160 depth = 0
161 161
162 162 if is_repo_file(request_path):
163 163 # if request path is a VALID file, we use a relative path with
164 164 # one level up
165 165 depth += 1
166 166
167 while path.startswith(u'../'):
167 while path.startswith('../'):
168 168 depth += 1
169 169 path = path[3:]
170 170
171 171 if depth > 0:
172 172 parts = parts[:-depth]
173 173
174 174 parts.append(path)
175 final_path = u'/'.join(parts).lstrip(u'/')
175 final_path = '/'.join(parts).lstrip('/')
176 176
177 return u'/' + final_path
177 return '/' + final_path
178 178
179 179
180 180 _cached_markdown_renderer = None
181 181
182 182
183 183 def get_markdown_renderer(extensions, output_format):
184 184 global _cached_markdown_renderer
185 185
186 186 if _cached_markdown_renderer is None:
187 187 _cached_markdown_renderer = markdown.Markdown(
188 188 extensions=extensions,
189 189 enable_attributes=False, output_format=output_format)
190 190 return _cached_markdown_renderer
191 191
192 192
193 193 _cached_markdown_renderer_flavored = None
194 194
195 195
196 196 def get_markdown_renderer_flavored(extensions, output_format):
197 197 global _cached_markdown_renderer_flavored
198 198
199 199 if _cached_markdown_renderer_flavored is None:
200 200 _cached_markdown_renderer_flavored = markdown.Markdown(
201 201 extensions=extensions + [GithubFlavoredMarkdownExtension()],
202 202 enable_attributes=False, output_format=output_format)
203 203 return _cached_markdown_renderer_flavored
204 204
205 205
206 206 class MarkupRenderer(object):
207 207 RESTRUCTUREDTEXT_DISALLOWED_DIRECTIVES = ['include', 'meta', 'raw']
208 208
209 209 MARKDOWN_PAT = re.compile(r'\.(md|mkdn?|mdown|markdown)$', re.IGNORECASE)
210 210 RST_PAT = re.compile(r'\.re?st$', re.IGNORECASE)
211 211 JUPYTER_PAT = re.compile(r'\.(ipynb)$', re.IGNORECASE)
212 212 PLAIN_PAT = re.compile(r'^readme$', re.IGNORECASE)
213 213
214 214 URL_PAT = re.compile(r'(http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]'
215 215 r'|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+)')
216 216
217 217 MENTION_PAT = re.compile(MENTIONS_REGEX)
218 218
219 219 extensions = ['markdown.extensions.codehilite', 'markdown.extensions.extra',
220 220 'markdown.extensions.def_list', 'markdown.extensions.sane_lists']
221 221
222 222 output_format = 'html4'
223 223
224 224 # extension together with weights. Lower is first means we control how
225 225 # extensions are attached to readme names with those.
226 226 PLAIN_EXTS = [
227 227 # prefer no extension
228 228 ('', 0), # special case that renders READMES names without extension
229 229 ('.text', 2), ('.TEXT', 2),
230 230 ('.txt', 3), ('.TXT', 3)
231 231 ]
232 232
233 233 RST_EXTS = [
234 234 ('.rst', 1), ('.rest', 1),
235 235 ('.RST', 2), ('.REST', 2)
236 236 ]
237 237
238 238 MARKDOWN_EXTS = [
239 239 ('.md', 1), ('.MD', 1),
240 240 ('.mkdn', 2), ('.MKDN', 2),
241 241 ('.mdown', 3), ('.MDOWN', 3),
242 242 ('.markdown', 4), ('.MARKDOWN', 4)
243 243 ]
244 244
245 245 def _detect_renderer(self, source, filename=None):
246 246 """
247 247 runs detection of what renderer should be used for generating html
248 248 from a markup language
249 249
250 250 filename can be also explicitly a renderer name
251 251
252 252 :param source:
253 253 :param filename:
254 254 """
255 255
256 256 if MarkupRenderer.MARKDOWN_PAT.findall(filename):
257 257 detected_renderer = 'markdown'
258 258 elif MarkupRenderer.RST_PAT.findall(filename):
259 259 detected_renderer = 'rst'
260 260 elif MarkupRenderer.JUPYTER_PAT.findall(filename):
261 261 detected_renderer = 'jupyter'
262 262 elif MarkupRenderer.PLAIN_PAT.findall(filename):
263 263 detected_renderer = 'plain'
264 264 else:
265 265 detected_renderer = 'plain'
266 266
267 267 return getattr(MarkupRenderer, detected_renderer)
268 268
269 269 @classmethod
270 270 def bleach_clean(cls, text):
271 271 from .bleach_whitelist import markdown_attrs, markdown_tags
272 272 allowed_tags = markdown_tags
273 273 allowed_attrs = markdown_attrs
274 274
275 275 try:
276 276 return bleach.clean(text, tags=allowed_tags, attributes=allowed_attrs)
277 277 except Exception:
278 278 return 'UNPARSEABLE TEXT'
279 279
280 280 @classmethod
281 281 def renderer_from_filename(cls, filename, exclude):
282 282 """
283 283 Detect renderer markdown/rst from filename and optionally use exclude
284 284 list to remove some options. This is mostly used in helpers.
285 285 Returns None when no renderer can be detected.
286 286 """
287 287 def _filter(elements):
288 288 if isinstance(exclude, (list, tuple)):
289 289 return [x for x in elements if x not in exclude]
290 290 return elements
291 291
292 292 if filename.endswith(
293 293 tuple(_filter([x[0] for x in cls.MARKDOWN_EXTS if x[0]]))):
294 294 return 'markdown'
295 295 if filename.endswith(tuple(_filter([x[0] for x in cls.RST_EXTS if x[0]]))):
296 296 return 'rst'
297 297
298 298 return None
299 299
300 300 def render(self, source, filename=None):
301 301 """
302 302 Renders a given filename using detected renderer
303 303 it detects renderers based on file extension or mimetype.
304 304 At last it will just do a simple html replacing new lines with <br/>
305 305
306 306 :param file_name:
307 307 :param source:
308 308 """
309 309
310 310 renderer = self._detect_renderer(source, filename)
311 311 readme_data = renderer(source)
312 312 return readme_data
313 313
314 314 @classmethod
315 315 def _flavored_markdown(cls, text):
316 316 """
317 317 Github style flavored markdown
318 318
319 319 :param text:
320 320 """
321 321
322 322 # Extract pre blocks.
323 323 extractions = {}
324 324
325 325 def pre_extraction_callback(matchobj):
326 326 digest = md5_safe(matchobj.group(0))
327 327 extractions[digest] = matchobj.group(0)
328 328 return "{gfm-extraction-%s}" % digest
329 329 pattern = re.compile(r'<pre>.*?</pre>', re.MULTILINE | re.DOTALL)
330 330 text = re.sub(pattern, pre_extraction_callback, text)
331 331
332 332 # Prevent foo_bar_baz from ending up with an italic word in the middle.
333 333 def italic_callback(matchobj):
334 334 s = matchobj.group(0)
335 335 if list(s).count('_') >= 2:
336 336 return s.replace('_', r'\_')
337 337 return s
338 338 text = re.sub(r'^(?! {4}|\t)\w+_\w+_\w[\w_]*', italic_callback, text)
339 339
340 340 # Insert pre block extractions.
341 341 def pre_insert_callback(matchobj):
342 342 return '\n\n' + extractions[matchobj.group(1)]
343 343 text = re.sub(r'\{gfm-extraction-([0-9a-f]{32})\}',
344 344 pre_insert_callback, text)
345 345
346 346 return text
347 347
348 348 @classmethod
349 349 def urlify_text(cls, text):
350 350 def url_func(match_obj):
351 351 url_full = match_obj.groups()[0]
352 352 return '<a href="%(url)s">%(url)s</a>' % ({'url': url_full})
353 353
354 354 return cls.URL_PAT.sub(url_func, text)
355 355
356 356 @classmethod
357 357 def convert_mentions(cls, text, mode):
358 358 mention_pat = cls.MENTION_PAT
359 359
360 360 def wrapp(match_obj):
361 361 uname = match_obj.groups()[0]
362 362 hovercard_url = "pyroutes.url('hovercard_username', {'username': '%s'});" % uname
363 363
364 364 if mode == 'markdown':
365 365 tmpl = '<strong class="tooltip-hovercard" data-hovercard-alt="{uname}" data-hovercard-url="{hovercard_url}">@{uname}</strong>'
366 366 elif mode == 'rst':
367 367 tmpl = ' **@{uname}** '
368 368 else:
369 369 raise ValueError('mode must be rst or markdown')
370 370
371 371 return tmpl.format(**{'uname': uname,
372 372 'hovercard_url': hovercard_url})
373 373
374 374 return mention_pat.sub(wrapp, text).strip()
375 375
376 376 @classmethod
377 377 def plain(cls, source, universal_newline=True, leading_newline=True):
378 378 source = safe_unicode(source)
379 379 if universal_newline:
380 380 newline = '\n'
381 381 source = newline.join(source.splitlines())
382 382
383 383 rendered_source = cls.urlify_text(source)
384 384 source = ''
385 385 if leading_newline:
386 386 source += '<br />'
387 387 source += rendered_source.replace("\n", '<br />')
388 388
389 389 rendered = cls.bleach_clean(source)
390 390 return rendered
391 391
392 392 @classmethod
393 393 def markdown(cls, source, safe=True, flavored=True, mentions=False,
394 394 clean_html=True):
395 395 """
396 396 returns markdown rendered code cleaned by the bleach library
397 397 """
398 398
399 399 if flavored:
400 400 markdown_renderer = get_markdown_renderer_flavored(
401 401 cls.extensions, cls.output_format)
402 402 else:
403 403 markdown_renderer = get_markdown_renderer(
404 404 cls.extensions, cls.output_format)
405 405
406 406 if mentions:
407 407 mention_hl = cls.convert_mentions(source, mode='markdown')
408 408 # we extracted mentions render with this using Mentions false
409 409 return cls.markdown(mention_hl, safe=safe, flavored=flavored,
410 410 mentions=False)
411 411
412 412 source = safe_unicode(source)
413 413
414 414 try:
415 415 if flavored:
416 416 source = cls._flavored_markdown(source)
417 417 rendered = markdown_renderer.convert(source)
418 418 except Exception:
419 419 log.exception('Error when rendering Markdown')
420 420 if safe:
421 421 log.debug('Fallback to render in plain mode')
422 422 rendered = cls.plain(source)
423 423 else:
424 424 raise
425 425
426 426 if clean_html:
427 427 rendered = cls.bleach_clean(rendered)
428 428 return rendered
429 429
430 430 @classmethod
431 431 def rst(cls, source, safe=True, mentions=False, clean_html=False):
432 432 if mentions:
433 433 mention_hl = cls.convert_mentions(source, mode='rst')
434 434 # we extracted mentions render with this using Mentions false
435 435 return cls.rst(mention_hl, safe=safe, mentions=False)
436 436
437 437 source = safe_unicode(source)
438 438 try:
439 439 docutils_settings = dict(
440 440 [(alias, None) for alias in
441 441 cls.RESTRUCTUREDTEXT_DISALLOWED_DIRECTIVES])
442 442
443 443 docutils_settings.update({
444 444 'input_encoding': 'unicode',
445 445 'report_level': 4,
446 446 'syntax_highlight': 'short',
447 447 })
448 448
449 449 for k, v in docutils_settings.items():
450 450 directives.register_directive(k, v)
451 451
452 452 parts = publish_parts(source=source,
453 453 writer=RhodeCodeWriter(),
454 454 settings_overrides=docutils_settings)
455 455 rendered = parts["fragment"]
456 456 if clean_html:
457 457 rendered = cls.bleach_clean(rendered)
458 458 return parts['html_title'] + rendered
459 459 except Exception:
460 460 log.exception('Error when rendering RST')
461 461 if safe:
462 462 log.debug('Fallback to render in plain mode')
463 463 return cls.plain(source)
464 464 else:
465 465 raise
466 466
467 467 @classmethod
468 468 def jupyter(cls, source, safe=True):
469 469 from rhodecode.lib import helpers
470 470
471 471 from traitlets.config import Config
472 472 import nbformat
473 473 from nbconvert import HTMLExporter
474 474 from nbconvert.preprocessors import Preprocessor
475 475
476 476 class CustomHTMLExporter(HTMLExporter):
477 477 def _template_file_default(self):
478 478 return 'basic'
479 479
480 480 class Sandbox(Preprocessor):
481 481
482 482 def preprocess(self, nb, resources):
483 483 sandbox_text = 'SandBoxed(IPython.core.display.Javascript object)'
484 484 for cell in nb['cells']:
485 485 if not safe:
486 486 continue
487 487
488 488 if 'outputs' in cell:
489 489 for cell_output in cell['outputs']:
490 490 if 'data' in cell_output:
491 491 if 'application/javascript' in cell_output['data']:
492 492 cell_output['data']['text/plain'] = sandbox_text
493 493 cell_output['data'].pop('application/javascript', None)
494 494
495 495 if 'source' in cell and cell['cell_type'] == 'markdown':
496 496 # sanitize similar like in markdown
497 497 cell['source'] = cls.bleach_clean(cell['source'])
498 498
499 499 return nb, resources
500 500
501 501 def _sanitize_resources(input_resources):
502 502 """
503 503 Skip/sanitize some of the CSS generated and included in jupyter
504 504 so it doesn't messes up UI so much
505 505 """
506 506
507 507 # TODO(marcink): probably we should replace this with whole custom
508 508 # CSS set that doesn't screw up, but jupyter generated html has some
509 509 # special markers, so it requires Custom HTML exporter template with
510 510 # _default_template_path_default, to achieve that
511 511
512 512 # strip the reset CSS
513 513 input_resources[0] = input_resources[0][input_resources[0].find('/*! Source'):]
514 514 return input_resources
515 515
516 516 def as_html(notebook):
517 517 conf = Config()
518 518 conf.CustomHTMLExporter.preprocessors = [Sandbox]
519 519 html_exporter = CustomHTMLExporter(config=conf)
520 520
521 521 (body, resources) = html_exporter.from_notebook_node(notebook)
522 522 header = '<!-- ## IPYTHON NOTEBOOK RENDERING ## -->'
523 523 js = MakoTemplate(r'''
524 524 <!-- MathJax configuration -->
525 525 <script type="text/x-mathjax-config">
526 526 MathJax.Hub.Config({
527 527 jax: ["input/TeX","output/HTML-CSS", "output/PreviewHTML"],
528 528 extensions: ["tex2jax.js","MathMenu.js","MathZoom.js", "fast-preview.js", "AssistiveMML.js", "[Contrib]/a11y/accessibility-menu.js"],
529 529 TeX: {
530 530 extensions: ["AMSmath.js","AMSsymbols.js","noErrors.js","noUndefined.js"]
531 531 },
532 532 tex2jax: {
533 533 inlineMath: [ ['$','$'], ["\\(","\\)"] ],
534 534 displayMath: [ ['$$','$$'], ["\\[","\\]"] ],
535 535 processEscapes: true,
536 536 processEnvironments: true
537 537 },
538 538 // Center justify equations in code and markdown cells. Elsewhere
539 539 // we use CSS to left justify single line equations in code cells.
540 540 displayAlign: 'center',
541 541 "HTML-CSS": {
542 542 styles: {'.MathJax_Display': {"margin": 0}},
543 543 linebreaks: { automatic: true },
544 544 availableFonts: ["STIX", "TeX"]
545 545 },
546 546 showMathMenu: false
547 547 });
548 548 </script>
549 549 <!-- End of MathJax configuration -->
550 550 <script src="${h.asset('js/src/math_jax/MathJax.js')}"></script>
551 551 ''').render(h=helpers)
552 552
553 553 css = MakoTemplate(r'''
554 554 <link rel="stylesheet" type="text/css" href="${h.asset('css/style-ipython.css', ver=ver)}" media="screen"/>
555 555 ''').render(h=helpers, ver='ver1')
556 556
557 557 body = '\n'.join([header, css, js, body])
558 558 return body, resources
559 559
560 560 notebook = nbformat.reads(source, as_version=4)
561 561 (body, resources) = as_html(notebook)
562 562 return body
563 563
564 564
565 565 class RstTemplateRenderer(object):
566 566
567 567 def __init__(self):
568 568 base = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
569 569 rst_template_dirs = [os.path.join(base, 'templates', 'rst_templates')]
570 570 self.template_store = TemplateLookup(
571 571 directories=rst_template_dirs,
572 572 input_encoding='utf-8',
573 573 imports=['from rhodecode.lib import helpers as h'])
574 574
575 575 def _get_template(self, templatename):
576 576 return self.template_store.get_template(templatename)
577 577
578 578 def render(self, template_name, **kwargs):
579 579 template = self._get_template(template_name)
580 580 return template.render(**kwargs)
@@ -1,156 +1,156 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 """
22 22 SimpleGit middleware for handling git protocol request (push/clone etc.)
23 23 It's implemented with basic auth function
24 24 """
25 25 import os
26 26 import re
27 27 import logging
28 28 import urllib.parse
29 29
30 30 import rhodecode
31 31 from rhodecode.lib import utils
32 32 from rhodecode.lib import utils2
33 33 from rhodecode.lib.middleware import simplevcs
34 34
35 35 log = logging.getLogger(__name__)
36 36
37 37
38 38 GIT_PROTO_PAT = re.compile(
39 39 r'^/(.+)/(info/refs|info/lfs/(.+)|git-upload-pack|git-receive-pack)')
40 40 GIT_LFS_PROTO_PAT = re.compile(r'^/(.+)/(info/lfs/(.+))')
41 41
42 42
43 43 def default_lfs_store():
44 44 """
45 45 Default lfs store location, it's consistent with Mercurials large file
46 46 store which is in .cache/largefiles
47 47 """
48 48 from rhodecode.lib.vcs.backends.git import lfs_store
49 49 user_home = os.path.expanduser("~")
50 50 return lfs_store(user_home)
51 51
52 52
53 53 class SimpleGit(simplevcs.SimpleVCS):
54 54
55 55 SCM = 'git'
56 56
57 57 def _get_repository_name(self, environ):
58 58 """
59 59 Gets repository name out of PATH_INFO header
60 60
61 61 :param environ: environ where PATH_INFO is stored
62 62 """
63 63 repo_name = GIT_PROTO_PAT.match(environ['PATH_INFO']).group(1)
64 64 # for GIT LFS, and bare format strip .git suffix from names
65 65 if repo_name.endswith('.git'):
66 66 repo_name = repo_name[:-4]
67 67 return repo_name
68 68
69 69 def _get_lfs_action(self, path, request_method):
70 70 """
71 71 return an action based on LFS requests type.
72 72 Those routes are handled inside vcsserver app.
73 73
74 74 batch -> POST to /info/lfs/objects/batch => PUSH/PULL
75 75 batch is based on the `operation.
76 76 that could be download or upload, but those are only
77 77 instructions to fetch so we return pull always
78 78
79 79 download -> GET to /info/lfs/{oid} => PULL
80 80 upload -> PUT to /info/lfs/{oid} => PUSH
81 81
82 82 verification -> POST to /info/lfs/verify => PULL
83 83
84 84 """
85 85
86 86 match_obj = GIT_LFS_PROTO_PAT.match(path)
87 87 _parts = match_obj.groups()
88 88 repo_name, path, operation = _parts
89 89 log.debug(
90 90 'LFS: detecting operation based on following '
91 91 'data: %s, req_method:%s', _parts, request_method)
92 92
93 93 if operation == 'verify':
94 94 return 'pull'
95 95 elif operation == 'objects/batch':
96 96 # batch sends back instructions for API to dl/upl we report it
97 97 # as pull
98 98 if request_method == 'POST':
99 99 return 'pull'
100 100
101 101 elif operation:
102 102 # probably a OID, upload is PUT, download a GET
103 103 if request_method == 'GET':
104 104 return 'pull'
105 105 else:
106 106 return 'push'
107 107
108 108 # if default not found require push, as action
109 109 return 'push'
110 110
111 111 _ACTION_MAPPING = {
112 112 'git-receive-pack': 'push',
113 113 'git-upload-pack': 'pull',
114 114 }
115 115
116 116 def _get_action(self, environ):
117 117 """
118 118 Maps git request commands into a pull or push command.
119 119 In case of unknown/unexpected data, it returns 'pull' to be safe.
120 120
121 121 :param environ:
122 122 """
123 123 path = environ['PATH_INFO']
124 124
125 125 if path.endswith('/info/refs'):
126 query = urllib.parse.urlparse.parse_qs(environ['QUERY_STRING'])
126 query = urllib.parse.parse_qs(environ['QUERY_STRING'])
127 127 service_cmd = query.get('service', [''])[0]
128 128 return self._ACTION_MAPPING.get(service_cmd, 'pull')
129 129
130 130 elif GIT_LFS_PROTO_PAT.match(environ['PATH_INFO']):
131 131 return self._get_lfs_action(
132 132 environ['PATH_INFO'], environ['REQUEST_METHOD'])
133 133
134 134 elif path.endswith('/git-receive-pack'):
135 135 return 'push'
136 136 elif path.endswith('/git-upload-pack'):
137 137 return 'pull'
138 138
139 139 return 'pull'
140 140
141 141 def _create_wsgi_app(self, repo_path, repo_name, config):
142 142 return self.scm_app.create_git_wsgi_app(
143 143 repo_path, repo_name, config)
144 144
145 145 def _create_config(self, extras, repo_name, scheme='http'):
146 146 extras['git_update_server_info'] = utils2.str2bool(
147 147 rhodecode.CONFIG.get('git_update_server_info'))
148 148
149 149 config = utils.make_db_config(repo=repo_name)
150 150 custom_store = config.get('vcs_git_lfs', 'store_location')
151 151
152 152 extras['git_lfs_enabled'] = utils2.str2bool(
153 153 config.get('vcs_git_lfs', 'enabled'))
154 154 extras['git_lfs_store_path'] = custom_store or default_lfs_store()
155 155 extras['git_lfs_http_scheme'] = scheme
156 156 return extras
@@ -1,160 +1,159 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 """
22 22 SimpleHG middleware for handling mercurial protocol request
23 23 (push/clone etc.). It's implemented with basic auth function
24 24 """
25 25
26 26 import logging
27 27 import urllib.parse
28 28 import urllib.request, urllib.parse, urllib.error
29 29
30 30 from rhodecode.lib import utils
31 31 from rhodecode.lib.ext_json import json
32 32 from rhodecode.lib.middleware import simplevcs
33 33
34 34 log = logging.getLogger(__name__)
35 35
36 36
37 37 class SimpleHg(simplevcs.SimpleVCS):
38 38
39 39 SCM = 'hg'
40 40
41 41 def _get_repository_name(self, environ):
42 42 """
43 43 Gets repository name out of PATH_INFO header
44 44
45 45 :param environ: environ where PATH_INFO is stored
46 46 """
47 47 repo_name = environ['PATH_INFO']
48 48 if repo_name and repo_name.startswith('/'):
49 49 # remove only the first leading /
50 50 repo_name = repo_name[1:]
51 51 return repo_name.rstrip('/')
52 52
53 53 _ACTION_MAPPING = {
54 54 'changegroup': 'pull',
55 55 'changegroupsubset': 'pull',
56 56 'getbundle': 'pull',
57 57 'stream_out': 'pull',
58 58 'listkeys': 'pull',
59 59 'between': 'pull',
60 60 'branchmap': 'pull',
61 61 'branches': 'pull',
62 62 'clonebundles': 'pull',
63 63 'capabilities': 'pull',
64 64 'debugwireargs': 'pull',
65 65 'heads': 'pull',
66 66 'lookup': 'pull',
67 67 'hello': 'pull',
68 68 'known': 'pull',
69 69
70 70 # largefiles
71 71 'putlfile': 'push',
72 72 'getlfile': 'pull',
73 73 'statlfile': 'pull',
74 74 'lheads': 'pull',
75 75
76 76 # evolve
77 77 'evoext_obshashrange_v1': 'pull',
78 78 'evoext_obshash': 'pull',
79 79 'evoext_obshash1': 'pull',
80 80
81 81 'unbundle': 'push',
82 82 'pushkey': 'push',
83 83 }
84 84
85 85 @classmethod
86 86 def _get_xarg_headers(cls, environ):
87 87 i = 1
88 88 chunks = [] # gather chunks stored in multiple 'hgarg_N'
89 89 while True:
90 90 head = environ.get('HTTP_X_HGARG_{}'.format(i))
91 91 if not head:
92 92 break
93 93 i += 1
94 94 chunks.append(urllib.parse.unquote_plus(head))
95 95 full_arg = ''.join(chunks)
96 96 pref = 'cmds='
97 97 if full_arg.startswith(pref):
98 98 # strip the cmds= header defining our batch commands
99 99 full_arg = full_arg[len(pref):]
100 100 cmds = full_arg.split(';')
101 101 return cmds
102 102
103 103 @classmethod
104 104 def _get_batch_cmd(cls, environ):
105 105 """
106 106 Handle batch command send commands. Those are ';' separated commands
107 107 sent by batch command that server needs to execute. We need to extract
108 108 those, and map them to our ACTION_MAPPING to get all push/pull commands
109 109 specified in the batch
110 110 """
111 111 default = 'push'
112 112 batch_cmds = []
113 113 try:
114 114 cmds = cls._get_xarg_headers(environ)
115 115 for pair in cmds:
116 116 parts = pair.split(' ', 1)
117 117 if len(parts) != 2:
118 118 continue
119 119 # entry should be in a format `key ARGS`
120 120 cmd, args = parts
121 121 action = cls._ACTION_MAPPING.get(cmd, default)
122 122 batch_cmds.append(action)
123 123 except Exception:
124 124 log.exception('Failed to extract batch commands operations')
125 125
126 126 # in case we failed, (e.g malformed data) assume it's PUSH sub-command
127 127 # for safety
128 128 return batch_cmds or [default]
129 129
130 130 def _get_action(self, environ):
131 131 """
132 132 Maps mercurial request commands into a pull or push command.
133 133 In case of unknown/unexpected data, it returns 'push' to be safe.
134 134
135 135 :param environ:
136 136 """
137 137 default = 'push'
138 query = urllib.parse.urlparse.parse_qs(environ['QUERY_STRING'],
139 keep_blank_values=True)
138 query = urllib.parse.parse_qs(environ['QUERY_STRING'], keep_blank_values=True)
140 139
141 140 if 'cmd' in query:
142 141 cmd = query['cmd'][0]
143 142 if cmd == 'batch':
144 143 cmds = self._get_batch_cmd(environ)
145 144 if 'push' in cmds:
146 145 return 'push'
147 146 else:
148 147 return 'pull'
149 148 return self._ACTION_MAPPING.get(cmd, default)
150 149
151 150 return default
152 151
153 152 def _create_wsgi_app(self, repo_path, repo_name, config):
154 153 return self.scm_app.create_hg_wsgi_app(repo_path, repo_name, config)
155 154
156 155 def _create_config(self, extras, repo_name, scheme='http'):
157 156 config = utils.make_db_config(repo=repo_name)
158 157 config.set('rhodecode', 'RC_SCM_DATA', json.dumps(extras))
159 158
160 159 return config.serialize()
@@ -1,679 +1,679 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2014-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 """
22 22 SimpleVCS middleware for handling protocol request (push/clone etc.)
23 23 It's implemented with basic auth function
24 24 """
25 25
26 26 import os
27 27 import re
28 import io
28 29 import logging
29 30 import importlib
30 31 from functools import wraps
31 from io import StringIO
32 32 from lxml import etree
33 33
34 34 import time
35 35 from paste.httpheaders import REMOTE_USER, AUTH_TYPE
36 36
37 37 from pyramid.httpexceptions import (
38 38 HTTPNotFound, HTTPForbidden, HTTPNotAcceptable, HTTPInternalServerError)
39 39 from zope.cachedescriptors.property import Lazy as LazyProperty
40 40
41 41 import rhodecode
42 42 from rhodecode.authentication.base import authenticate, VCS_TYPE, loadplugin
43 43 from rhodecode.lib import rc_cache
44 44 from rhodecode.lib.auth import AuthUser, HasPermissionAnyMiddleware
45 45 from rhodecode.lib.base import (
46 46 BasicAuth, get_ip_addr, get_user_agent, vcs_operation_context)
47 47 from rhodecode.lib.exceptions import (UserCreationError, NotAllowedToCreateUserError)
48 48 from rhodecode.lib.hooks_daemon import prepare_callback_daemon
49 49 from rhodecode.lib.middleware import appenlight
50 50 from rhodecode.lib.middleware.utils import scm_app_http
51 51 from rhodecode.lib.utils import is_valid_repo, SLUG_RE
52 52 from rhodecode.lib.utils2 import safe_str, fix_PATH, str2bool, safe_unicode
53 53 from rhodecode.lib.vcs.conf import settings as vcs_settings
54 54 from rhodecode.lib.vcs.backends import base
55 55
56 56 from rhodecode.model import meta
57 57 from rhodecode.model.db import User, Repository, PullRequest
58 58 from rhodecode.model.scm import ScmModel
59 59 from rhodecode.model.pull_request import PullRequestModel
60 60 from rhodecode.model.settings import SettingsModel, VcsSettingsModel
61 61
62 62 log = logging.getLogger(__name__)
63 63
64 64
65 65 def extract_svn_txn_id(acl_repo_name, data):
66 66 """
67 67 Helper method for extraction of svn txn_id from submitted XML data during
68 68 POST operations
69 69 """
70 70 try:
71 71 root = etree.fromstring(data)
72 72 pat = re.compile(r'/txn/(?P<txn_id>.*)')
73 73 for el in root:
74 74 if el.tag == '{DAV:}source':
75 75 for sub_el in el:
76 76 if sub_el.tag == '{DAV:}href':
77 77 match = pat.search(sub_el.text)
78 78 if match:
79 79 svn_tx_id = match.groupdict()['txn_id']
80 80 txn_id = rc_cache.utils.compute_key_from_params(
81 81 acl_repo_name, svn_tx_id)
82 82 return txn_id
83 83 except Exception:
84 84 log.exception('Failed to extract txn_id')
85 85
86 86
87 87 def initialize_generator(factory):
88 88 """
89 89 Initializes the returned generator by draining its first element.
90 90
91 91 This can be used to give a generator an initializer, which is the code
92 92 up to the first yield statement. This decorator enforces that the first
93 93 produced element has the value ``"__init__"`` to make its special
94 94 purpose very explicit in the using code.
95 95 """
96 96
97 97 @wraps(factory)
98 98 def wrapper(*args, **kwargs):
99 99 gen = factory(*args, **kwargs)
100 100 try:
101 101 init = next(gen)
102 102 except StopIteration:
103 103 raise ValueError('Generator must yield at least one element.')
104 104 if init != "__init__":
105 105 raise ValueError('First yielded element must be "__init__".')
106 106 return gen
107 107 return wrapper
108 108
109 109
110 110 class SimpleVCS(object):
111 111 """Common functionality for SCM HTTP handlers."""
112 112
113 113 SCM = 'unknown'
114 114
115 115 acl_repo_name = None
116 116 url_repo_name = None
117 117 vcs_repo_name = None
118 118 rc_extras = {}
119 119
120 120 # We have to handle requests to shadow repositories different than requests
121 121 # to normal repositories. Therefore we have to distinguish them. To do this
122 122 # we use this regex which will match only on URLs pointing to shadow
123 123 # repositories.
124 124 shadow_repo_re = re.compile(
125 125 '(?P<groups>(?:{slug_pat}/)*)' # repo groups
126 126 '(?P<target>{slug_pat})/' # target repo
127 'pull-request/(?P<pr_id>\d+)/' # pull request
127 'pull-request/(?P<pr_id>\\d+)/' # pull request
128 128 'repository$' # shadow repo
129 129 .format(slug_pat=SLUG_RE.pattern))
130 130
131 131 def __init__(self, config, registry):
132 132 self.registry = registry
133 133 self.config = config
134 134 # re-populated by specialized middleware
135 135 self.repo_vcs_config = base.Config()
136 136
137 137 rc_settings = SettingsModel().get_all_settings(cache=True, from_request=False)
138 138 realm = rc_settings.get('rhodecode_realm') or 'RhodeCode AUTH'
139 139
140 140 # authenticate this VCS request using authfunc
141 141 auth_ret_code_detection = \
142 142 str2bool(self.config.get('auth_ret_code_detection', False))
143 143 self.authenticate = BasicAuth(
144 144 '', authenticate, registry, config.get('auth_ret_code'),
145 145 auth_ret_code_detection, rc_realm=realm)
146 146 self.ip_addr = '0.0.0.0'
147 147
148 148 @LazyProperty
149 149 def global_vcs_config(self):
150 150 try:
151 151 return VcsSettingsModel().get_ui_settings_as_config_obj()
152 152 except Exception:
153 153 return base.Config()
154 154
155 155 @property
156 156 def base_path(self):
157 157 settings_path = self.repo_vcs_config.get(*VcsSettingsModel.PATH_SETTING)
158 158
159 159 if not settings_path:
160 160 settings_path = self.global_vcs_config.get(*VcsSettingsModel.PATH_SETTING)
161 161
162 162 if not settings_path:
163 163 # try, maybe we passed in explicitly as config option
164 164 settings_path = self.config.get('base_path')
165 165
166 166 if not settings_path:
167 167 raise ValueError('FATAL: base_path is empty')
168 168 return settings_path
169 169
170 170 def set_repo_names(self, environ):
171 171 """
172 172 This will populate the attributes acl_repo_name, url_repo_name,
173 173 vcs_repo_name and is_shadow_repo. In case of requests to normal (non
174 174 shadow) repositories all names are equal. In case of requests to a
175 175 shadow repository the acl-name points to the target repo of the pull
176 176 request and the vcs-name points to the shadow repo file system path.
177 177 The url-name is always the URL used by the vcs client program.
178 178
179 179 Example in case of a shadow repo:
180 180 acl_repo_name = RepoGroup/MyRepo
181 181 url_repo_name = RepoGroup/MyRepo/pull-request/3/repository
182 182 vcs_repo_name = /repo/base/path/RepoGroup/.__shadow_MyRepo_pr-3'
183 183 """
184 184 # First we set the repo name from URL for all attributes. This is the
185 185 # default if handling normal (non shadow) repo requests.
186 186 self.url_repo_name = self._get_repository_name(environ)
187 187 self.acl_repo_name = self.vcs_repo_name = self.url_repo_name
188 188 self.is_shadow_repo = False
189 189
190 190 # Check if this is a request to a shadow repository.
191 191 match = self.shadow_repo_re.match(self.url_repo_name)
192 192 if match:
193 193 match_dict = match.groupdict()
194 194
195 195 # Build acl repo name from regex match.
196 196 acl_repo_name = safe_unicode('{groups}{target}'.format(
197 197 groups=match_dict['groups'] or '',
198 198 target=match_dict['target']))
199 199
200 200 # Retrieve pull request instance by ID from regex match.
201 201 pull_request = PullRequest.get(match_dict['pr_id'])
202 202
203 203 # Only proceed if we got a pull request and if acl repo name from
204 204 # URL equals the target repo name of the pull request.
205 205 if pull_request and (acl_repo_name == pull_request.target_repo.repo_name):
206 206
207 207 # Get file system path to shadow repository.
208 208 workspace_id = PullRequestModel()._workspace_id(pull_request)
209 209 vcs_repo_name = pull_request.target_repo.get_shadow_repository_path(workspace_id)
210 210
211 211 # Store names for later usage.
212 212 self.vcs_repo_name = vcs_repo_name
213 213 self.acl_repo_name = acl_repo_name
214 214 self.is_shadow_repo = True
215 215
216 216 log.debug('Setting all VCS repository names: %s', {
217 217 'acl_repo_name': self.acl_repo_name,
218 218 'url_repo_name': self.url_repo_name,
219 219 'vcs_repo_name': self.vcs_repo_name,
220 220 })
221 221
222 222 @property
223 223 def scm_app(self):
224 224 custom_implementation = self.config['vcs.scm_app_implementation']
225 225 if custom_implementation == 'http':
226 226 log.debug('Using HTTP implementation of scm app.')
227 227 scm_app_impl = scm_app_http
228 228 else:
229 229 log.debug('Using custom implementation of scm_app: "{}"'.format(
230 230 custom_implementation))
231 231 scm_app_impl = importlib.import_module(custom_implementation)
232 232 return scm_app_impl
233 233
234 234 def _get_by_id(self, repo_name):
235 235 """
236 236 Gets a special pattern _<ID> from clone url and tries to replace it
237 237 with a repository_name for support of _<ID> non changeable urls
238 238 """
239 239
240 240 data = repo_name.split('/')
241 241 if len(data) >= 2:
242 242 from rhodecode.model.repo import RepoModel
243 243 by_id_match = RepoModel().get_repo_by_id(repo_name)
244 244 if by_id_match:
245 245 data[1] = by_id_match.repo_name
246 246
247 247 return safe_str('/'.join(data))
248 248
249 249 def _invalidate_cache(self, repo_name):
250 250 """
251 251 Set's cache for this repository for invalidation on next access
252 252
253 253 :param repo_name: full repo name, also a cache key
254 254 """
255 255 ScmModel().mark_for_invalidation(repo_name)
256 256
257 257 def is_valid_and_existing_repo(self, repo_name, base_path, scm_type):
258 258 db_repo = Repository.get_by_repo_name(repo_name)
259 259 if not db_repo:
260 260 log.debug('Repository `%s` not found inside the database.',
261 261 repo_name)
262 262 return False
263 263
264 264 if db_repo.repo_type != scm_type:
265 265 log.warning(
266 266 'Repository `%s` have incorrect scm_type, expected %s got %s',
267 267 repo_name, db_repo.repo_type, scm_type)
268 268 return False
269 269
270 270 config = db_repo._config
271 271 config.set('extensions', 'largefiles', '')
272 272 return is_valid_repo(
273 273 repo_name, base_path,
274 274 explicit_scm=scm_type, expect_scm=scm_type, config=config)
275 275
276 276 def valid_and_active_user(self, user):
277 277 """
278 278 Checks if that user is not empty, and if it's actually object it checks
279 279 if he's active.
280 280
281 281 :param user: user object or None
282 282 :return: boolean
283 283 """
284 284 if user is None:
285 285 return False
286 286
287 287 elif user.active:
288 288 return True
289 289
290 290 return False
291 291
292 292 @property
293 293 def is_shadow_repo_dir(self):
294 294 return os.path.isdir(self.vcs_repo_name)
295 295
296 296 def _check_permission(self, action, user, auth_user, repo_name, ip_addr=None,
297 297 plugin_id='', plugin_cache_active=False, cache_ttl=0):
298 298 """
299 299 Checks permissions using action (push/pull) user and repository
300 300 name. If plugin_cache and ttl is set it will use the plugin which
301 301 authenticated the user to store the cached permissions result for N
302 302 amount of seconds as in cache_ttl
303 303
304 304 :param action: push or pull action
305 305 :param user: user instance
306 306 :param repo_name: repository name
307 307 """
308 308
309 309 log.debug('AUTH_CACHE_TTL for permissions `%s` active: %s (TTL: %s)',
310 310 plugin_id, plugin_cache_active, cache_ttl)
311 311
312 312 user_id = user.user_id
313 313 cache_namespace_uid = 'cache_user_auth.{}'.format(user_id)
314 314 region = rc_cache.get_or_create_region('cache_perms', cache_namespace_uid)
315 315
316 316 @region.conditional_cache_on_arguments(namespace=cache_namespace_uid,
317 317 expiration_time=cache_ttl,
318 318 condition=plugin_cache_active)
319 319 def compute_perm_vcs(
320 320 cache_name, plugin_id, action, user_id, repo_name, ip_addr):
321 321
322 322 log.debug('auth: calculating permission access now...')
323 323 # check IP
324 324 inherit = user.inherit_default_permissions
325 325 ip_allowed = AuthUser.check_ip_allowed(
326 326 user_id, ip_addr, inherit_from_default=inherit)
327 327 if ip_allowed:
328 328 log.info('Access for IP:%s allowed', ip_addr)
329 329 else:
330 330 return False
331 331
332 332 if action == 'push':
333 333 perms = ('repository.write', 'repository.admin')
334 334 if not HasPermissionAnyMiddleware(*perms)(auth_user, repo_name):
335 335 return False
336 336
337 337 else:
338 338 # any other action need at least read permission
339 339 perms = (
340 340 'repository.read', 'repository.write', 'repository.admin')
341 341 if not HasPermissionAnyMiddleware(*perms)(auth_user, repo_name):
342 342 return False
343 343
344 344 return True
345 345
346 346 start = time.time()
347 347 log.debug('Running plugin `%s` permissions check', plugin_id)
348 348
349 349 # for environ based auth, password can be empty, but then the validation is
350 350 # on the server that fills in the env data needed for authentication
351 351 perm_result = compute_perm_vcs(
352 352 'vcs_permissions', plugin_id, action, user.user_id, repo_name, ip_addr)
353 353
354 354 auth_time = time.time() - start
355 355 log.debug('Permissions for plugin `%s` completed in %.4fs, '
356 356 'expiration time of fetched cache %.1fs.',
357 357 plugin_id, auth_time, cache_ttl)
358 358
359 359 return perm_result
360 360
361 361 def _get_http_scheme(self, environ):
362 362 try:
363 363 return environ['wsgi.url_scheme']
364 364 except Exception:
365 365 log.exception('Failed to read http scheme')
366 366 return 'http'
367 367
368 368 def _check_ssl(self, environ, start_response):
369 369 """
370 370 Checks the SSL check flag and returns False if SSL is not present
371 371 and required True otherwise
372 372 """
373 373 org_proto = environ['wsgi._org_proto']
374 374 # check if we have SSL required ! if not it's a bad request !
375 375 require_ssl = str2bool(self.repo_vcs_config.get('web', 'push_ssl'))
376 376 if require_ssl and org_proto == 'http':
377 377 log.debug(
378 378 'Bad request: detected protocol is `%s` and '
379 379 'SSL/HTTPS is required.', org_proto)
380 380 return False
381 381 return True
382 382
383 383 def _get_default_cache_ttl(self):
384 384 # take AUTH_CACHE_TTL from the `rhodecode` auth plugin
385 385 plugin = loadplugin('egg:rhodecode-enterprise-ce#rhodecode')
386 386 plugin_settings = plugin.get_settings()
387 387 plugin_cache_active, cache_ttl = plugin.get_ttl_cache(
388 388 plugin_settings) or (False, 0)
389 389 return plugin_cache_active, cache_ttl
390 390
391 391 def __call__(self, environ, start_response):
392 392 try:
393 393 return self._handle_request(environ, start_response)
394 394 except Exception:
395 395 log.exception("Exception while handling request")
396 396 appenlight.track_exception(environ)
397 397 return HTTPInternalServerError()(environ, start_response)
398 398 finally:
399 399 meta.Session.remove()
400 400
401 401 def _handle_request(self, environ, start_response):
402 402 if not self._check_ssl(environ, start_response):
403 403 reason = ('SSL required, while RhodeCode was unable '
404 404 'to detect this as SSL request')
405 405 log.debug('User not allowed to proceed, %s', reason)
406 406 return HTTPNotAcceptable(reason)(environ, start_response)
407 407
408 408 if not self.url_repo_name:
409 409 log.warning('Repository name is empty: %s', self.url_repo_name)
410 410 # failed to get repo name, we fail now
411 411 return HTTPNotFound()(environ, start_response)
412 412 log.debug('Extracted repo name is %s', self.url_repo_name)
413 413
414 414 ip_addr = get_ip_addr(environ)
415 415 user_agent = get_user_agent(environ)
416 416 username = None
417 417
418 418 # skip passing error to error controller
419 419 environ['pylons.status_code_redirect'] = True
420 420
421 421 # ======================================================================
422 422 # GET ACTION PULL or PUSH
423 423 # ======================================================================
424 424 action = self._get_action(environ)
425 425
426 426 # ======================================================================
427 427 # Check if this is a request to a shadow repository of a pull request.
428 428 # In this case only pull action is allowed.
429 429 # ======================================================================
430 430 if self.is_shadow_repo and action != 'pull':
431 431 reason = 'Only pull action is allowed for shadow repositories.'
432 432 log.debug('User not allowed to proceed, %s', reason)
433 433 return HTTPNotAcceptable(reason)(environ, start_response)
434 434
435 435 # Check if the shadow repo actually exists, in case someone refers
436 436 # to it, and it has been deleted because of successful merge.
437 437 if self.is_shadow_repo and not self.is_shadow_repo_dir:
438 438 log.debug(
439 439 'Shadow repo detected, and shadow repo dir `%s` is missing',
440 440 self.is_shadow_repo_dir)
441 441 return HTTPNotFound()(environ, start_response)
442 442
443 443 # ======================================================================
444 444 # CHECK ANONYMOUS PERMISSION
445 445 # ======================================================================
446 446 detect_force_push = False
447 447 check_branch_perms = False
448 448 if action in ['pull', 'push']:
449 449 user_obj = anonymous_user = User.get_default_user()
450 450 auth_user = user_obj.AuthUser()
451 451 username = anonymous_user.username
452 452 if anonymous_user.active:
453 453 plugin_cache_active, cache_ttl = self._get_default_cache_ttl()
454 454 # ONLY check permissions if the user is activated
455 455 anonymous_perm = self._check_permission(
456 456 action, anonymous_user, auth_user, self.acl_repo_name, ip_addr,
457 457 plugin_id='anonymous_access',
458 458 plugin_cache_active=plugin_cache_active,
459 459 cache_ttl=cache_ttl,
460 460 )
461 461 else:
462 462 anonymous_perm = False
463 463
464 464 if not anonymous_user.active or not anonymous_perm:
465 465 if not anonymous_user.active:
466 466 log.debug('Anonymous access is disabled, running '
467 467 'authentication')
468 468
469 469 if not anonymous_perm:
470 470 log.debug('Not enough credentials to access this '
471 471 'repository as anonymous user')
472 472
473 473 username = None
474 474 # ==============================================================
475 475 # DEFAULT PERM FAILED OR ANONYMOUS ACCESS IS DISABLED SO WE
476 476 # NEED TO AUTHENTICATE AND ASK FOR AUTH USER PERMISSIONS
477 477 # ==============================================================
478 478
479 479 # try to auth based on environ, container auth methods
480 480 log.debug('Running PRE-AUTH for container based authentication')
481 481 pre_auth = authenticate(
482 482 '', '', environ, VCS_TYPE, registry=self.registry,
483 483 acl_repo_name=self.acl_repo_name)
484 484 if pre_auth and pre_auth.get('username'):
485 485 username = pre_auth['username']
486 486 log.debug('PRE-AUTH got %s as username', username)
487 487 if pre_auth:
488 488 log.debug('PRE-AUTH successful from %s',
489 489 pre_auth.get('auth_data', {}).get('_plugin'))
490 490
491 491 # If not authenticated by the container, running basic auth
492 492 # before inject the calling repo_name for special scope checks
493 493 self.authenticate.acl_repo_name = self.acl_repo_name
494 494
495 495 plugin_cache_active, cache_ttl = False, 0
496 496 plugin = None
497 497 if not username:
498 498 self.authenticate.realm = self.authenticate.get_rc_realm()
499 499
500 500 try:
501 501 auth_result = self.authenticate(environ)
502 502 except (UserCreationError, NotAllowedToCreateUserError) as e:
503 503 log.error(e)
504 504 reason = safe_str(e)
505 505 return HTTPNotAcceptable(reason)(environ, start_response)
506 506
507 507 if isinstance(auth_result, dict):
508 508 AUTH_TYPE.update(environ, 'basic')
509 509 REMOTE_USER.update(environ, auth_result['username'])
510 510 username = auth_result['username']
511 511 plugin = auth_result.get('auth_data', {}).get('_plugin')
512 512 log.info(
513 513 'MAIN-AUTH successful for user `%s` from %s plugin',
514 514 username, plugin)
515 515
516 516 plugin_cache_active, cache_ttl = auth_result.get(
517 517 'auth_data', {}).get('_ttl_cache') or (False, 0)
518 518 else:
519 519 return auth_result.wsgi_application(environ, start_response)
520 520
521 521 # ==============================================================
522 522 # CHECK PERMISSIONS FOR THIS REQUEST USING GIVEN USERNAME
523 523 # ==============================================================
524 524 user = User.get_by_username(username)
525 525 if not self.valid_and_active_user(user):
526 526 return HTTPForbidden()(environ, start_response)
527 527 username = user.username
528 528 user_id = user.user_id
529 529
530 530 # check user attributes for password change flag
531 531 user_obj = user
532 532 auth_user = user_obj.AuthUser()
533 533 if user_obj and user_obj.username != User.DEFAULT_USER and \
534 534 user_obj.user_data.get('force_password_change'):
535 535 reason = 'password change required'
536 536 log.debug('User not allowed to authenticate, %s', reason)
537 537 return HTTPNotAcceptable(reason)(environ, start_response)
538 538
539 539 # check permissions for this repository
540 540 perm = self._check_permission(
541 541 action, user, auth_user, self.acl_repo_name, ip_addr,
542 542 plugin, plugin_cache_active, cache_ttl)
543 543 if not perm:
544 544 return HTTPForbidden()(environ, start_response)
545 545 environ['rc_auth_user_id'] = user_id
546 546
547 547 if action == 'push':
548 548 perms = auth_user.get_branch_permissions(self.acl_repo_name)
549 549 if perms:
550 550 check_branch_perms = True
551 551 detect_force_push = True
552 552
553 553 # extras are injected into UI object and later available
554 554 # in hooks executed by RhodeCode
555 555 check_locking = _should_check_locking(environ.get('QUERY_STRING'))
556 556
557 557 extras = vcs_operation_context(
558 558 environ, repo_name=self.acl_repo_name, username=username,
559 559 action=action, scm=self.SCM, check_locking=check_locking,
560 560 is_shadow_repo=self.is_shadow_repo, check_branch_perms=check_branch_perms,
561 561 detect_force_push=detect_force_push
562 562 )
563 563
564 564 # ======================================================================
565 565 # REQUEST HANDLING
566 566 # ======================================================================
567 567 repo_path = os.path.join(
568 568 safe_str(self.base_path), safe_str(self.vcs_repo_name))
569 569 log.debug('Repository path is %s', repo_path)
570 570
571 571 fix_PATH()
572 572
573 573 log.info(
574 574 '%s action on %s repo "%s" by "%s" from %s %s',
575 575 action, self.SCM, safe_str(self.url_repo_name),
576 576 safe_str(username), ip_addr, user_agent)
577 577
578 578 return self._generate_vcs_response(
579 579 environ, start_response, repo_path, extras, action)
580 580
581 581 @initialize_generator
582 582 def _generate_vcs_response(
583 583 self, environ, start_response, repo_path, extras, action):
584 584 """
585 585 Returns a generator for the response content.
586 586
587 587 This method is implemented as a generator, so that it can trigger
588 588 the cache validation after all content sent back to the client. It
589 589 also handles the locking exceptions which will be triggered when
590 590 the first chunk is produced by the underlying WSGI application.
591 591 """
592 592 txn_id = ''
593 593 if 'CONTENT_LENGTH' in environ and environ['REQUEST_METHOD'] == 'MERGE':
594 594 # case for SVN, we want to re-use the callback daemon port
595 595 # so we use the txn_id, for this we peek the body, and still save
596 596 # it as wsgi.input
597 597 data = environ['wsgi.input'].read()
598 environ['wsgi.input'] = StringIO(data)
598 environ['wsgi.input'] = io.StringIO(data)
599 599 txn_id = extract_svn_txn_id(self.acl_repo_name, data)
600 600
601 601 callback_daemon, extras = self._prepare_callback_daemon(
602 602 extras, environ, action, txn_id=txn_id)
603 603 log.debug('HOOKS extras is %s', extras)
604 604
605 605 http_scheme = self._get_http_scheme(environ)
606 606
607 607 config = self._create_config(extras, self.acl_repo_name, scheme=http_scheme)
608 608 app = self._create_wsgi_app(repo_path, self.url_repo_name, config)
609 609 with callback_daemon:
610 610 app.rc_extras = extras
611 611
612 612 try:
613 613 response = app(environ, start_response)
614 614 finally:
615 615 # This statement works together with the decorator
616 616 # "initialize_generator" above. The decorator ensures that
617 617 # we hit the first yield statement before the generator is
618 618 # returned back to the WSGI server. This is needed to
619 619 # ensure that the call to "app" above triggers the
620 620 # needed callback to "start_response" before the
621 621 # generator is actually used.
622 622 yield "__init__"
623 623
624 624 # iter content
625 625 for chunk in response:
626 626 yield chunk
627 627
628 628 try:
629 629 # invalidate cache on push
630 630 if action == 'push':
631 631 self._invalidate_cache(self.url_repo_name)
632 632 finally:
633 633 meta.Session.remove()
634 634
635 635 def _get_repository_name(self, environ):
636 636 """Get repository name out of the environmnent
637 637
638 638 :param environ: WSGI environment
639 639 """
640 640 raise NotImplementedError()
641 641
642 642 def _get_action(self, environ):
643 643 """Map request commands into a pull or push command.
644 644
645 645 :param environ: WSGI environment
646 646 """
647 647 raise NotImplementedError()
648 648
649 649 def _create_wsgi_app(self, repo_path, repo_name, config):
650 650 """Return the WSGI app that will finally handle the request."""
651 651 raise NotImplementedError()
652 652
653 653 def _create_config(self, extras, repo_name, scheme='http'):
654 654 """Create a safe config representation."""
655 655 raise NotImplementedError()
656 656
657 657 def _should_use_callback_daemon(self, extras, environ, action):
658 658 if extras.get('is_shadow_repo'):
659 659 # we don't want to execute hooks, and callback daemon for shadow repos
660 660 return False
661 661 return True
662 662
663 663 def _prepare_callback_daemon(self, extras, environ, action, txn_id=None):
664 664 direct_calls = vcs_settings.HOOKS_DIRECT_CALLS
665 665 if not self._should_use_callback_daemon(extras, environ, action):
666 666 # disable callback daemon for actions that don't require it
667 667 direct_calls = True
668 668
669 669 return prepare_callback_daemon(
670 670 extras, protocol=vcs_settings.HOOKS_PROTOCOL,
671 671 host=vcs_settings.HOOKS_HOST, use_direct_calls=direct_calls, txn_id=txn_id)
672 672
673 673
674 674 def _should_check_locking(query_string):
675 675 # this is kind of hacky, but due to how mercurial handles client-server
676 676 # server see all operation on commit; bookmarks, phases and
677 677 # obsolescence marker in different transaction, we don't want to check
678 678 # locking on those
679 679 return query_string not in ['cmd=listkeys']
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
General Comments 0
You need to be logged in to leave comments. Login now