##// END OF EJS Templates
python3: fixed various code issues...
super-admin -
r4973:5e52ba1a default
parent child Browse files
Show More
@@ -1,578 +1,578 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import itertools
22 22 import logging
23 23 import sys
24 24 import types
25 25 import fnmatch
26 26
27 27 import decorator
28 28 import venusian
29 29 from collections import OrderedDict
30 30
31 31 from pyramid.exceptions import ConfigurationError
32 32 from pyramid.renderers import render
33 33 from pyramid.response import Response
34 34 from pyramid.httpexceptions import HTTPNotFound
35 35
36 36 from rhodecode.api.exc import (
37 37 JSONRPCBaseError, JSONRPCError, JSONRPCForbidden, JSONRPCValidationError)
38 38 from rhodecode.apps._base import TemplateArgs
39 39 from rhodecode.lib.auth import AuthUser
40 40 from rhodecode.lib.base import get_ip_addr, attach_context_attributes
41 41 from rhodecode.lib.exc_tracking import store_exception
42 42 from rhodecode.lib.ext_json import json
43 43 from rhodecode.lib.utils2 import safe_str
44 44 from rhodecode.lib.plugins.utils import get_plugin_settings
45 45 from rhodecode.model.db import User, UserApiKeys
46 46
47 47 log = logging.getLogger(__name__)
48 48
49 49 DEFAULT_RENDERER = 'jsonrpc_renderer'
50 50 DEFAULT_URL = '/_admin/apiv2'
51 51
52 52
53 53 def find_methods(jsonrpc_methods, pattern):
54 54 matches = OrderedDict()
55 55 if not isinstance(pattern, (list, tuple)):
56 56 pattern = [pattern]
57 57
58 58 for single_pattern in pattern:
59 59 for method_name, method in jsonrpc_methods.items():
60 60 if fnmatch.fnmatch(method_name, single_pattern):
61 61 matches[method_name] = method
62 62 return matches
63 63
64 64
65 65 class ExtJsonRenderer(object):
66 66 """
67 67 Custom renderer that mkaes use of our ext_json lib
68 68
69 69 """
70 70
71 71 def __init__(self, serializer=json.dumps, **kw):
72 72 """ Any keyword arguments will be passed to the ``serializer``
73 73 function."""
74 74 self.serializer = serializer
75 75 self.kw = kw
76 76
77 77 def __call__(self, info):
78 78 """ Returns a plain JSON-encoded string with content-type
79 79 ``application/json``. The content-type may be overridden by
80 80 setting ``request.response.content_type``."""
81 81
82 82 def _render(value, system):
83 83 request = system.get('request')
84 84 if request is not None:
85 85 response = request.response
86 86 ct = response.content_type
87 87 if ct == response.default_content_type:
88 88 response.content_type = 'application/json'
89 89
90 90 return self.serializer(value, **self.kw)
91 91
92 92 return _render
93 93
94 94
95 95 def jsonrpc_response(request, result):
96 96 rpc_id = getattr(request, 'rpc_id', None)
97 97 response = request.response
98 98
99 99 # store content_type before render is called
100 100 ct = response.content_type
101 101
102 102 ret_value = ''
103 103 if rpc_id:
104 104 ret_value = {
105 105 'id': rpc_id,
106 106 'result': result,
107 107 'error': None,
108 108 }
109 109
110 110 # fetch deprecation warnings, and store it inside results
111 111 deprecation = getattr(request, 'rpc_deprecation', None)
112 112 if deprecation:
113 113 ret_value['DEPRECATION_WARNING'] = deprecation
114 114
115 115 raw_body = render(DEFAULT_RENDERER, ret_value, request=request)
116 116 response.body = safe_str(raw_body, response.charset)
117 117
118 118 if ct == response.default_content_type:
119 119 response.content_type = 'application/json'
120 120
121 121 return response
122 122
123 123
124 124 def jsonrpc_error(request, message, retid=None, code=None, headers=None):
125 125 """
126 126 Generate a Response object with a JSON-RPC error body
127 127
128 128 :param code:
129 129 :param retid:
130 130 :param message:
131 131 """
132 132 err_dict = {'id': retid, 'result': None, 'error': message}
133 133 body = render(DEFAULT_RENDERER, err_dict, request=request).encode('utf-8')
134 134
135 135 return Response(
136 136 body=body,
137 137 status=code,
138 138 content_type='application/json',
139 139 headerlist=headers
140 140 )
141 141
142 142
143 143 def exception_view(exc, request):
144 144 rpc_id = getattr(request, 'rpc_id', None)
145 145
146 146 if isinstance(exc, JSONRPCError):
147 147 fault_message = safe_str(exc.message)
148 148 log.debug('json-rpc error rpc_id:%s "%s"', rpc_id, fault_message)
149 149 elif isinstance(exc, JSONRPCValidationError):
150 150 colander_exc = exc.colander_exception
151 151 # TODO(marcink): think maybe of nicer way to serialize errors ?
152 152 fault_message = colander_exc.asdict()
153 153 log.debug('json-rpc colander error rpc_id:%s "%s"', rpc_id, fault_message)
154 154 elif isinstance(exc, JSONRPCForbidden):
155 155 fault_message = 'Access was denied to this resource.'
156 156 log.warning('json-rpc forbidden call rpc_id:%s "%s"', rpc_id, fault_message)
157 157 elif isinstance(exc, HTTPNotFound):
158 158 method = request.rpc_method
159 159 log.debug('json-rpc method `%s` not found in list of '
160 160 'api calls: %s, rpc_id:%s',
161 161 method, request.registry.jsonrpc_methods.keys(), rpc_id)
162 162
163 163 similar = 'none'
164 164 try:
165 165 similar_paterns = ['*{}*'.format(x) for x in method.split('_')]
166 166 similar_found = find_methods(
167 167 request.registry.jsonrpc_methods, similar_paterns)
168 168 similar = ', '.join(similar_found.keys()) or similar
169 169 except Exception:
170 170 # make the whole above block safe
171 171 pass
172 172
173 173 fault_message = "No such method: {}. Similar methods: {}".format(
174 174 method, similar)
175 175 else:
176 176 fault_message = 'undefined error'
177 177 exc_info = exc.exc_info()
178 178 store_exception(id(exc_info), exc_info, prefix='rhodecode-api')
179 179
180 180 statsd = request.registry.statsd
181 181 if statsd:
182 182 exc_type = "{}.{}".format(exc.__class__.__module__, exc.__class__.__name__)
183 183 statsd.incr('rhodecode_exception_total',
184 184 tags=["exc_source:api", "type:{}".format(exc_type)])
185 185
186 186 return jsonrpc_error(request, fault_message, rpc_id)
187 187
188 188
189 189 def request_view(request):
190 190 """
191 191 Main request handling method. It handles all logic to call a specific
192 192 exposed method
193 193 """
194 194 # cython compatible inspect
195 195 from rhodecode.config.patches import inspect_getargspec
196 196 inspect = inspect_getargspec()
197 197
198 198 # check if we can find this session using api_key, get_by_auth_token
199 199 # search not expired tokens only
200 200 try:
201 201 api_user = User.get_by_auth_token(request.rpc_api_key)
202 202
203 203 if api_user is None:
204 204 return jsonrpc_error(
205 205 request, retid=request.rpc_id, message='Invalid API KEY')
206 206
207 207 if not api_user.active:
208 208 return jsonrpc_error(
209 209 request, retid=request.rpc_id,
210 210 message='Request from this user not allowed')
211 211
212 212 # check if we are allowed to use this IP
213 213 auth_u = AuthUser(
214 214 api_user.user_id, request.rpc_api_key, ip_addr=request.rpc_ip_addr)
215 215 if not auth_u.ip_allowed:
216 216 return jsonrpc_error(
217 217 request, retid=request.rpc_id,
218 218 message='Request from IP:%s not allowed' % (
219 219 request.rpc_ip_addr,))
220 220 else:
221 221 log.info('Access for IP:%s allowed', request.rpc_ip_addr)
222 222
223 223 # register our auth-user
224 224 request.rpc_user = auth_u
225 225 request.environ['rc_auth_user_id'] = auth_u.user_id
226 226
227 227 # now check if token is valid for API
228 228 auth_token = request.rpc_api_key
229 229 token_match = api_user.authenticate_by_token(
230 230 auth_token, roles=[UserApiKeys.ROLE_API])
231 231 invalid_token = not token_match
232 232
233 233 log.debug('Checking if API KEY is valid with proper role')
234 234 if invalid_token:
235 235 return jsonrpc_error(
236 236 request, retid=request.rpc_id,
237 237 message='API KEY invalid or, has bad role for an API call')
238 238
239 239 except Exception:
240 240 log.exception('Error on API AUTH')
241 241 return jsonrpc_error(
242 242 request, retid=request.rpc_id, message='Invalid API KEY')
243 243
244 244 method = request.rpc_method
245 245 func = request.registry.jsonrpc_methods[method]
246 246
247 247 # now that we have a method, add request._req_params to
248 248 # self.kargs and dispatch control to WGIController
249 249 argspec = inspect.getargspec(func)
250 250 arglist = argspec[0]
251 251 defaults = map(type, argspec[3] or [])
252 252 default_empty = types.NotImplementedType
253 253
254 254 # kw arguments required by this method
255 func_kwargs = dict(itertools.izip_longest(
255 func_kwargs = dict(itertools.zip_longest(
256 256 reversed(arglist), reversed(defaults), fillvalue=default_empty))
257 257
258 258 # This attribute will need to be first param of a method that uses
259 259 # api_key, which is translated to instance of user at that name
260 260 user_var = 'apiuser'
261 261 request_var = 'request'
262 262
263 263 for arg in [user_var, request_var]:
264 264 if arg not in arglist:
265 265 return jsonrpc_error(
266 266 request,
267 267 retid=request.rpc_id,
268 268 message='This method [%s] does not support '
269 269 'required parameter `%s`' % (func.__name__, arg))
270 270
271 271 # get our arglist and check if we provided them as args
272 272 for arg, default in func_kwargs.items():
273 273 if arg in [user_var, request_var]:
274 274 # user_var and request_var are pre-hardcoded parameters and we
275 275 # don't need to do any translation
276 276 continue
277 277
278 278 # skip the required param check if it's default value is
279 279 # NotImplementedType (default_empty)
280 280 if default == default_empty and arg not in request.rpc_params:
281 281 return jsonrpc_error(
282 282 request,
283 283 retid=request.rpc_id,
284 284 message=('Missing non optional `%s` arg in JSON DATA' % arg)
285 285 )
286 286
287 287 # sanitize extra passed arguments
288 288 for k in request.rpc_params.keys()[:]:
289 289 if k not in func_kwargs:
290 290 del request.rpc_params[k]
291 291
292 292 call_params = request.rpc_params
293 293 call_params.update({
294 294 'request': request,
295 295 'apiuser': auth_u
296 296 })
297 297
298 298 # register some common functions for usage
299 299 attach_context_attributes(TemplateArgs(), request, request.rpc_user.user_id)
300 300
301 301 statsd = request.registry.statsd
302 302
303 303 try:
304 304 ret_value = func(**call_params)
305 305 resp = jsonrpc_response(request, ret_value)
306 306 if statsd:
307 307 statsd.incr('rhodecode_api_call_success_total')
308 308 return resp
309 309 except JSONRPCBaseError:
310 310 raise
311 311 except Exception:
312 312 log.exception('Unhandled exception occurred on api call: %s', func)
313 313 exc_info = sys.exc_info()
314 314 exc_id, exc_type_name = store_exception(
315 315 id(exc_info), exc_info, prefix='rhodecode-api')
316 316 error_headers = [('RhodeCode-Exception-Id', str(exc_id)),
317 317 ('RhodeCode-Exception-Type', str(exc_type_name))]
318 318 err_resp = jsonrpc_error(
319 319 request, retid=request.rpc_id, message='Internal server error',
320 320 headers=error_headers)
321 321 if statsd:
322 322 statsd.incr('rhodecode_api_call_fail_total')
323 323 return err_resp
324 324
325 325
326 326 def setup_request(request):
327 327 """
328 328 Parse a JSON-RPC request body. It's used inside the predicates method
329 329 to validate and bootstrap requests for usage in rpc calls.
330 330
331 331 We need to raise JSONRPCError here if we want to return some errors back to
332 332 user.
333 333 """
334 334
335 335 log.debug('Executing setup request: %r', request)
336 336 request.rpc_ip_addr = get_ip_addr(request.environ)
337 337 # TODO(marcink): deprecate GET at some point
338 338 if request.method not in ['POST', 'GET']:
339 339 log.debug('unsupported request method "%s"', request.method)
340 340 raise JSONRPCError(
341 341 'unsupported request method "%s". Please use POST' % request.method)
342 342
343 343 if 'CONTENT_LENGTH' not in request.environ:
344 344 log.debug("No Content-Length")
345 345 raise JSONRPCError("Empty body, No Content-Length in request")
346 346
347 347 else:
348 348 length = request.environ['CONTENT_LENGTH']
349 349 log.debug('Content-Length: %s', length)
350 350
351 351 if length == 0:
352 352 log.debug("Content-Length is 0")
353 353 raise JSONRPCError("Content-Length is 0")
354 354
355 355 raw_body = request.body
356 356 log.debug("Loading JSON body now")
357 357 try:
358 358 json_body = json.loads(raw_body)
359 359 except ValueError as e:
360 360 # catch JSON errors Here
361 361 raise JSONRPCError("JSON parse error ERR:%s RAW:%r" % (e, raw_body))
362 362
363 363 request.rpc_id = json_body.get('id')
364 364 request.rpc_method = json_body.get('method')
365 365
366 366 # check required base parameters
367 367 try:
368 368 api_key = json_body.get('api_key')
369 369 if not api_key:
370 370 api_key = json_body.get('auth_token')
371 371
372 372 if not api_key:
373 373 raise KeyError('api_key or auth_token')
374 374
375 375 # TODO(marcink): support passing in token in request header
376 376
377 377 request.rpc_api_key = api_key
378 378 request.rpc_id = json_body['id']
379 379 request.rpc_method = json_body['method']
380 380 request.rpc_params = json_body['args'] \
381 381 if isinstance(json_body['args'], dict) else {}
382 382
383 383 log.debug('method: %s, params: %.10240r', request.rpc_method, request.rpc_params)
384 384 except KeyError as e:
385 385 raise JSONRPCError('Incorrect JSON data. Missing %s' % e)
386 386
387 387 log.debug('setup complete, now handling method:%s rpcid:%s',
388 388 request.rpc_method, request.rpc_id, )
389 389
390 390
391 391 class RoutePredicate(object):
392 392 def __init__(self, val, config):
393 393 self.val = val
394 394
395 395 def text(self):
396 396 return 'jsonrpc route = %s' % self.val
397 397
398 398 phash = text
399 399
400 400 def __call__(self, info, request):
401 401 if self.val:
402 402 # potentially setup and bootstrap our call
403 403 setup_request(request)
404 404
405 405 # Always return True so that even if it isn't a valid RPC it
406 406 # will fall through to the underlaying handlers like notfound_view
407 407 return True
408 408
409 409
410 410 class NotFoundPredicate(object):
411 411 def __init__(self, val, config):
412 412 self.val = val
413 413 self.methods = config.registry.jsonrpc_methods
414 414
415 415 def text(self):
416 416 return 'jsonrpc method not found = {}.'.format(self.val)
417 417
418 418 phash = text
419 419
420 420 def __call__(self, info, request):
421 421 return hasattr(request, 'rpc_method')
422 422
423 423
424 424 class MethodPredicate(object):
425 425 def __init__(self, val, config):
426 426 self.method = val
427 427
428 428 def text(self):
429 429 return 'jsonrpc method = %s' % self.method
430 430
431 431 phash = text
432 432
433 433 def __call__(self, context, request):
434 434 # we need to explicitly return False here, so pyramid doesn't try to
435 435 # execute our view directly. We need our main handler to execute things
436 436 return getattr(request, 'rpc_method') == self.method
437 437
438 438
439 439 def add_jsonrpc_method(config, view, **kwargs):
440 440 # pop the method name
441 441 method = kwargs.pop('method', None)
442 442
443 443 if method is None:
444 444 raise ConfigurationError(
445 445 'Cannot register a JSON-RPC method without specifying the "method"')
446 446
447 447 # we define custom predicate, to enable to detect conflicting methods,
448 448 # those predicates are kind of "translation" from the decorator variables
449 449 # to internal predicates names
450 450
451 451 kwargs['jsonrpc_method'] = method
452 452
453 453 # register our view into global view store for validation
454 454 config.registry.jsonrpc_methods[method] = view
455 455
456 456 # we're using our main request_view handler, here, so each method
457 457 # has a unified handler for itself
458 458 config.add_view(request_view, route_name='apiv2', **kwargs)
459 459
460 460
461 461 class jsonrpc_method(object):
462 462 """
463 463 decorator that works similar to @add_view_config decorator,
464 464 but tailored for our JSON RPC
465 465 """
466 466
467 467 venusian = venusian # for testing injection
468 468
469 469 def __init__(self, method=None, **kwargs):
470 470 self.method = method
471 471 self.kwargs = kwargs
472 472
473 473 def __call__(self, wrapped):
474 474 kwargs = self.kwargs.copy()
475 475 kwargs['method'] = self.method or wrapped.__name__
476 476 depth = kwargs.pop('_depth', 0)
477 477
478 478 def callback(context, name, ob):
479 479 config = context.config.with_package(info.module)
480 480 config.add_jsonrpc_method(view=ob, **kwargs)
481 481
482 482 info = venusian.attach(wrapped, callback, category='pyramid',
483 483 depth=depth + 1)
484 484 if info.scope == 'class':
485 485 # ensure that attr is set if decorating a class method
486 486 kwargs.setdefault('attr', wrapped.__name__)
487 487
488 488 kwargs['_info'] = info.codeinfo # fbo action_method
489 489 return wrapped
490 490
491 491
492 492 class jsonrpc_deprecated_method(object):
493 493 """
494 494 Marks method as deprecated, adds log.warning, and inject special key to
495 495 the request variable to mark method as deprecated.
496 496 Also injects special docstring that extract_docs will catch to mark
497 497 method as deprecated.
498 498
499 499 :param use_method: specify which method should be used instead of
500 500 the decorated one
501 501
502 502 Use like::
503 503
504 504 @jsonrpc_method()
505 505 @jsonrpc_deprecated_method(use_method='new_func', deprecated_at_version='3.0.0')
506 506 def old_func(request, apiuser, arg1, arg2):
507 507 ...
508 508 """
509 509
510 510 def __init__(self, use_method, deprecated_at_version):
511 511 self.use_method = use_method
512 512 self.deprecated_at_version = deprecated_at_version
513 513 self.deprecated_msg = ''
514 514
515 515 def __call__(self, func):
516 516 self.deprecated_msg = 'Please use method `{method}` instead.'.format(
517 517 method=self.use_method)
518 518
519 519 docstring = """\n
520 520 .. deprecated:: {version}
521 521
522 522 {deprecation_message}
523 523
524 524 {original_docstring}
525 525 """
526 526 func.__doc__ = docstring.format(
527 527 version=self.deprecated_at_version,
528 528 deprecation_message=self.deprecated_msg,
529 529 original_docstring=func.__doc__)
530 530 return decorator.decorator(self.__wrapper, func)
531 531
532 532 def __wrapper(self, func, *fargs, **fkwargs):
533 533 log.warning('DEPRECATED API CALL on function %s, please '
534 534 'use `%s` instead', func, self.use_method)
535 535 # alter function docstring to mark as deprecated, this is picked up
536 536 # via fabric file that generates API DOC.
537 537 result = func(*fargs, **fkwargs)
538 538
539 539 request = fargs[0]
540 540 request.rpc_deprecation = 'DEPRECATED METHOD ' + self.deprecated_msg
541 541 return result
542 542
543 543
544 544 def add_api_methods(config):
545 545 from rhodecode.api.views import (
546 546 deprecated_api, gist_api, pull_request_api, repo_api, repo_group_api,
547 547 server_api, search_api, testing_api, user_api, user_group_api)
548 548
549 549 config.scan('rhodecode.api.views')
550 550
551 551
552 552 def includeme(config):
553 553 plugin_module = 'rhodecode.api'
554 554 plugin_settings = get_plugin_settings(
555 555 plugin_module, config.registry.settings)
556 556
557 557 if not hasattr(config.registry, 'jsonrpc_methods'):
558 558 config.registry.jsonrpc_methods = OrderedDict()
559 559
560 560 # match filter by given method only
561 561 config.add_view_predicate('jsonrpc_method', MethodPredicate)
562 562 config.add_view_predicate('jsonrpc_method_not_found', NotFoundPredicate)
563 563
564 564 config.add_renderer(DEFAULT_RENDERER, ExtJsonRenderer(
565 565 serializer=json.dumps, indent=4))
566 566 config.add_directive('add_jsonrpc_method', add_jsonrpc_method)
567 567
568 568 config.add_route_predicate(
569 569 'jsonrpc_call', RoutePredicate)
570 570
571 571 config.add_route(
572 572 'apiv2', plugin_settings.get('url', DEFAULT_URL), jsonrpc_call=True)
573 573
574 574 # register some exception handling view
575 575 config.add_view(exception_view, context=JSONRPCBaseError)
576 576 config.add_notfound_view(exception_view, jsonrpc_method_not_found=True)
577 577
578 578 add_api_methods(config)
@@ -1,419 +1,419 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import logging
22 22 import itertools
23 23 import base64
24 24
25 25 from rhodecode.api import (
26 26 jsonrpc_method, JSONRPCError, JSONRPCForbidden, find_methods)
27 27
28 28 from rhodecode.api.utils import (
29 29 Optional, OAttr, has_superadmin_permission, get_user_or_error)
30 30 from rhodecode.lib.utils import repo2db_mapper
31 31 from rhodecode.lib import system_info
32 32 from rhodecode.lib import user_sessions
33 33 from rhodecode.lib import exc_tracking
34 34 from rhodecode.lib.ext_json import json
35 35 from rhodecode.lib.utils2 import safe_int
36 36 from rhodecode.model.db import UserIpMap
37 37 from rhodecode.model.scm import ScmModel
38 38 from rhodecode.model.settings import VcsSettingsModel
39 39 from rhodecode.apps.file_store import utils
40 40 from rhodecode.apps.file_store.exceptions import FileNotAllowedException, \
41 41 FileOverSizeException
42 42
43 43 log = logging.getLogger(__name__)
44 44
45 45
46 46 @jsonrpc_method()
47 47 def get_server_info(request, apiuser):
48 48 """
49 49 Returns the |RCE| server information.
50 50
51 51 This includes the running version of |RCE| and all installed
52 52 packages. This command takes the following options:
53 53
54 54 :param apiuser: This is filled automatically from the |authtoken|.
55 55 :type apiuser: AuthUser
56 56
57 57 Example output:
58 58
59 59 .. code-block:: bash
60 60
61 61 id : <id_given_in_input>
62 62 result : {
63 63 'modules': [<module name>,...]
64 64 'py_version': <python version>,
65 65 'platform': <platform type>,
66 66 'rhodecode_version': <rhodecode version>
67 67 }
68 68 error : null
69 69 """
70 70
71 71 if not has_superadmin_permission(apiuser):
72 72 raise JSONRPCForbidden()
73 73
74 74 server_info = ScmModel().get_server_info(request.environ)
75 75 # rhodecode-index requires those
76 76
77 77 server_info['index_storage'] = server_info['search']['value']['location']
78 78 server_info['storage'] = server_info['storage']['value']['path']
79 79
80 80 return server_info
81 81
82 82
83 83 @jsonrpc_method()
84 84 def get_repo_store(request, apiuser):
85 85 """
86 86 Returns the |RCE| repository storage information.
87 87
88 88 :param apiuser: This is filled automatically from the |authtoken|.
89 89 :type apiuser: AuthUser
90 90
91 91 Example output:
92 92
93 93 .. code-block:: bash
94 94
95 95 id : <id_given_in_input>
96 96 result : {
97 97 'modules': [<module name>,...]
98 98 'py_version': <python version>,
99 99 'platform': <platform type>,
100 100 'rhodecode_version': <rhodecode version>
101 101 }
102 102 error : null
103 103 """
104 104
105 105 if not has_superadmin_permission(apiuser):
106 106 raise JSONRPCForbidden()
107 107
108 108 path = VcsSettingsModel().get_repos_location()
109 109 return {"path": path}
110 110
111 111
112 112 @jsonrpc_method()
113 113 def get_ip(request, apiuser, userid=Optional(OAttr('apiuser'))):
114 114 """
115 115 Displays the IP Address as seen from the |RCE| server.
116 116
117 117 * This command displays the IP Address, as well as all the defined IP
118 118 addresses for the specified user. If the ``userid`` is not set, the
119 119 data returned is for the user calling the method.
120 120
121 121 This command can only be run using an |authtoken| with admin rights to
122 122 the specified repository.
123 123
124 124 This command takes the following options:
125 125
126 126 :param apiuser: This is filled automatically from |authtoken|.
127 127 :type apiuser: AuthUser
128 128 :param userid: Sets the userid for which associated IP Address data
129 129 is returned.
130 130 :type userid: Optional(str or int)
131 131
132 132 Example output:
133 133
134 134 .. code-block:: bash
135 135
136 136 id : <id_given_in_input>
137 137 result : {
138 138 "server_ip_addr": "<ip_from_clien>",
139 139 "user_ips": [
140 140 {
141 141 "ip_addr": "<ip_with_mask>",
142 142 "ip_range": ["<start_ip>", "<end_ip>"],
143 143 },
144 144 ...
145 145 ]
146 146 }
147 147
148 148 """
149 149 if not has_superadmin_permission(apiuser):
150 150 raise JSONRPCForbidden()
151 151
152 152 userid = Optional.extract(userid, evaluate_locals=locals())
153 153 userid = getattr(userid, 'user_id', userid)
154 154
155 155 user = get_user_or_error(userid)
156 156 ips = UserIpMap.query().filter(UserIpMap.user == user).all()
157 157 return {
158 158 'server_ip_addr': request.rpc_ip_addr,
159 159 'user_ips': ips
160 160 }
161 161
162 162
163 163 @jsonrpc_method()
164 164 def rescan_repos(request, apiuser, remove_obsolete=Optional(False)):
165 165 """
166 166 Triggers a rescan of the specified repositories.
167 167
168 168 * If the ``remove_obsolete`` option is set, it also deletes repositories
169 169 that are found in the database but not on the file system, so called
170 170 "clean zombies".
171 171
172 172 This command can only be run using an |authtoken| with admin rights to
173 173 the specified repository.
174 174
175 175 This command takes the following options:
176 176
177 177 :param apiuser: This is filled automatically from the |authtoken|.
178 178 :type apiuser: AuthUser
179 179 :param remove_obsolete: Deletes repositories from the database that
180 180 are not found on the filesystem.
181 181 :type remove_obsolete: Optional(``True`` | ``False``)
182 182
183 183 Example output:
184 184
185 185 .. code-block:: bash
186 186
187 187 id : <id_given_in_input>
188 188 result : {
189 189 'added': [<added repository name>,...]
190 190 'removed': [<removed repository name>,...]
191 191 }
192 192 error : null
193 193
194 194 Example error output:
195 195
196 196 .. code-block:: bash
197 197
198 198 id : <id_given_in_input>
199 199 result : null
200 200 error : {
201 201 'Error occurred during rescan repositories action'
202 202 }
203 203
204 204 """
205 205 if not has_superadmin_permission(apiuser):
206 206 raise JSONRPCForbidden()
207 207
208 208 try:
209 209 rm_obsolete = Optional.extract(remove_obsolete)
210 210 added, removed = repo2db_mapper(ScmModel().repo_scan(),
211 211 remove_obsolete=rm_obsolete)
212 212 return {'added': added, 'removed': removed}
213 213 except Exception:
214 214 log.exception('Failed to run repo rescann')
215 215 raise JSONRPCError(
216 216 'Error occurred during rescan repositories action'
217 217 )
218 218
219 219
220 220 @jsonrpc_method()
221 221 def cleanup_sessions(request, apiuser, older_then=Optional(60)):
222 222 """
223 223 Triggers a session cleanup action.
224 224
225 225 If the ``older_then`` option is set, only sessions that hasn't been
226 226 accessed in the given number of days will be removed.
227 227
228 228 This command can only be run using an |authtoken| with admin rights to
229 229 the specified repository.
230 230
231 231 This command takes the following options:
232 232
233 233 :param apiuser: This is filled automatically from the |authtoken|.
234 234 :type apiuser: AuthUser
235 235 :param older_then: Deletes session that hasn't been accessed
236 236 in given number of days.
237 237 :type older_then: Optional(int)
238 238
239 239 Example output:
240 240
241 241 .. code-block:: bash
242 242
243 243 id : <id_given_in_input>
244 244 result: {
245 245 "backend": "<type of backend>",
246 246 "sessions_removed": <number_of_removed_sessions>
247 247 }
248 248 error : null
249 249
250 250 Example error output:
251 251
252 252 .. code-block:: bash
253 253
254 254 id : <id_given_in_input>
255 255 result : null
256 256 error : {
257 257 'Error occurred during session cleanup'
258 258 }
259 259
260 260 """
261 261 if not has_superadmin_permission(apiuser):
262 262 raise JSONRPCForbidden()
263 263
264 264 older_then = safe_int(Optional.extract(older_then)) or 60
265 265 older_than_seconds = 60 * 60 * 24 * older_then
266 266
267 267 config = system_info.rhodecode_config().get_value()['value']['config']
268 268 session_model = user_sessions.get_session_handler(
269 269 config.get('beaker.session.type', 'memory'))(config)
270 270
271 271 backend = session_model.SESSION_TYPE
272 272 try:
273 273 cleaned = session_model.clean_sessions(
274 274 older_than_seconds=older_than_seconds)
275 275 return {'sessions_removed': cleaned, 'backend': backend}
276 276 except user_sessions.CleanupCommand as msg:
277 277 return {'cleanup_command': msg.message, 'backend': backend}
278 278 except Exception as e:
279 279 log.exception('Failed session cleanup')
280 280 raise JSONRPCError(
281 281 'Error occurred during session cleanup'
282 282 )
283 283
284 284
285 285 @jsonrpc_method()
286 286 def get_method(request, apiuser, pattern=Optional('*')):
287 287 """
288 288 Returns list of all available API methods. By default match pattern
289 289 os "*" but any other pattern can be specified. eg *comment* will return
290 290 all methods with comment inside them. If just single method is matched
291 291 returned data will also include method specification
292 292
293 293 This command can only be run using an |authtoken| with admin rights to
294 294 the specified repository.
295 295
296 296 This command takes the following options:
297 297
298 298 :param apiuser: This is filled automatically from the |authtoken|.
299 299 :type apiuser: AuthUser
300 300 :param pattern: pattern to match method names against
301 301 :type pattern: Optional("*")
302 302
303 303 Example output:
304 304
305 305 .. code-block:: bash
306 306
307 307 id : <id_given_in_input>
308 308 "result": [
309 309 "changeset_comment",
310 310 "comment_pull_request",
311 311 "comment_commit"
312 312 ]
313 313 error : null
314 314
315 315 .. code-block:: bash
316 316
317 317 id : <id_given_in_input>
318 318 "result": [
319 319 "comment_commit",
320 320 {
321 321 "apiuser": "<RequiredType>",
322 322 "comment_type": "<Optional:u'note'>",
323 323 "commit_id": "<RequiredType>",
324 324 "message": "<RequiredType>",
325 325 "repoid": "<RequiredType>",
326 326 "request": "<RequiredType>",
327 327 "resolves_comment_id": "<Optional:None>",
328 328 "status": "<Optional:None>",
329 329 "userid": "<Optional:<OptionalAttr:apiuser>>"
330 330 }
331 331 ]
332 332 error : null
333 333 """
334 334 from rhodecode.config.patches import inspect_getargspec
335 335 inspect = inspect_getargspec()
336 336
337 337 if not has_superadmin_permission(apiuser):
338 338 raise JSONRPCForbidden()
339 339
340 340 pattern = Optional.extract(pattern)
341 341
342 342 matches = find_methods(request.registry.jsonrpc_methods, pattern)
343 343
344 344 args_desc = []
345 345 if len(matches) == 1:
346 346 func = matches[matches.keys()[0]]
347 347
348 348 argspec = inspect.getargspec(func)
349 349 arglist = argspec[0]
350 350 defaults = map(repr, argspec[3] or [])
351 351
352 352 default_empty = '<RequiredType>'
353 353
354 354 # kw arguments required by this method
355 func_kwargs = dict(itertools.izip_longest(
355 func_kwargs = dict(itertools.zip_longest(
356 356 reversed(arglist), reversed(defaults), fillvalue=default_empty))
357 357 args_desc.append(func_kwargs)
358 358
359 359 return matches.keys() + args_desc
360 360
361 361
362 362 @jsonrpc_method()
363 363 def store_exception(request, apiuser, exc_data_json, prefix=Optional('rhodecode')):
364 364 """
365 365 Stores sent exception inside the built-in exception tracker in |RCE| server.
366 366
367 367 This command can only be run using an |authtoken| with admin rights to
368 368 the specified repository.
369 369
370 370 This command takes the following options:
371 371
372 372 :param apiuser: This is filled automatically from the |authtoken|.
373 373 :type apiuser: AuthUser
374 374
375 375 :param exc_data_json: JSON data with exception e.g
376 376 {"exc_traceback": "Value `1` is not allowed", "exc_type_name": "ValueError"}
377 377 :type exc_data_json: JSON data
378 378
379 379 :param prefix: prefix for error type, e.g 'rhodecode', 'vcsserver', 'rhodecode-tools'
380 380 :type prefix: Optional("rhodecode")
381 381
382 382 Example output:
383 383
384 384 .. code-block:: bash
385 385
386 386 id : <id_given_in_input>
387 387 "result": {
388 388 "exc_id": 139718459226384,
389 389 "exc_url": "http://localhost:8080/_admin/settings/exceptions/139718459226384"
390 390 }
391 391 error : null
392 392 """
393 393 if not has_superadmin_permission(apiuser):
394 394 raise JSONRPCForbidden()
395 395
396 396 prefix = Optional.extract(prefix)
397 397 exc_id = exc_tracking.generate_id()
398 398
399 399 try:
400 400 exc_data = json.loads(exc_data_json)
401 401 except Exception:
402 402 log.error('Failed to parse JSON: %r', exc_data_json)
403 403 raise JSONRPCError('Failed to parse JSON data from exc_data_json field. '
404 404 'Please make sure it contains a valid JSON.')
405 405
406 406 try:
407 407 exc_traceback = exc_data['exc_traceback']
408 408 exc_type_name = exc_data['exc_type_name']
409 409 except KeyError as err:
410 410 raise JSONRPCError('Missing exc_traceback, or exc_type_name '
411 411 'in exc_data_json field. Missing: {}'.format(err))
412 412
413 413 exc_tracking._store_exception(
414 414 exc_id=exc_id, exc_traceback=exc_traceback,
415 415 exc_type_name=exc_type_name, prefix=prefix)
416 416
417 417 exc_url = request.route_url(
418 418 'admin_settings_exception_tracker_show', exception_id=exc_id)
419 419 return {'exc_id': exc_id, 'exc_url': exc_url}
@@ -1,479 +1,479 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2016-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import re
22 22 import logging
23 23 import formencode
24 24 import formencode.htmlfill
25 25 import datetime
26 26 from pyramid.interfaces import IRoutesMapper
27 27
28 28 from pyramid.httpexceptions import HTTPFound
29 29 from pyramid.renderers import render
30 30 from pyramid.response import Response
31 31
32 32 from rhodecode.apps._base import BaseAppView, DataGridAppView
33 33 from rhodecode.apps.ssh_support import SshKeyFileChangeEvent
34 34 from rhodecode import events
35 35
36 36 from rhodecode.lib import helpers as h
37 37 from rhodecode.lib.auth import (
38 38 LoginRequired, HasPermissionAllDecorator, CSRFRequired)
39 39 from rhodecode.lib.utils2 import aslist, safe_unicode
40 40 from rhodecode.model.db import (
41 41 or_, coalesce, User, UserIpMap, UserSshKeys)
42 42 from rhodecode.model.forms import (
43 43 ApplicationPermissionsForm, ObjectPermissionsForm, UserPermissionsForm)
44 44 from rhodecode.model.meta import Session
45 45 from rhodecode.model.permission import PermissionModel
46 46 from rhodecode.model.settings import SettingsModel
47 47
48 48
49 49 log = logging.getLogger(__name__)
50 50
51 51
52 52 class AdminPermissionsView(BaseAppView, DataGridAppView):
53 53 def load_default_context(self):
54 54 c = self._get_local_tmpl_context()
55 55 PermissionModel().set_global_permission_choices(
56 56 c, gettext_translator=self.request.translate)
57 57 return c
58 58
59 59 @LoginRequired()
60 60 @HasPermissionAllDecorator('hg.admin')
61 61 def permissions_application(self):
62 62 c = self.load_default_context()
63 63 c.active = 'application'
64 64
65 65 c.user = User.get_default_user(refresh=True)
66 66
67 67 app_settings = c.rc_config
68 68
69 69 defaults = {
70 70 'anonymous': c.user.active,
71 71 'default_register_message': app_settings.get(
72 72 'rhodecode_register_message')
73 73 }
74 74 defaults.update(c.user.get_default_perms())
75 75
76 76 data = render('rhodecode:templates/admin/permissions/permissions.mako',
77 77 self._get_template_context(c), self.request)
78 78 html = formencode.htmlfill.render(
79 79 data,
80 80 defaults=defaults,
81 81 encoding="UTF-8",
82 82 force_defaults=False
83 83 )
84 84 return Response(html)
85 85
86 86 @LoginRequired()
87 87 @HasPermissionAllDecorator('hg.admin')
88 88 @CSRFRequired()
89 89 def permissions_application_update(self):
90 90 _ = self.request.translate
91 91 c = self.load_default_context()
92 92 c.active = 'application'
93 93
94 94 _form = ApplicationPermissionsForm(
95 95 self.request.translate,
96 96 [x[0] for x in c.register_choices],
97 97 [x[0] for x in c.password_reset_choices],
98 98 [x[0] for x in c.extern_activate_choices])()
99 99
100 100 try:
101 101 form_result = _form.to_python(dict(self.request.POST))
102 102 form_result.update({'perm_user_name': User.DEFAULT_USER})
103 103 PermissionModel().update_application_permissions(form_result)
104 104
105 105 settings = [
106 106 ('register_message', 'default_register_message'),
107 107 ]
108 108 for setting, form_key in settings:
109 109 sett = SettingsModel().create_or_update_setting(
110 110 setting, form_result[form_key])
111 111 Session().add(sett)
112 112
113 113 Session().commit()
114 114 h.flash(_('Application permissions updated successfully'),
115 115 category='success')
116 116
117 117 except formencode.Invalid as errors:
118 118 defaults = errors.value
119 119
120 120 data = render(
121 121 'rhodecode:templates/admin/permissions/permissions.mako',
122 122 self._get_template_context(c), self.request)
123 123 html = formencode.htmlfill.render(
124 124 data,
125 125 defaults=defaults,
126 126 errors=errors.error_dict or {},
127 127 prefix_error=False,
128 128 encoding="UTF-8",
129 129 force_defaults=False
130 130 )
131 131 return Response(html)
132 132
133 133 except Exception:
134 134 log.exception("Exception during update of permissions")
135 135 h.flash(_('Error occurred during update of permissions'),
136 136 category='error')
137 137
138 138 affected_user_ids = [User.get_default_user_id()]
139 139 PermissionModel().trigger_permission_flush(affected_user_ids)
140 140
141 141 raise HTTPFound(h.route_path('admin_permissions_application'))
142 142
143 143 @LoginRequired()
144 144 @HasPermissionAllDecorator('hg.admin')
145 145 def permissions_objects(self):
146 146 c = self.load_default_context()
147 147 c.active = 'objects'
148 148
149 149 c.user = User.get_default_user(refresh=True)
150 150 defaults = {}
151 151 defaults.update(c.user.get_default_perms())
152 152
153 153 data = render(
154 154 'rhodecode:templates/admin/permissions/permissions.mako',
155 155 self._get_template_context(c), self.request)
156 156 html = formencode.htmlfill.render(
157 157 data,
158 158 defaults=defaults,
159 159 encoding="UTF-8",
160 160 force_defaults=False
161 161 )
162 162 return Response(html)
163 163
164 164 @LoginRequired()
165 165 @HasPermissionAllDecorator('hg.admin')
166 166 @CSRFRequired()
167 167 def permissions_objects_update(self):
168 168 _ = self.request.translate
169 169 c = self.load_default_context()
170 170 c.active = 'objects'
171 171
172 172 _form = ObjectPermissionsForm(
173 173 self.request.translate,
174 174 [x[0] for x in c.repo_perms_choices],
175 175 [x[0] for x in c.group_perms_choices],
176 176 [x[0] for x in c.user_group_perms_choices],
177 177 )()
178 178
179 179 try:
180 180 form_result = _form.to_python(dict(self.request.POST))
181 181 form_result.update({'perm_user_name': User.DEFAULT_USER})
182 182 PermissionModel().update_object_permissions(form_result)
183 183
184 184 Session().commit()
185 185 h.flash(_('Object permissions updated successfully'),
186 186 category='success')
187 187
188 188 except formencode.Invalid as errors:
189 189 defaults = errors.value
190 190
191 191 data = render(
192 192 'rhodecode:templates/admin/permissions/permissions.mako',
193 193 self._get_template_context(c), self.request)
194 194 html = formencode.htmlfill.render(
195 195 data,
196 196 defaults=defaults,
197 197 errors=errors.error_dict or {},
198 198 prefix_error=False,
199 199 encoding="UTF-8",
200 200 force_defaults=False
201 201 )
202 202 return Response(html)
203 203 except Exception:
204 204 log.exception("Exception during update of permissions")
205 205 h.flash(_('Error occurred during update of permissions'),
206 206 category='error')
207 207
208 208 affected_user_ids = [User.get_default_user_id()]
209 209 PermissionModel().trigger_permission_flush(affected_user_ids)
210 210
211 211 raise HTTPFound(h.route_path('admin_permissions_object'))
212 212
213 213 @LoginRequired()
214 214 @HasPermissionAllDecorator('hg.admin')
215 215 def permissions_branch(self):
216 216 c = self.load_default_context()
217 217 c.active = 'branch'
218 218
219 219 c.user = User.get_default_user(refresh=True)
220 220 defaults = {}
221 221 defaults.update(c.user.get_default_perms())
222 222
223 223 data = render(
224 224 'rhodecode:templates/admin/permissions/permissions.mako',
225 225 self._get_template_context(c), self.request)
226 226 html = formencode.htmlfill.render(
227 227 data,
228 228 defaults=defaults,
229 229 encoding="UTF-8",
230 230 force_defaults=False
231 231 )
232 232 return Response(html)
233 233
234 234 @LoginRequired()
235 235 @HasPermissionAllDecorator('hg.admin')
236 236 def permissions_global(self):
237 237 c = self.load_default_context()
238 238 c.active = 'global'
239 239
240 240 c.user = User.get_default_user(refresh=True)
241 241 defaults = {}
242 242 defaults.update(c.user.get_default_perms())
243 243
244 244 data = render(
245 245 'rhodecode:templates/admin/permissions/permissions.mako',
246 246 self._get_template_context(c), self.request)
247 247 html = formencode.htmlfill.render(
248 248 data,
249 249 defaults=defaults,
250 250 encoding="UTF-8",
251 251 force_defaults=False
252 252 )
253 253 return Response(html)
254 254
255 255 @LoginRequired()
256 256 @HasPermissionAllDecorator('hg.admin')
257 257 @CSRFRequired()
258 258 def permissions_global_update(self):
259 259 _ = self.request.translate
260 260 c = self.load_default_context()
261 261 c.active = 'global'
262 262
263 263 _form = UserPermissionsForm(
264 264 self.request.translate,
265 265 [x[0] for x in c.repo_create_choices],
266 266 [x[0] for x in c.repo_create_on_write_choices],
267 267 [x[0] for x in c.repo_group_create_choices],
268 268 [x[0] for x in c.user_group_create_choices],
269 269 [x[0] for x in c.fork_choices],
270 270 [x[0] for x in c.inherit_default_permission_choices])()
271 271
272 272 try:
273 273 form_result = _form.to_python(dict(self.request.POST))
274 274 form_result.update({'perm_user_name': User.DEFAULT_USER})
275 275 PermissionModel().update_user_permissions(form_result)
276 276
277 277 Session().commit()
278 278 h.flash(_('Global permissions updated successfully'),
279 279 category='success')
280 280
281 281 except formencode.Invalid as errors:
282 282 defaults = errors.value
283 283
284 284 data = render(
285 285 'rhodecode:templates/admin/permissions/permissions.mako',
286 286 self._get_template_context(c), self.request)
287 287 html = formencode.htmlfill.render(
288 288 data,
289 289 defaults=defaults,
290 290 errors=errors.error_dict or {},
291 291 prefix_error=False,
292 292 encoding="UTF-8",
293 293 force_defaults=False
294 294 )
295 295 return Response(html)
296 296 except Exception:
297 297 log.exception("Exception during update of permissions")
298 298 h.flash(_('Error occurred during update of permissions'),
299 299 category='error')
300 300
301 301 affected_user_ids = [User.get_default_user_id()]
302 302 PermissionModel().trigger_permission_flush(affected_user_ids)
303 303
304 304 raise HTTPFound(h.route_path('admin_permissions_global'))
305 305
306 306 @LoginRequired()
307 307 @HasPermissionAllDecorator('hg.admin')
308 308 def permissions_ips(self):
309 309 c = self.load_default_context()
310 310 c.active = 'ips'
311 311
312 312 c.user = User.get_default_user(refresh=True)
313 313 c.user_ip_map = (
314 314 UserIpMap.query().filter(UserIpMap.user == c.user).all())
315 315
316 316 return self._get_template_context(c)
317 317
318 318 @LoginRequired()
319 319 @HasPermissionAllDecorator('hg.admin')
320 320 def permissions_overview(self):
321 321 c = self.load_default_context()
322 322 c.active = 'perms'
323 323
324 324 c.user = User.get_default_user(refresh=True)
325 325 c.perm_user = c.user.AuthUser()
326 326 return self._get_template_context(c)
327 327
328 328 @LoginRequired()
329 329 @HasPermissionAllDecorator('hg.admin')
330 330 def auth_token_access(self):
331 331 from rhodecode import CONFIG
332 332
333 333 c = self.load_default_context()
334 334 c.active = 'auth_token_access'
335 335
336 336 c.user = User.get_default_user(refresh=True)
337 337 c.perm_user = c.user.AuthUser()
338 338
339 339 mapper = self.request.registry.queryUtility(IRoutesMapper)
340 340 c.view_data = []
341 341
342 _argument_prog = re.compile('\{(.*?)\}|:\((.*)\)')
342 _argument_prog = re.compile(r'\{(.*?)\}|:\((.*)\)')
343 343 introspector = self.request.registry.introspector
344 344
345 345 view_intr = {}
346 346 for view_data in introspector.get_category('views'):
347 347 intr = view_data['introspectable']
348 348
349 349 if 'route_name' in intr and intr['attr']:
350 350 view_intr[intr['route_name']] = '{}:{}'.format(
351 351 str(intr['derived_callable'].__name__), intr['attr']
352 352 )
353 353
354 354 c.whitelist_key = 'api_access_controllers_whitelist'
355 355 c.whitelist_file = CONFIG.get('__file__')
356 356 whitelist_views = aslist(
357 357 CONFIG.get(c.whitelist_key), sep=',')
358 358
359 359 for route_info in mapper.get_routes():
360 360 if not route_info.name.startswith('__'):
361 361 routepath = route_info.pattern
362 362
363 363 def replace(matchobj):
364 364 if matchobj.group(1):
365 365 return "{%s}" % matchobj.group(1).split(':')[0]
366 366 else:
367 367 return "{%s}" % matchobj.group(2)
368 368
369 369 routepath = _argument_prog.sub(replace, routepath)
370 370
371 371 if not routepath.startswith('/'):
372 372 routepath = '/' + routepath
373 373
374 374 view_fqn = view_intr.get(route_info.name, 'NOT AVAILABLE')
375 375 active = view_fqn in whitelist_views
376 376 c.view_data.append((route_info.name, view_fqn, routepath, active))
377 377
378 378 c.whitelist_views = whitelist_views
379 379 return self._get_template_context(c)
380 380
381 381 def ssh_enabled(self):
382 382 return self.request.registry.settings.get(
383 383 'ssh.generate_authorized_keyfile')
384 384
385 385 @LoginRequired()
386 386 @HasPermissionAllDecorator('hg.admin')
387 387 def ssh_keys(self):
388 388 c = self.load_default_context()
389 389 c.active = 'ssh_keys'
390 390 c.ssh_enabled = self.ssh_enabled()
391 391 return self._get_template_context(c)
392 392
393 393 @LoginRequired()
394 394 @HasPermissionAllDecorator('hg.admin')
395 395 def ssh_keys_data(self):
396 396 _ = self.request.translate
397 397 self.load_default_context()
398 398 column_map = {
399 399 'fingerprint': 'ssh_key_fingerprint',
400 400 'username': User.username
401 401 }
402 402 draw, start, limit = self._extract_chunk(self.request)
403 403 search_q, order_by, order_dir = self._extract_ordering(
404 404 self.request, column_map=column_map)
405 405
406 406 ssh_keys_data_total_count = UserSshKeys.query()\
407 407 .count()
408 408
409 409 # json generate
410 410 base_q = UserSshKeys.query().join(UserSshKeys.user)
411 411
412 412 if search_q:
413 413 like_expression = u'%{}%'.format(safe_unicode(search_q))
414 414 base_q = base_q.filter(or_(
415 415 User.username.ilike(like_expression),
416 416 UserSshKeys.ssh_key_fingerprint.ilike(like_expression),
417 417 ))
418 418
419 419 users_data_total_filtered_count = base_q.count()
420 420
421 421 sort_col = self._get_order_col(order_by, UserSshKeys)
422 422 if sort_col:
423 423 if order_dir == 'asc':
424 424 # handle null values properly to order by NULL last
425 425 if order_by in ['created_on']:
426 426 sort_col = coalesce(sort_col, datetime.date.max)
427 427 sort_col = sort_col.asc()
428 428 else:
429 429 # handle null values properly to order by NULL last
430 430 if order_by in ['created_on']:
431 431 sort_col = coalesce(sort_col, datetime.date.min)
432 432 sort_col = sort_col.desc()
433 433
434 434 base_q = base_q.order_by(sort_col)
435 435 base_q = base_q.offset(start).limit(limit)
436 436
437 437 ssh_keys = base_q.all()
438 438
439 439 ssh_keys_data = []
440 440 for ssh_key in ssh_keys:
441 441 ssh_keys_data.append({
442 442 "username": h.gravatar_with_user(self.request, ssh_key.user.username),
443 443 "fingerprint": ssh_key.ssh_key_fingerprint,
444 444 "description": ssh_key.description,
445 445 "created_on": h.format_date(ssh_key.created_on),
446 446 "accessed_on": h.format_date(ssh_key.accessed_on),
447 447 "action": h.link_to(
448 448 _('Edit'), h.route_path('edit_user_ssh_keys',
449 449 user_id=ssh_key.user.user_id))
450 450 })
451 451
452 452 data = ({
453 453 'draw': draw,
454 454 'data': ssh_keys_data,
455 455 'recordsTotal': ssh_keys_data_total_count,
456 456 'recordsFiltered': users_data_total_filtered_count,
457 457 })
458 458
459 459 return data
460 460
461 461 @LoginRequired()
462 462 @HasPermissionAllDecorator('hg.admin')
463 463 @CSRFRequired()
464 464 def ssh_keys_update(self):
465 465 _ = self.request.translate
466 466 self.load_default_context()
467 467
468 468 ssh_enabled = self.ssh_enabled()
469 469 key_file = self.request.registry.settings.get(
470 470 'ssh.authorized_keys_file_path')
471 471 if ssh_enabled:
472 472 events.trigger(SshKeyFileChangeEvent(), self.request.registry)
473 473 h.flash(_('Updated SSH keys file: {}').format(key_file),
474 474 category='success')
475 475 else:
476 476 h.flash(_('SSH key support is disabled in .ini file'),
477 477 category='warning')
478 478
479 479 raise HTTPFound(h.route_path('admin_permissions_ssh_keys'))
@@ -1,58 +1,57 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2016-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21
21 import io
22 22 import uuid
23 from io import StringIO
24 23 import pathlib2
25 24
26 25
27 26 def get_file_storage(settings):
28 27 from rhodecode.apps.file_store.backends.local_store import LocalFileStorage
29 28 from rhodecode.apps.file_store import config_keys
30 29 store_path = settings.get(config_keys.store_path)
31 30 return LocalFileStorage(base_path=store_path)
32 31
33 32
34 33 def splitext(filename):
35 34 ext = ''.join(pathlib2.Path(filename).suffixes)
36 35 return filename, ext
37 36
38 37
39 38 def uid_filename(filename, randomized=True):
40 39 """
41 40 Generates a randomized or stable (uuid) filename,
42 41 preserving the original extension.
43 42
44 43 :param filename: the original filename
45 44 :param randomized: define if filename should be stable (sha1 based) or randomized
46 45 """
47 46
48 47 _, ext = splitext(filename)
49 48 if randomized:
50 49 uid = uuid.uuid4()
51 50 else:
52 51 hash_key = '{}.{}'.format(filename, 'store')
53 52 uid = uuid.uuid5(uuid.NAMESPACE_URL, hash_key)
54 53 return str(uid) + ext.lower()
55 54
56 55
57 56 def bytes_to_file_obj(bytes_data):
58 return StringIO.StringIO(bytes_data)
57 return io.StringIO(bytes_data)
@@ -1,580 +1,580 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import urllib.parse
22 22
23 23 import mock
24 24 import pytest
25 25
26 26 from rhodecode.tests import (
27 27 assert_session_flash, HG_REPO, TEST_USER_ADMIN_LOGIN,
28 28 no_newline_id_generator)
29 29 from rhodecode.tests.fixture import Fixture
30 30 from rhodecode.lib.auth import check_password
31 31 from rhodecode.lib import helpers as h
32 32 from rhodecode.model.auth_token import AuthTokenModel
33 33 from rhodecode.model.db import User, Notification, UserApiKeys
34 34 from rhodecode.model.meta import Session
35 35
36 36 fixture = Fixture()
37 37
38 38 whitelist_view = ['RepoCommitsView:repo_commit_raw']
39 39
40 40
41 41 def route_path(name, params=None, **kwargs):
42 42 import urllib.request, urllib.parse, urllib.error
43 43 from rhodecode.apps._base import ADMIN_PREFIX
44 44
45 45 base_url = {
46 46 'login': ADMIN_PREFIX + '/login',
47 47 'logout': ADMIN_PREFIX + '/logout',
48 48 'register': ADMIN_PREFIX + '/register',
49 49 'reset_password':
50 50 ADMIN_PREFIX + '/password_reset',
51 51 'reset_password_confirmation':
52 52 ADMIN_PREFIX + '/password_reset_confirmation',
53 53
54 54 'admin_permissions_application':
55 55 ADMIN_PREFIX + '/permissions/application',
56 56 'admin_permissions_application_update':
57 57 ADMIN_PREFIX + '/permissions/application/update',
58 58
59 59 'repo_commit_raw': '/{repo_name}/raw-changeset/{commit_id}'
60 60
61 61 }[name].format(**kwargs)
62 62
63 63 if params:
64 64 base_url = '{}?{}'.format(base_url, urllib.parse.urlencode(params))
65 65 return base_url
66 66
67 67
68 68 @pytest.mark.usefixtures('app')
69 69 class TestLoginController(object):
70 70 destroy_users = set()
71 71
72 72 @classmethod
73 73 def teardown_class(cls):
74 74 fixture.destroy_users(cls.destroy_users)
75 75
76 76 def teardown_method(self, method):
77 77 for n in Notification.query().all():
78 78 Session().delete(n)
79 79
80 80 Session().commit()
81 81 assert Notification.query().all() == []
82 82
83 83 def test_index(self):
84 84 response = self.app.get(route_path('login'))
85 85 assert response.status == '200 OK'
86 86 # Test response...
87 87
88 88 def test_login_admin_ok(self):
89 89 response = self.app.post(route_path('login'),
90 90 {'username': 'test_admin',
91 91 'password': 'test12'}, status=302)
92 92 response = response.follow()
93 93 session = response.get_session_from_response()
94 94 username = session['rhodecode_user'].get('username')
95 95 assert username == 'test_admin'
96 96 response.mustcontain('logout')
97 97
98 98 def test_login_regular_ok(self):
99 99 response = self.app.post(route_path('login'),
100 100 {'username': 'test_regular',
101 101 'password': 'test12'}, status=302)
102 102
103 103 response = response.follow()
104 104 session = response.get_session_from_response()
105 105 username = session['rhodecode_user'].get('username')
106 106 assert username == 'test_regular'
107 107 response.mustcontain('logout')
108 108
109 109 def test_login_regular_forbidden_when_super_admin_restriction(self):
110 110 from rhodecode.authentication.plugins.auth_rhodecode import RhodeCodeAuthPlugin
111 111 with fixture.auth_restriction(self.app._pyramid_registry,
112 112 RhodeCodeAuthPlugin.AUTH_RESTRICTION_SUPER_ADMIN):
113 113 response = self.app.post(route_path('login'),
114 114 {'username': 'test_regular',
115 115 'password': 'test12'})
116 116
117 117 response.mustcontain('invalid user name')
118 118 response.mustcontain('invalid password')
119 119
120 120 def test_login_regular_forbidden_when_scope_restriction(self):
121 121 from rhodecode.authentication.plugins.auth_rhodecode import RhodeCodeAuthPlugin
122 122 with fixture.scope_restriction(self.app._pyramid_registry,
123 123 RhodeCodeAuthPlugin.AUTH_RESTRICTION_SCOPE_VCS):
124 124 response = self.app.post(route_path('login'),
125 125 {'username': 'test_regular',
126 126 'password': 'test12'})
127 127
128 128 response.mustcontain('invalid user name')
129 129 response.mustcontain('invalid password')
130 130
131 131 def test_login_ok_came_from(self):
132 132 test_came_from = '/_admin/users?branch=stable'
133 133 _url = '{}?came_from={}'.format(route_path('login'), test_came_from)
134 134 response = self.app.post(
135 135 _url, {'username': 'test_admin', 'password': 'test12'}, status=302)
136 136
137 137 assert 'branch=stable' in response.location
138 138 response = response.follow()
139 139
140 140 assert response.status == '200 OK'
141 141 response.mustcontain('Users administration')
142 142
143 143 def test_redirect_to_login_with_get_args(self):
144 144 with fixture.anon_access(False):
145 145 kwargs = {'branch': 'stable'}
146 146 response = self.app.get(
147 147 h.route_path('repo_summary', repo_name=HG_REPO, _query=kwargs),
148 148 status=302)
149 149
150 response_query = urllib.parse.urlparse.parse_qsl(response.location)
150 response_query = urllib.parse.parse_qsl(response.location)
151 151 assert 'branch=stable' in response_query[0][1]
152 152
153 153 def test_login_form_with_get_args(self):
154 154 _url = '{}?came_from=/_admin/users,branch=stable'.format(route_path('login'))
155 155 response = self.app.get(_url)
156 156 assert 'branch%3Dstable' in response.form.action
157 157
158 158 @pytest.mark.parametrize("url_came_from", [
159 159 'data:text/html,<script>window.alert("xss")</script>',
160 160 'mailto:test@rhodecode.org',
161 161 'file:///etc/passwd',
162 162 'ftp://some.ftp.server',
163 163 'http://other.domain',
164 164 '/\r\nX-Forwarded-Host: http://example.org',
165 165 ], ids=no_newline_id_generator)
166 166 def test_login_bad_came_froms(self, url_came_from):
167 167 _url = '{}?came_from={}'.format(route_path('login'), url_came_from)
168 168 response = self.app.post(
169 169 _url,
170 170 {'username': 'test_admin', 'password': 'test12'})
171 171 assert response.status == '302 Found'
172 172 response = response.follow()
173 173 assert response.status == '200 OK'
174 174 assert response.request.path == '/'
175 175
176 176 def test_login_short_password(self):
177 177 response = self.app.post(route_path('login'),
178 178 {'username': 'test_admin',
179 179 'password': 'as'})
180 180 assert response.status == '200 OK'
181 181
182 182 response.mustcontain('Enter 3 characters or more')
183 183
184 184 def test_login_wrong_non_ascii_password(self, user_regular):
185 185 response = self.app.post(
186 186 route_path('login'),
187 187 {'username': user_regular.username,
188 188 'password': u'invalid-non-asci\xe4'.encode('utf8')})
189 189
190 190 response.mustcontain('invalid user name')
191 191 response.mustcontain('invalid password')
192 192
193 193 def test_login_with_non_ascii_password(self, user_util):
194 194 password = u'valid-non-ascii\xe4'
195 195 user = user_util.create_user(password=password)
196 196 response = self.app.post(
197 197 route_path('login'),
198 198 {'username': user.username,
199 199 'password': password})
200 200 assert response.status_code == 302
201 201
202 202 def test_login_wrong_username_password(self):
203 203 response = self.app.post(route_path('login'),
204 204 {'username': 'error',
205 205 'password': 'test12'})
206 206
207 207 response.mustcontain('invalid user name')
208 208 response.mustcontain('invalid password')
209 209
210 210 def test_login_admin_ok_password_migration(self, real_crypto_backend):
211 211 from rhodecode.lib import auth
212 212
213 213 # create new user, with sha256 password
214 214 temp_user = 'test_admin_sha256'
215 215 user = fixture.create_user(temp_user)
216 216 user.password = auth._RhodeCodeCryptoSha256().hash_create(
217 217 b'test123')
218 218 Session().add(user)
219 219 Session().commit()
220 220 self.destroy_users.add(temp_user)
221 221 response = self.app.post(route_path('login'),
222 222 {'username': temp_user,
223 223 'password': 'test123'}, status=302)
224 224
225 225 response = response.follow()
226 226 session = response.get_session_from_response()
227 227 username = session['rhodecode_user'].get('username')
228 228 assert username == temp_user
229 229 response.mustcontain('logout')
230 230
231 231 # new password should be bcrypted, after log-in and transfer
232 232 user = User.get_by_username(temp_user)
233 233 assert user.password.startswith('$')
234 234
235 235 # REGISTRATIONS
236 236 def test_register(self):
237 237 response = self.app.get(route_path('register'))
238 238 response.mustcontain('Create an Account')
239 239
240 240 def test_register_err_same_username(self):
241 241 uname = 'test_admin'
242 242 response = self.app.post(
243 243 route_path('register'),
244 244 {
245 245 'username': uname,
246 246 'password': 'test12',
247 247 'password_confirmation': 'test12',
248 248 'email': 'goodmail@domain.com',
249 249 'firstname': 'test',
250 250 'lastname': 'test'
251 251 }
252 252 )
253 253
254 254 assertr = response.assert_response()
255 255 msg = 'Username "%(username)s" already exists'
256 256 msg = msg % {'username': uname}
257 257 assertr.element_contains('#username+.error-message', msg)
258 258
259 259 def test_register_err_same_email(self):
260 260 response = self.app.post(
261 261 route_path('register'),
262 262 {
263 263 'username': 'test_admin_0',
264 264 'password': 'test12',
265 265 'password_confirmation': 'test12',
266 266 'email': 'test_admin@mail.com',
267 267 'firstname': 'test',
268 268 'lastname': 'test'
269 269 }
270 270 )
271 271
272 272 assertr = response.assert_response()
273 273 msg = u'This e-mail address is already taken'
274 274 assertr.element_contains('#email+.error-message', msg)
275 275
276 276 def test_register_err_same_email_case_sensitive(self):
277 277 response = self.app.post(
278 278 route_path('register'),
279 279 {
280 280 'username': 'test_admin_1',
281 281 'password': 'test12',
282 282 'password_confirmation': 'test12',
283 283 'email': 'TesT_Admin@mail.COM',
284 284 'firstname': 'test',
285 285 'lastname': 'test'
286 286 }
287 287 )
288 288 assertr = response.assert_response()
289 289 msg = u'This e-mail address is already taken'
290 290 assertr.element_contains('#email+.error-message', msg)
291 291
292 292 def test_register_err_wrong_data(self):
293 293 response = self.app.post(
294 294 route_path('register'),
295 295 {
296 296 'username': 'xs',
297 297 'password': 'test',
298 298 'password_confirmation': 'test',
299 299 'email': 'goodmailm',
300 300 'firstname': 'test',
301 301 'lastname': 'test'
302 302 }
303 303 )
304 304 assert response.status == '200 OK'
305 305 response.mustcontain('An email address must contain a single @')
306 306 response.mustcontain('Enter a value 6 characters long or more')
307 307
308 308 def test_register_err_username(self):
309 309 response = self.app.post(
310 310 route_path('register'),
311 311 {
312 312 'username': 'error user',
313 313 'password': 'test12',
314 314 'password_confirmation': 'test12',
315 315 'email': 'goodmailm',
316 316 'firstname': 'test',
317 317 'lastname': 'test'
318 318 }
319 319 )
320 320
321 321 response.mustcontain('An email address must contain a single @')
322 322 response.mustcontain(
323 323 'Username may only contain '
324 324 'alphanumeric characters underscores, '
325 325 'periods or dashes and must begin with '
326 326 'alphanumeric character')
327 327
328 328 def test_register_err_case_sensitive(self):
329 329 usr = 'Test_Admin'
330 330 response = self.app.post(
331 331 route_path('register'),
332 332 {
333 333 'username': usr,
334 334 'password': 'test12',
335 335 'password_confirmation': 'test12',
336 336 'email': 'goodmailm',
337 337 'firstname': 'test',
338 338 'lastname': 'test'
339 339 }
340 340 )
341 341
342 342 assertr = response.assert_response()
343 343 msg = u'Username "%(username)s" already exists'
344 344 msg = msg % {'username': usr}
345 345 assertr.element_contains('#username+.error-message', msg)
346 346
347 347 def test_register_special_chars(self):
348 348 response = self.app.post(
349 349 route_path('register'),
350 350 {
351 351 'username': 'xxxaxn',
352 352 'password': 'ąćźżąśśśś',
353 353 'password_confirmation': 'ąćźżąśśśś',
354 354 'email': 'goodmailm@test.plx',
355 355 'firstname': 'test',
356 356 'lastname': 'test'
357 357 }
358 358 )
359 359
360 360 msg = u'Invalid characters (non-ascii) in password'
361 361 response.mustcontain(msg)
362 362
363 363 def test_register_password_mismatch(self):
364 364 response = self.app.post(
365 365 route_path('register'),
366 366 {
367 367 'username': 'xs',
368 368 'password': '123qwe',
369 369 'password_confirmation': 'qwe123',
370 370 'email': 'goodmailm@test.plxa',
371 371 'firstname': 'test',
372 372 'lastname': 'test'
373 373 }
374 374 )
375 375 msg = u'Passwords do not match'
376 376 response.mustcontain(msg)
377 377
378 378 def test_register_ok(self):
379 379 username = 'test_regular4'
380 380 password = 'qweqwe'
381 381 email = 'marcin@test.com'
382 382 name = 'testname'
383 383 lastname = 'testlastname'
384 384
385 385 # this initializes a session
386 386 response = self.app.get(route_path('register'))
387 387 response.mustcontain('Create an Account')
388 388
389 389
390 390 response = self.app.post(
391 391 route_path('register'),
392 392 {
393 393 'username': username,
394 394 'password': password,
395 395 'password_confirmation': password,
396 396 'email': email,
397 397 'firstname': name,
398 398 'lastname': lastname,
399 399 'admin': True
400 400 },
401 401 status=302
402 402 ) # This should be overridden
403 403
404 404 assert_session_flash(
405 405 response, 'You have successfully registered with RhodeCode. You can log-in now.')
406 406
407 407 ret = Session().query(User).filter(
408 408 User.username == 'test_regular4').one()
409 409 assert ret.username == username
410 410 assert check_password(password, ret.password)
411 411 assert ret.email == email
412 412 assert ret.name == name
413 413 assert ret.lastname == lastname
414 414 assert ret.auth_tokens is not None
415 415 assert not ret.admin
416 416
417 417 def test_forgot_password_wrong_mail(self):
418 418 bad_email = 'marcin@wrongmail.org'
419 419 # this initializes a session
420 420 self.app.get(route_path('reset_password'))
421 421
422 422 response = self.app.post(
423 423 route_path('reset_password'), {'email': bad_email, }
424 424 )
425 425 assert_session_flash(response,
426 426 'If such email exists, a password reset link was sent to it.')
427 427
428 428 def test_forgot_password(self, user_util):
429 429 # this initializes a session
430 430 self.app.get(route_path('reset_password'))
431 431
432 432 user = user_util.create_user()
433 433 user_id = user.user_id
434 434 email = user.email
435 435
436 436 response = self.app.post(route_path('reset_password'), {'email': email, })
437 437
438 438 assert_session_flash(response,
439 439 'If such email exists, a password reset link was sent to it.')
440 440
441 441 # BAD KEY
442 442 confirm_url = '{}?key={}'.format(route_path('reset_password_confirmation'), 'badkey')
443 443 response = self.app.get(confirm_url, status=302)
444 444 assert response.location.endswith(route_path('reset_password'))
445 445 assert_session_flash(response, 'Given reset token is invalid')
446 446
447 447 response.follow() # cleanup flash
448 448
449 449 # GOOD KEY
450 450 key = UserApiKeys.query()\
451 451 .filter(UserApiKeys.user_id == user_id)\
452 452 .filter(UserApiKeys.role == UserApiKeys.ROLE_PASSWORD_RESET)\
453 453 .first()
454 454
455 455 assert key
456 456
457 457 confirm_url = '{}?key={}'.format(route_path('reset_password_confirmation'), key.api_key)
458 458 response = self.app.get(confirm_url)
459 459 assert response.status == '302 Found'
460 460 assert response.location.endswith(route_path('login'))
461 461
462 462 assert_session_flash(
463 463 response,
464 464 'Your password reset was successful, '
465 465 'a new password has been sent to your email')
466 466
467 467 response.follow()
468 468
469 469 def _get_api_whitelist(self, values=None):
470 470 config = {'api_access_controllers_whitelist': values or []}
471 471 return config
472 472
473 473 @pytest.mark.parametrize("test_name, auth_token", [
474 474 ('none', None),
475 475 ('empty_string', ''),
476 476 ('fake_number', '123456'),
477 477 ('proper_auth_token', None)
478 478 ])
479 479 def test_access_not_whitelisted_page_via_auth_token(
480 480 self, test_name, auth_token, user_admin):
481 481
482 482 whitelist = self._get_api_whitelist([])
483 483 with mock.patch.dict('rhodecode.CONFIG', whitelist):
484 484 assert [] == whitelist['api_access_controllers_whitelist']
485 485 if test_name == 'proper_auth_token':
486 486 # use builtin if api_key is None
487 487 auth_token = user_admin.api_key
488 488
489 489 with fixture.anon_access(False):
490 490 self.app.get(
491 491 route_path('repo_commit_raw',
492 492 repo_name=HG_REPO, commit_id='tip',
493 493 params=dict(api_key=auth_token)),
494 494 status=302)
495 495
496 496 @pytest.mark.parametrize("test_name, auth_token, code", [
497 497 ('none', None, 302),
498 498 ('empty_string', '', 302),
499 499 ('fake_number', '123456', 302),
500 500 ('proper_auth_token', None, 200)
501 501 ])
502 502 def test_access_whitelisted_page_via_auth_token(
503 503 self, test_name, auth_token, code, user_admin):
504 504
505 505 whitelist = self._get_api_whitelist(whitelist_view)
506 506
507 507 with mock.patch.dict('rhodecode.CONFIG', whitelist):
508 508 assert whitelist_view == whitelist['api_access_controllers_whitelist']
509 509
510 510 if test_name == 'proper_auth_token':
511 511 auth_token = user_admin.api_key
512 512 assert auth_token
513 513
514 514 with fixture.anon_access(False):
515 515 self.app.get(
516 516 route_path('repo_commit_raw',
517 517 repo_name=HG_REPO, commit_id='tip',
518 518 params=dict(api_key=auth_token)),
519 519 status=code)
520 520
521 521 @pytest.mark.parametrize("test_name, auth_token, code", [
522 522 ('proper_auth_token', None, 200),
523 523 ('wrong_auth_token', '123456', 302),
524 524 ])
525 525 def test_access_whitelisted_page_via_auth_token_bound_to_token(
526 526 self, test_name, auth_token, code, user_admin):
527 527
528 528 expected_token = auth_token
529 529 if test_name == 'proper_auth_token':
530 530 auth_token = user_admin.api_key
531 531 expected_token = auth_token
532 532 assert auth_token
533 533
534 534 whitelist = self._get_api_whitelist([
535 535 'RepoCommitsView:repo_commit_raw@{}'.format(expected_token)])
536 536
537 537 with mock.patch.dict('rhodecode.CONFIG', whitelist):
538 538
539 539 with fixture.anon_access(False):
540 540 self.app.get(
541 541 route_path('repo_commit_raw',
542 542 repo_name=HG_REPO, commit_id='tip',
543 543 params=dict(api_key=auth_token)),
544 544 status=code)
545 545
546 546 def test_access_page_via_extra_auth_token(self):
547 547 whitelist = self._get_api_whitelist(whitelist_view)
548 548 with mock.patch.dict('rhodecode.CONFIG', whitelist):
549 549 assert whitelist_view == \
550 550 whitelist['api_access_controllers_whitelist']
551 551
552 552 new_auth_token = AuthTokenModel().create(
553 553 TEST_USER_ADMIN_LOGIN, 'test')
554 554 Session().commit()
555 555 with fixture.anon_access(False):
556 556 self.app.get(
557 557 route_path('repo_commit_raw',
558 558 repo_name=HG_REPO, commit_id='tip',
559 559 params=dict(api_key=new_auth_token.api_key)),
560 560 status=200)
561 561
562 562 def test_access_page_via_expired_auth_token(self):
563 563 whitelist = self._get_api_whitelist(whitelist_view)
564 564 with mock.patch.dict('rhodecode.CONFIG', whitelist):
565 565 assert whitelist_view == \
566 566 whitelist['api_access_controllers_whitelist']
567 567
568 568 new_auth_token = AuthTokenModel().create(
569 569 TEST_USER_ADMIN_LOGIN, 'test')
570 570 Session().commit()
571 571 # patch the api key and make it expired
572 572 new_auth_token.expires = 0
573 573 Session().add(new_auth_token)
574 574 Session().commit()
575 575 with fixture.anon_access(False):
576 576 self.app.get(
577 577 route_path('repo_commit_raw',
578 578 repo_name=HG_REPO, commit_id='tip',
579 579 params=dict(api_key=new_auth_token.api_key)),
580 580 status=302)
@@ -1,1877 +1,1877 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import logging
22 22 import collections
23 23
24 24 import formencode
25 25 import formencode.htmlfill
26 26 import peppercorn
27 27 from pyramid.httpexceptions import (
28 28 HTTPFound, HTTPNotFound, HTTPForbidden, HTTPBadRequest, HTTPConflict)
29 29
30 30 from pyramid.renderers import render
31 31
32 32 from rhodecode.apps._base import RepoAppView, DataGridAppView
33 33
34 34 from rhodecode.lib import helpers as h, diffs, codeblocks, channelstream
35 35 from rhodecode.lib.base import vcs_operation_context
36 36 from rhodecode.lib.diffs import load_cached_diff, cache_diff, diff_cache_exist
37 37 from rhodecode.lib.exceptions import CommentVersionMismatch
38 38 from rhodecode.lib.ext_json import json
39 39 from rhodecode.lib.auth import (
40 40 LoginRequired, HasRepoPermissionAny, HasRepoPermissionAnyDecorator,
41 41 NotAnonymous, CSRFRequired)
42 42 from rhodecode.lib.utils2 import str2bool, safe_str, safe_unicode, safe_int, aslist, retry
43 43 from rhodecode.lib.vcs.backends.base import (
44 44 EmptyCommit, UpdateFailureReason, unicode_to_reference)
45 45 from rhodecode.lib.vcs.exceptions import (
46 46 CommitDoesNotExistError, RepositoryRequirementError, EmptyRepositoryError)
47 47 from rhodecode.model.changeset_status import ChangesetStatusModel
48 48 from rhodecode.model.comment import CommentsModel
49 49 from rhodecode.model.db import (
50 50 func, false, or_, PullRequest, ChangesetComment, ChangesetStatus, Repository,
51 51 PullRequestReviewers)
52 52 from rhodecode.model.forms import PullRequestForm
53 53 from rhodecode.model.meta import Session
54 54 from rhodecode.model.pull_request import PullRequestModel, MergeCheck
55 55 from rhodecode.model.scm import ScmModel
56 56
57 57 log = logging.getLogger(__name__)
58 58
59 59
60 60 class RepoPullRequestsView(RepoAppView, DataGridAppView):
61 61
62 62 def load_default_context(self):
63 63 c = self._get_local_tmpl_context(include_app_defaults=True)
64 64 c.REVIEW_STATUS_APPROVED = ChangesetStatus.STATUS_APPROVED
65 65 c.REVIEW_STATUS_REJECTED = ChangesetStatus.STATUS_REJECTED
66 66 # backward compat., we use for OLD PRs a plain renderer
67 67 c.renderer = 'plain'
68 68 return c
69 69
70 70 def _get_pull_requests_list(
71 71 self, repo_name, source, filter_type, opened_by, statuses):
72 72
73 73 draw, start, limit = self._extract_chunk(self.request)
74 74 search_q, order_by, order_dir = self._extract_ordering(self.request)
75 75 _render = self.request.get_partial_renderer(
76 76 'rhodecode:templates/data_table/_dt_elements.mako')
77 77
78 78 # pagination
79 79
80 80 if filter_type == 'awaiting_review':
81 81 pull_requests = PullRequestModel().get_awaiting_review(
82 82 repo_name,
83 83 search_q=search_q, statuses=statuses,
84 84 offset=start, length=limit, order_by=order_by, order_dir=order_dir)
85 85 pull_requests_total_count = PullRequestModel().count_awaiting_review(
86 86 repo_name,
87 87 search_q=search_q, statuses=statuses)
88 88 elif filter_type == 'awaiting_my_review':
89 89 pull_requests = PullRequestModel().get_awaiting_my_review(
90 90 repo_name, self._rhodecode_user.user_id,
91 91 search_q=search_q, statuses=statuses,
92 92 offset=start, length=limit, order_by=order_by, order_dir=order_dir)
93 93 pull_requests_total_count = PullRequestModel().count_awaiting_my_review(
94 94 repo_name, self._rhodecode_user.user_id,
95 95 search_q=search_q, statuses=statuses)
96 96 else:
97 97 pull_requests = PullRequestModel().get_all(
98 98 repo_name, search_q=search_q, source=source, opened_by=opened_by,
99 99 statuses=statuses, offset=start, length=limit,
100 100 order_by=order_by, order_dir=order_dir)
101 101 pull_requests_total_count = PullRequestModel().count_all(
102 102 repo_name, search_q=search_q, source=source, statuses=statuses,
103 103 opened_by=opened_by)
104 104
105 105 data = []
106 106 comments_model = CommentsModel()
107 107 for pr in pull_requests:
108 108 comments_count = comments_model.get_all_comments(
109 109 self.db_repo.repo_id, pull_request=pr,
110 110 include_drafts=False, count_only=True)
111 111
112 112 review_statuses = pr.reviewers_statuses(user=self._rhodecode_db_user)
113 113 my_review_status = ChangesetStatus.STATUS_NOT_REVIEWED
114 114 if review_statuses and review_statuses[4]:
115 115 _review_obj, _user, _reasons, _mandatory, statuses = review_statuses
116 116 my_review_status = statuses[0][1].status
117 117
118 118 data.append({
119 119 'name': _render('pullrequest_name',
120 120 pr.pull_request_id, pr.pull_request_state,
121 121 pr.work_in_progress, pr.target_repo.repo_name,
122 122 short=True),
123 123 'name_raw': pr.pull_request_id,
124 124 'status': _render('pullrequest_status',
125 125 pr.calculated_review_status()),
126 126 'my_status': _render('pullrequest_status',
127 127 my_review_status),
128 128 'title': _render('pullrequest_title', pr.title, pr.description),
129 129 'description': h.escape(pr.description),
130 130 'updated_on': _render('pullrequest_updated_on',
131 131 h.datetime_to_time(pr.updated_on),
132 132 pr.versions_count),
133 133 'updated_on_raw': h.datetime_to_time(pr.updated_on),
134 134 'created_on': _render('pullrequest_updated_on',
135 135 h.datetime_to_time(pr.created_on)),
136 136 'created_on_raw': h.datetime_to_time(pr.created_on),
137 137 'state': pr.pull_request_state,
138 138 'author': _render('pullrequest_author',
139 139 pr.author.full_contact, ),
140 140 'author_raw': pr.author.full_name,
141 141 'comments': _render('pullrequest_comments', comments_count),
142 142 'comments_raw': comments_count,
143 143 'closed': pr.is_closed(),
144 144 })
145 145
146 146 data = ({
147 147 'draw': draw,
148 148 'data': data,
149 149 'recordsTotal': pull_requests_total_count,
150 150 'recordsFiltered': pull_requests_total_count,
151 151 })
152 152 return data
153 153
154 154 @LoginRequired()
155 155 @HasRepoPermissionAnyDecorator(
156 156 'repository.read', 'repository.write', 'repository.admin')
157 157 def pull_request_list(self):
158 158 c = self.load_default_context()
159 159
160 160 req_get = self.request.GET
161 161 c.source = str2bool(req_get.get('source'))
162 162 c.closed = str2bool(req_get.get('closed'))
163 163 c.my = str2bool(req_get.get('my'))
164 164 c.awaiting_review = str2bool(req_get.get('awaiting_review'))
165 165 c.awaiting_my_review = str2bool(req_get.get('awaiting_my_review'))
166 166
167 167 c.active = 'open'
168 168 if c.my:
169 169 c.active = 'my'
170 170 if c.closed:
171 171 c.active = 'closed'
172 172 if c.awaiting_review and not c.source:
173 173 c.active = 'awaiting'
174 174 if c.source and not c.awaiting_review:
175 175 c.active = 'source'
176 176 if c.awaiting_my_review:
177 177 c.active = 'awaiting_my'
178 178
179 179 return self._get_template_context(c)
180 180
181 181 @LoginRequired()
182 182 @HasRepoPermissionAnyDecorator(
183 183 'repository.read', 'repository.write', 'repository.admin')
184 184 def pull_request_list_data(self):
185 185 self.load_default_context()
186 186
187 187 # additional filters
188 188 req_get = self.request.GET
189 189 source = str2bool(req_get.get('source'))
190 190 closed = str2bool(req_get.get('closed'))
191 191 my = str2bool(req_get.get('my'))
192 192 awaiting_review = str2bool(req_get.get('awaiting_review'))
193 193 awaiting_my_review = str2bool(req_get.get('awaiting_my_review'))
194 194
195 195 filter_type = 'awaiting_review' if awaiting_review \
196 196 else 'awaiting_my_review' if awaiting_my_review \
197 197 else None
198 198
199 199 opened_by = None
200 200 if my:
201 201 opened_by = [self._rhodecode_user.user_id]
202 202
203 203 statuses = [PullRequest.STATUS_NEW, PullRequest.STATUS_OPEN]
204 204 if closed:
205 205 statuses = [PullRequest.STATUS_CLOSED]
206 206
207 207 data = self._get_pull_requests_list(
208 208 repo_name=self.db_repo_name, source=source,
209 209 filter_type=filter_type, opened_by=opened_by, statuses=statuses)
210 210
211 211 return data
212 212
213 213 def _is_diff_cache_enabled(self, target_repo):
214 214 caching_enabled = self._get_general_setting(
215 215 target_repo, 'rhodecode_diff_cache')
216 216 log.debug('Diff caching enabled: %s', caching_enabled)
217 217 return caching_enabled
218 218
219 219 def _get_diffset(self, source_repo_name, source_repo,
220 220 ancestor_commit,
221 221 source_ref_id, target_ref_id,
222 222 target_commit, source_commit, diff_limit, file_limit,
223 223 fulldiff, hide_whitespace_changes, diff_context, use_ancestor=True):
224 224
225 225 target_commit_final = target_commit
226 226 source_commit_final = source_commit
227 227
228 228 if use_ancestor:
229 229 # we might want to not use it for versions
230 230 target_ref_id = ancestor_commit.raw_id
231 231 target_commit_final = ancestor_commit
232 232
233 233 vcs_diff = PullRequestModel().get_diff(
234 234 source_repo, source_ref_id, target_ref_id,
235 235 hide_whitespace_changes, diff_context)
236 236
237 237 diff_processor = diffs.DiffProcessor(
238 238 vcs_diff, format='newdiff', diff_limit=diff_limit,
239 239 file_limit=file_limit, show_full_diff=fulldiff)
240 240
241 241 _parsed = diff_processor.prepare()
242 242
243 243 diffset = codeblocks.DiffSet(
244 244 repo_name=self.db_repo_name,
245 245 source_repo_name=source_repo_name,
246 246 source_node_getter=codeblocks.diffset_node_getter(target_commit_final),
247 247 target_node_getter=codeblocks.diffset_node_getter(source_commit_final),
248 248 )
249 249 diffset = self.path_filter.render_patchset_filtered(
250 250 diffset, _parsed, target_ref_id, source_ref_id)
251 251
252 252 return diffset
253 253
254 254 def _get_range_diffset(self, source_scm, source_repo,
255 255 commit1, commit2, diff_limit, file_limit,
256 256 fulldiff, hide_whitespace_changes, diff_context):
257 257 vcs_diff = source_scm.get_diff(
258 258 commit1, commit2,
259 259 ignore_whitespace=hide_whitespace_changes,
260 260 context=diff_context)
261 261
262 262 diff_processor = diffs.DiffProcessor(
263 263 vcs_diff, format='newdiff', diff_limit=diff_limit,
264 264 file_limit=file_limit, show_full_diff=fulldiff)
265 265
266 266 _parsed = diff_processor.prepare()
267 267
268 268 diffset = codeblocks.DiffSet(
269 269 repo_name=source_repo.repo_name,
270 270 source_node_getter=codeblocks.diffset_node_getter(commit1),
271 271 target_node_getter=codeblocks.diffset_node_getter(commit2))
272 272
273 273 diffset = self.path_filter.render_patchset_filtered(
274 274 diffset, _parsed, commit1.raw_id, commit2.raw_id)
275 275
276 276 return diffset
277 277
278 278 def register_comments_vars(self, c, pull_request, versions, include_drafts=True):
279 279 comments_model = CommentsModel()
280 280
281 281 # GENERAL COMMENTS with versions #
282 282 q = comments_model._all_general_comments_of_pull_request(pull_request)
283 283 q = q.order_by(ChangesetComment.comment_id.asc())
284 284 if not include_drafts:
285 285 q = q.filter(ChangesetComment.draft == false())
286 286 general_comments = q
287 287
288 288 # pick comments we want to render at current version
289 289 c.comment_versions = comments_model.aggregate_comments(
290 290 general_comments, versions, c.at_version_num)
291 291
292 292 # INLINE COMMENTS with versions #
293 293 q = comments_model._all_inline_comments_of_pull_request(pull_request)
294 294 q = q.order_by(ChangesetComment.comment_id.asc())
295 295 if not include_drafts:
296 296 q = q.filter(ChangesetComment.draft == false())
297 297 inline_comments = q
298 298
299 299 c.inline_versions = comments_model.aggregate_comments(
300 300 inline_comments, versions, c.at_version_num, inline=True)
301 301
302 302 # Comments inline+general
303 303 if c.at_version:
304 304 c.inline_comments_flat = c.inline_versions[c.at_version_num]['display']
305 305 c.comments = c.comment_versions[c.at_version_num]['display']
306 306 else:
307 307 c.inline_comments_flat = c.inline_versions[c.at_version_num]['until']
308 308 c.comments = c.comment_versions[c.at_version_num]['until']
309 309
310 310 return general_comments, inline_comments
311 311
312 312 @LoginRequired()
313 313 @HasRepoPermissionAnyDecorator(
314 314 'repository.read', 'repository.write', 'repository.admin')
315 315 def pull_request_show(self):
316 316 _ = self.request.translate
317 317 c = self.load_default_context()
318 318
319 319 pull_request = PullRequest.get_or_404(
320 320 self.request.matchdict['pull_request_id'])
321 321 pull_request_id = pull_request.pull_request_id
322 322
323 323 c.state_progressing = pull_request.is_state_changing()
324 324 c.pr_broadcast_channel = channelstream.pr_channel(pull_request)
325 325
326 326 _new_state = {
327 327 'created': PullRequest.STATE_CREATED,
328 328 }.get(self.request.GET.get('force_state'))
329 329 can_force_state = c.is_super_admin or HasRepoPermissionAny('repository.admin')(c.repo_name)
330 330
331 331 if can_force_state and _new_state:
332 332 with pull_request.set_state(PullRequest.STATE_UPDATING, final_state=_new_state):
333 333 h.flash(
334 334 _('Pull Request state was force changed to `{}`').format(_new_state),
335 335 category='success')
336 336 Session().commit()
337 337
338 338 raise HTTPFound(h.route_path(
339 339 'pullrequest_show', repo_name=self.db_repo_name,
340 340 pull_request_id=pull_request_id))
341 341
342 342 version = self.request.GET.get('version')
343 343 from_version = self.request.GET.get('from_version') or version
344 344 merge_checks = self.request.GET.get('merge_checks')
345 345 c.fulldiff = str2bool(self.request.GET.get('fulldiff'))
346 346 force_refresh = str2bool(self.request.GET.get('force_refresh'))
347 347 c.range_diff_on = self.request.GET.get('range-diff') == "1"
348 348
349 349 # fetch global flags of ignore ws or context lines
350 350 diff_context = diffs.get_diff_context(self.request)
351 351 hide_whitespace_changes = diffs.get_diff_whitespace_flag(self.request)
352 352
353 353 (pull_request_latest,
354 354 pull_request_at_ver,
355 355 pull_request_display_obj,
356 356 at_version) = PullRequestModel().get_pr_version(
357 357 pull_request_id, version=version)
358 358
359 359 pr_closed = pull_request_latest.is_closed()
360 360
361 361 if pr_closed and (version or from_version):
362 362 # not allow to browse versions for closed PR
363 363 raise HTTPFound(h.route_path(
364 364 'pullrequest_show', repo_name=self.db_repo_name,
365 365 pull_request_id=pull_request_id))
366 366
367 367 versions = pull_request_display_obj.versions()
368 368
369 369 c.commit_versions = PullRequestModel().pr_commits_versions(versions)
370 370
371 371 # used to store per-commit range diffs
372 372 c.changes = collections.OrderedDict()
373 373
374 374 c.at_version = at_version
375 375 c.at_version_num = (at_version
376 376 if at_version and at_version != PullRequest.LATEST_VER
377 377 else None)
378 378
379 379 c.at_version_index = ChangesetComment.get_index_from_version(
380 380 c.at_version_num, versions)
381 381
382 382 (prev_pull_request_latest,
383 383 prev_pull_request_at_ver,
384 384 prev_pull_request_display_obj,
385 385 prev_at_version) = PullRequestModel().get_pr_version(
386 386 pull_request_id, version=from_version)
387 387
388 388 c.from_version = prev_at_version
389 389 c.from_version_num = (prev_at_version
390 390 if prev_at_version and prev_at_version != PullRequest.LATEST_VER
391 391 else None)
392 392 c.from_version_index = ChangesetComment.get_index_from_version(
393 393 c.from_version_num, versions)
394 394
395 395 # define if we're in COMPARE mode or VIEW at version mode
396 396 compare = at_version != prev_at_version
397 397
398 398 # pull_requests repo_name we opened it against
399 399 # ie. target_repo must match
400 400 if self.db_repo_name != pull_request_at_ver.target_repo.repo_name:
401 401 log.warning('Mismatch between the current repo: %s, and target %s',
402 402 self.db_repo_name, pull_request_at_ver.target_repo.repo_name)
403 403 raise HTTPNotFound()
404 404
405 405 c.shadow_clone_url = PullRequestModel().get_shadow_clone_url(pull_request_at_ver)
406 406
407 407 c.pull_request = pull_request_display_obj
408 408 c.renderer = pull_request_at_ver.description_renderer or c.renderer
409 409 c.pull_request_latest = pull_request_latest
410 410
411 411 # inject latest version
412 412 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
413 413 c.versions = versions + [latest_ver]
414 414
415 415 if compare or (at_version and not at_version == PullRequest.LATEST_VER):
416 416 c.allowed_to_change_status = False
417 417 c.allowed_to_update = False
418 418 c.allowed_to_merge = False
419 419 c.allowed_to_delete = False
420 420 c.allowed_to_comment = False
421 421 c.allowed_to_close = False
422 422 else:
423 423 can_change_status = PullRequestModel().check_user_change_status(
424 424 pull_request_at_ver, self._rhodecode_user)
425 425 c.allowed_to_change_status = can_change_status and not pr_closed
426 426
427 427 c.allowed_to_update = PullRequestModel().check_user_update(
428 428 pull_request_latest, self._rhodecode_user) and not pr_closed
429 429 c.allowed_to_merge = PullRequestModel().check_user_merge(
430 430 pull_request_latest, self._rhodecode_user) and not pr_closed
431 431 c.allowed_to_delete = PullRequestModel().check_user_delete(
432 432 pull_request_latest, self._rhodecode_user) and not pr_closed
433 433 c.allowed_to_comment = not pr_closed
434 434 c.allowed_to_close = c.allowed_to_merge and not pr_closed
435 435
436 436 c.forbid_adding_reviewers = False
437 437
438 438 if pull_request_latest.reviewer_data and \
439 439 'rules' in pull_request_latest.reviewer_data:
440 440 rules = pull_request_latest.reviewer_data['rules'] or {}
441 441 try:
442 442 c.forbid_adding_reviewers = rules.get('forbid_adding_reviewers')
443 443 except Exception:
444 444 pass
445 445
446 446 # check merge capabilities
447 447 _merge_check = MergeCheck.validate(
448 448 pull_request_latest, auth_user=self._rhodecode_user,
449 449 translator=self.request.translate,
450 450 force_shadow_repo_refresh=force_refresh)
451 451
452 452 c.pr_merge_errors = _merge_check.error_details
453 453 c.pr_merge_possible = not _merge_check.failed
454 454 c.pr_merge_message = _merge_check.merge_msg
455 455 c.pr_merge_source_commit = _merge_check.source_commit
456 456 c.pr_merge_target_commit = _merge_check.target_commit
457 457
458 458 c.pr_merge_info = MergeCheck.get_merge_conditions(
459 459 pull_request_latest, translator=self.request.translate)
460 460
461 461 c.pull_request_review_status = _merge_check.review_status
462 462 if merge_checks:
463 463 self.request.override_renderer = \
464 464 'rhodecode:templates/pullrequests/pullrequest_merge_checks.mako'
465 465 return self._get_template_context(c)
466 466
467 467 c.reviewers_count = pull_request.reviewers_count
468 468 c.observers_count = pull_request.observers_count
469 469
470 470 # reviewers and statuses
471 471 c.pull_request_default_reviewers_data_json = json.dumps(pull_request.reviewer_data)
472 472 c.pull_request_set_reviewers_data_json = collections.OrderedDict({'reviewers': []})
473 473 c.pull_request_set_observers_data_json = collections.OrderedDict({'observers': []})
474 474
475 475 for review_obj, member, reasons, mandatory, status in pull_request_at_ver.reviewers_statuses():
476 476 member_reviewer = h.reviewer_as_json(
477 477 member, reasons=reasons, mandatory=mandatory,
478 478 role=review_obj.role,
479 479 user_group=review_obj.rule_user_group_data()
480 480 )
481 481
482 482 current_review_status = status[0][1].status if status else ChangesetStatus.STATUS_NOT_REVIEWED
483 483 member_reviewer['review_status'] = current_review_status
484 484 member_reviewer['review_status_label'] = h.commit_status_lbl(current_review_status)
485 485 member_reviewer['allowed_to_update'] = c.allowed_to_update
486 486 c.pull_request_set_reviewers_data_json['reviewers'].append(member_reviewer)
487 487
488 488 c.pull_request_set_reviewers_data_json = json.dumps(c.pull_request_set_reviewers_data_json)
489 489
490 490 for observer_obj, member in pull_request_at_ver.observers():
491 491 member_observer = h.reviewer_as_json(
492 492 member, reasons=[], mandatory=False,
493 493 role=observer_obj.role,
494 494 user_group=observer_obj.rule_user_group_data()
495 495 )
496 496 member_observer['allowed_to_update'] = c.allowed_to_update
497 497 c.pull_request_set_observers_data_json['observers'].append(member_observer)
498 498
499 499 c.pull_request_set_observers_data_json = json.dumps(c.pull_request_set_observers_data_json)
500 500
501 501 general_comments, inline_comments = \
502 502 self.register_comments_vars(c, pull_request_latest, versions)
503 503
504 504 # TODOs
505 505 c.unresolved_comments = CommentsModel() \
506 506 .get_pull_request_unresolved_todos(pull_request_latest)
507 507 c.resolved_comments = CommentsModel() \
508 508 .get_pull_request_resolved_todos(pull_request_latest)
509 509
510 510 # Drafts
511 511 c.draft_comments = CommentsModel().get_pull_request_drafts(
512 512 self._rhodecode_db_user.user_id,
513 513 pull_request_latest)
514 514
515 515 # if we use version, then do not show later comments
516 516 # than current version
517 517 display_inline_comments = collections.defaultdict(
518 518 lambda: collections.defaultdict(list))
519 519 for co in inline_comments:
520 520 if c.at_version_num:
521 521 # pick comments that are at least UPTO given version, so we
522 522 # don't render comments for higher version
523 523 should_render = co.pull_request_version_id and \
524 524 co.pull_request_version_id <= c.at_version_num
525 525 else:
526 526 # showing all, for 'latest'
527 527 should_render = True
528 528
529 529 if should_render:
530 530 display_inline_comments[co.f_path][co.line_no].append(co)
531 531
532 532 # load diff data into template context, if we use compare mode then
533 533 # diff is calculated based on changes between versions of PR
534 534
535 535 source_repo = pull_request_at_ver.source_repo
536 536 source_ref_id = pull_request_at_ver.source_ref_parts.commit_id
537 537
538 538 target_repo = pull_request_at_ver.target_repo
539 539 target_ref_id = pull_request_at_ver.target_ref_parts.commit_id
540 540
541 541 if compare:
542 542 # in compare switch the diff base to latest commit from prev version
543 543 target_ref_id = prev_pull_request_display_obj.revisions[0]
544 544
545 545 # despite opening commits for bookmarks/branches/tags, we always
546 546 # convert this to rev to prevent changes after bookmark or branch change
547 547 c.source_ref_type = 'rev'
548 548 c.source_ref = source_ref_id
549 549
550 550 c.target_ref_type = 'rev'
551 551 c.target_ref = target_ref_id
552 552
553 553 c.source_repo = source_repo
554 554 c.target_repo = target_repo
555 555
556 556 c.commit_ranges = []
557 557 source_commit = EmptyCommit()
558 558 target_commit = EmptyCommit()
559 559 c.missing_requirements = False
560 560
561 561 source_scm = source_repo.scm_instance()
562 562 target_scm = target_repo.scm_instance()
563 563
564 564 shadow_scm = None
565 565 try:
566 566 shadow_scm = pull_request_latest.get_shadow_repo()
567 567 except Exception:
568 568 log.debug('Failed to get shadow repo', exc_info=True)
569 569 # try first the existing source_repo, and then shadow
570 570 # repo if we can obtain one
571 571 commits_source_repo = source_scm
572 572 if shadow_scm:
573 573 commits_source_repo = shadow_scm
574 574
575 575 c.commits_source_repo = commits_source_repo
576 576 c.ancestor = None # set it to None, to hide it from PR view
577 577
578 578 # empty version means latest, so we keep this to prevent
579 579 # double caching
580 580 version_normalized = version or PullRequest.LATEST_VER
581 581 from_version_normalized = from_version or PullRequest.LATEST_VER
582 582
583 583 cache_path = self.rhodecode_vcs_repo.get_create_shadow_cache_pr_path(target_repo)
584 584 cache_file_path = diff_cache_exist(
585 585 cache_path, 'pull_request', pull_request_id, version_normalized,
586 586 from_version_normalized, source_ref_id, target_ref_id,
587 587 hide_whitespace_changes, diff_context, c.fulldiff)
588 588
589 589 caching_enabled = self._is_diff_cache_enabled(c.target_repo)
590 590 force_recache = self.get_recache_flag()
591 591
592 592 cached_diff = None
593 593 if caching_enabled:
594 594 cached_diff = load_cached_diff(cache_file_path)
595 595
596 596 has_proper_commit_cache = (
597 597 cached_diff and cached_diff.get('commits')
598 598 and len(cached_diff.get('commits', [])) == 5
599 599 and cached_diff.get('commits')[0]
600 600 and cached_diff.get('commits')[3])
601 601
602 602 if not force_recache and not c.range_diff_on and has_proper_commit_cache:
603 603 diff_commit_cache = \
604 604 (ancestor_commit, commit_cache, missing_requirements,
605 605 source_commit, target_commit) = cached_diff['commits']
606 606 else:
607 607 # NOTE(marcink): we reach potentially unreachable errors when a PR has
608 608 # merge errors resulting in potentially hidden commits in the shadow repo.
609 609 maybe_unreachable = _merge_check.MERGE_CHECK in _merge_check.error_details \
610 610 and _merge_check.merge_response
611 611 maybe_unreachable = maybe_unreachable \
612 612 and _merge_check.merge_response.metadata.get('unresolved_files')
613 613 log.debug("Using unreachable commits due to MERGE_CHECK in merge simulation")
614 614 diff_commit_cache = \
615 615 (ancestor_commit, commit_cache, missing_requirements,
616 616 source_commit, target_commit) = self.get_commits(
617 617 commits_source_repo,
618 618 pull_request_at_ver,
619 619 source_commit,
620 620 source_ref_id,
621 621 source_scm,
622 622 target_commit,
623 623 target_ref_id,
624 624 target_scm,
625 625 maybe_unreachable=maybe_unreachable)
626 626
627 627 # register our commit range
628 628 for comm in commit_cache.values():
629 629 c.commit_ranges.append(comm)
630 630
631 631 c.missing_requirements = missing_requirements
632 632 c.ancestor_commit = ancestor_commit
633 633 c.statuses = source_repo.statuses(
634 634 [x.raw_id for x in c.commit_ranges])
635 635
636 636 # auto collapse if we have more than limit
637 637 collapse_limit = diffs.DiffProcessor._collapse_commits_over
638 638 c.collapse_all_commits = len(c.commit_ranges) > collapse_limit
639 639 c.compare_mode = compare
640 640
641 641 # diff_limit is the old behavior, will cut off the whole diff
642 642 # if the limit is applied otherwise will just hide the
643 643 # big files from the front-end
644 644 diff_limit = c.visual.cut_off_limit_diff
645 645 file_limit = c.visual.cut_off_limit_file
646 646
647 647 c.missing_commits = False
648 648 if (c.missing_requirements
649 649 or isinstance(source_commit, EmptyCommit)
650 650 or source_commit == target_commit):
651 651
652 652 c.missing_commits = True
653 653 else:
654 654 c.inline_comments = display_inline_comments
655 655
656 656 use_ancestor = True
657 657 if from_version_normalized != version_normalized:
658 658 use_ancestor = False
659 659
660 660 has_proper_diff_cache = cached_diff and cached_diff.get('commits')
661 661 if not force_recache and has_proper_diff_cache:
662 662 c.diffset = cached_diff['diff']
663 663 else:
664 664 try:
665 665 c.diffset = self._get_diffset(
666 666 c.source_repo.repo_name, commits_source_repo,
667 667 c.ancestor_commit,
668 668 source_ref_id, target_ref_id,
669 669 target_commit, source_commit,
670 670 diff_limit, file_limit, c.fulldiff,
671 671 hide_whitespace_changes, diff_context,
672 672 use_ancestor=use_ancestor
673 673 )
674 674
675 675 # save cached diff
676 676 if caching_enabled:
677 677 cache_diff(cache_file_path, c.diffset, diff_commit_cache)
678 678 except CommitDoesNotExistError:
679 679 log.exception('Failed to generate diffset')
680 680 c.missing_commits = True
681 681
682 682 if not c.missing_commits:
683 683
684 684 c.limited_diff = c.diffset.limited_diff
685 685
686 686 # calculate removed files that are bound to comments
687 687 comment_deleted_files = [
688 688 fname for fname in display_inline_comments
689 689 if fname not in c.diffset.file_stats]
690 690
691 691 c.deleted_files_comments = collections.defaultdict(dict)
692 692 for fname, per_line_comments in display_inline_comments.items():
693 693 if fname in comment_deleted_files:
694 694 c.deleted_files_comments[fname]['stats'] = 0
695 695 c.deleted_files_comments[fname]['comments'] = list()
696 696 for lno, comments in per_line_comments.items():
697 697 c.deleted_files_comments[fname]['comments'].extend(comments)
698 698
699 699 # maybe calculate the range diff
700 700 if c.range_diff_on:
701 701 # TODO(marcink): set whitespace/context
702 702 context_lcl = 3
703 703 ign_whitespace_lcl = False
704 704
705 705 for commit in c.commit_ranges:
706 706 commit2 = commit
707 707 commit1 = commit.first_parent
708 708
709 709 range_diff_cache_file_path = diff_cache_exist(
710 710 cache_path, 'diff', commit.raw_id,
711 711 ign_whitespace_lcl, context_lcl, c.fulldiff)
712 712
713 713 cached_diff = None
714 714 if caching_enabled:
715 715 cached_diff = load_cached_diff(range_diff_cache_file_path)
716 716
717 717 has_proper_diff_cache = cached_diff and cached_diff.get('diff')
718 718 if not force_recache and has_proper_diff_cache:
719 719 diffset = cached_diff['diff']
720 720 else:
721 721 diffset = self._get_range_diffset(
722 722 commits_source_repo, source_repo,
723 723 commit1, commit2, diff_limit, file_limit,
724 724 c.fulldiff, ign_whitespace_lcl, context_lcl
725 725 )
726 726
727 727 # save cached diff
728 728 if caching_enabled:
729 729 cache_diff(range_diff_cache_file_path, diffset, None)
730 730
731 731 c.changes[commit.raw_id] = diffset
732 732
733 733 # this is a hack to properly display links, when creating PR, the
734 734 # compare view and others uses different notation, and
735 735 # compare_commits.mako renders links based on the target_repo.
736 736 # We need to swap that here to generate it properly on the html side
737 737 c.target_repo = c.source_repo
738 738
739 739 c.commit_statuses = ChangesetStatus.STATUSES
740 740
741 741 c.show_version_changes = not pr_closed
742 742 if c.show_version_changes:
743 743 cur_obj = pull_request_at_ver
744 744 prev_obj = prev_pull_request_at_ver
745 745
746 746 old_commit_ids = prev_obj.revisions
747 747 new_commit_ids = cur_obj.revisions
748 748 commit_changes = PullRequestModel()._calculate_commit_id_changes(
749 749 old_commit_ids, new_commit_ids)
750 750 c.commit_changes_summary = commit_changes
751 751
752 752 # calculate the diff for commits between versions
753 753 c.commit_changes = []
754 754
755 755 def mark(cs, fw):
756 return list(h.itertools.izip_longest([], cs, fillvalue=fw))
756 return list(h.itertools.zip_longest([], cs, fillvalue=fw))
757 757
758 758 for c_type, raw_id in mark(commit_changes.added, 'a') \
759 759 + mark(commit_changes.removed, 'r') \
760 760 + mark(commit_changes.common, 'c'):
761 761
762 762 if raw_id in commit_cache:
763 763 commit = commit_cache[raw_id]
764 764 else:
765 765 try:
766 766 commit = commits_source_repo.get_commit(raw_id)
767 767 except CommitDoesNotExistError:
768 768 # in case we fail extracting still use "dummy" commit
769 769 # for display in commit diff
770 770 commit = h.AttributeDict(
771 771 {'raw_id': raw_id,
772 772 'message': 'EMPTY or MISSING COMMIT'})
773 773 c.commit_changes.append([c_type, commit])
774 774
775 775 # current user review statuses for each version
776 776 c.review_versions = {}
777 777 is_reviewer = PullRequestModel().is_user_reviewer(
778 778 pull_request, self._rhodecode_user)
779 779 if is_reviewer:
780 780 for co in general_comments:
781 781 if co.author.user_id == self._rhodecode_user.user_id:
782 782 status = co.status_change
783 783 if status:
784 784 _ver_pr = status[0].comment.pull_request_version_id
785 785 c.review_versions[_ver_pr] = status[0]
786 786
787 787 return self._get_template_context(c)
788 788
789 789 def get_commits(
790 790 self, commits_source_repo, pull_request_at_ver, source_commit,
791 791 source_ref_id, source_scm, target_commit, target_ref_id, target_scm,
792 792 maybe_unreachable=False):
793 793
794 794 commit_cache = collections.OrderedDict()
795 795 missing_requirements = False
796 796
797 797 try:
798 798 pre_load = ["author", "date", "message", "branch", "parents"]
799 799
800 800 pull_request_commits = pull_request_at_ver.revisions
801 801 log.debug('Loading %s commits from %s',
802 802 len(pull_request_commits), commits_source_repo)
803 803
804 804 for rev in pull_request_commits:
805 805 comm = commits_source_repo.get_commit(commit_id=rev, pre_load=pre_load,
806 806 maybe_unreachable=maybe_unreachable)
807 807 commit_cache[comm.raw_id] = comm
808 808
809 809 # Order here matters, we first need to get target, and then
810 810 # the source
811 811 target_commit = commits_source_repo.get_commit(
812 812 commit_id=safe_str(target_ref_id))
813 813
814 814 source_commit = commits_source_repo.get_commit(
815 815 commit_id=safe_str(source_ref_id), maybe_unreachable=True)
816 816 except CommitDoesNotExistError:
817 817 log.warning('Failed to get commit from `{}` repo'.format(
818 818 commits_source_repo), exc_info=True)
819 819 except RepositoryRequirementError:
820 820 log.warning('Failed to get all required data from repo', exc_info=True)
821 821 missing_requirements = True
822 822
823 823 pr_ancestor_id = pull_request_at_ver.common_ancestor_id
824 824
825 825 try:
826 826 ancestor_commit = source_scm.get_commit(pr_ancestor_id)
827 827 except Exception:
828 828 ancestor_commit = None
829 829
830 830 return ancestor_commit, commit_cache, missing_requirements, source_commit, target_commit
831 831
832 832 def assure_not_empty_repo(self):
833 833 _ = self.request.translate
834 834
835 835 try:
836 836 self.db_repo.scm_instance().get_commit()
837 837 except EmptyRepositoryError:
838 838 h.flash(h.literal(_('There are no commits yet')),
839 839 category='warning')
840 840 raise HTTPFound(
841 841 h.route_path('repo_summary', repo_name=self.db_repo.repo_name))
842 842
843 843 @LoginRequired()
844 844 @NotAnonymous()
845 845 @HasRepoPermissionAnyDecorator(
846 846 'repository.read', 'repository.write', 'repository.admin')
847 847 def pull_request_new(self):
848 848 _ = self.request.translate
849 849 c = self.load_default_context()
850 850
851 851 self.assure_not_empty_repo()
852 852 source_repo = self.db_repo
853 853
854 854 commit_id = self.request.GET.get('commit')
855 855 branch_ref = self.request.GET.get('branch')
856 856 bookmark_ref = self.request.GET.get('bookmark')
857 857
858 858 try:
859 859 source_repo_data = PullRequestModel().generate_repo_data(
860 860 source_repo, commit_id=commit_id,
861 861 branch=branch_ref, bookmark=bookmark_ref,
862 862 translator=self.request.translate)
863 863 except CommitDoesNotExistError as e:
864 864 log.exception(e)
865 865 h.flash(_('Commit does not exist'), 'error')
866 866 raise HTTPFound(
867 867 h.route_path('pullrequest_new', repo_name=source_repo.repo_name))
868 868
869 869 default_target_repo = source_repo
870 870
871 871 if source_repo.parent and c.has_origin_repo_read_perm:
872 872 parent_vcs_obj = source_repo.parent.scm_instance()
873 873 if parent_vcs_obj and not parent_vcs_obj.is_empty():
874 874 # change default if we have a parent repo
875 875 default_target_repo = source_repo.parent
876 876
877 877 target_repo_data = PullRequestModel().generate_repo_data(
878 878 default_target_repo, translator=self.request.translate)
879 879
880 880 selected_source_ref = source_repo_data['refs']['selected_ref']
881 881 title_source_ref = ''
882 882 if selected_source_ref:
883 883 title_source_ref = selected_source_ref.split(':', 2)[1]
884 884 c.default_title = PullRequestModel().generate_pullrequest_title(
885 885 source=source_repo.repo_name,
886 886 source_ref=title_source_ref,
887 887 target=default_target_repo.repo_name
888 888 )
889 889
890 890 c.default_repo_data = {
891 891 'source_repo_name': source_repo.repo_name,
892 892 'source_refs_json': json.dumps(source_repo_data),
893 893 'target_repo_name': default_target_repo.repo_name,
894 894 'target_refs_json': json.dumps(target_repo_data),
895 895 }
896 896 c.default_source_ref = selected_source_ref
897 897
898 898 return self._get_template_context(c)
899 899
900 900 @LoginRequired()
901 901 @NotAnonymous()
902 902 @HasRepoPermissionAnyDecorator(
903 903 'repository.read', 'repository.write', 'repository.admin')
904 904 def pull_request_repo_refs(self):
905 905 self.load_default_context()
906 906 target_repo_name = self.request.matchdict['target_repo_name']
907 907 repo = Repository.get_by_repo_name(target_repo_name)
908 908 if not repo:
909 909 raise HTTPNotFound()
910 910
911 911 target_perm = HasRepoPermissionAny(
912 912 'repository.read', 'repository.write', 'repository.admin')(
913 913 target_repo_name)
914 914 if not target_perm:
915 915 raise HTTPNotFound()
916 916
917 917 return PullRequestModel().generate_repo_data(
918 918 repo, translator=self.request.translate)
919 919
920 920 @LoginRequired()
921 921 @NotAnonymous()
922 922 @HasRepoPermissionAnyDecorator(
923 923 'repository.read', 'repository.write', 'repository.admin')
924 924 def pullrequest_repo_targets(self):
925 925 _ = self.request.translate
926 926 filter_query = self.request.GET.get('query')
927 927
928 928 # get the parents
929 929 parent_target_repos = []
930 930 if self.db_repo.parent:
931 931 parents_query = Repository.query() \
932 932 .order_by(func.length(Repository.repo_name)) \
933 933 .filter(Repository.fork_id == self.db_repo.parent.repo_id)
934 934
935 935 if filter_query:
936 936 ilike_expression = u'%{}%'.format(safe_unicode(filter_query))
937 937 parents_query = parents_query.filter(
938 938 Repository.repo_name.ilike(ilike_expression))
939 939 parents = parents_query.limit(20).all()
940 940
941 941 for parent in parents:
942 942 parent_vcs_obj = parent.scm_instance()
943 943 if parent_vcs_obj and not parent_vcs_obj.is_empty():
944 944 parent_target_repos.append(parent)
945 945
946 946 # get other forks, and repo itself
947 947 query = Repository.query() \
948 948 .order_by(func.length(Repository.repo_name)) \
949 949 .filter(
950 950 or_(Repository.repo_id == self.db_repo.repo_id, # repo itself
951 951 Repository.fork_id == self.db_repo.repo_id) # forks of this repo
952 952 ) \
953 953 .filter(~Repository.repo_id.in_([x.repo_id for x in parent_target_repos]))
954 954
955 955 if filter_query:
956 956 ilike_expression = u'%{}%'.format(safe_unicode(filter_query))
957 957 query = query.filter(Repository.repo_name.ilike(ilike_expression))
958 958
959 959 limit = max(20 - len(parent_target_repos), 5) # not less then 5
960 960 target_repos = query.limit(limit).all()
961 961
962 962 all_target_repos = target_repos + parent_target_repos
963 963
964 964 repos = []
965 965 # This checks permissions to the repositories
966 966 for obj in ScmModel().get_repos(all_target_repos):
967 967 repos.append({
968 968 'id': obj['name'],
969 969 'text': obj['name'],
970 970 'type': 'repo',
971 971 'repo_id': obj['dbrepo']['repo_id'],
972 972 'repo_type': obj['dbrepo']['repo_type'],
973 973 'private': obj['dbrepo']['private'],
974 974
975 975 })
976 976
977 977 data = {
978 978 'more': False,
979 979 'results': [{
980 980 'text': _('Repositories'),
981 981 'children': repos
982 982 }] if repos else []
983 983 }
984 984 return data
985 985
986 986 @classmethod
987 987 def get_comment_ids(cls, post_data):
988 988 return filter(lambda e: e > 0, map(safe_int, aslist(post_data.get('comments'), ',')))
989 989
990 990 @LoginRequired()
991 991 @NotAnonymous()
992 992 @HasRepoPermissionAnyDecorator(
993 993 'repository.read', 'repository.write', 'repository.admin')
994 994 def pullrequest_comments(self):
995 995 self.load_default_context()
996 996
997 997 pull_request = PullRequest.get_or_404(
998 998 self.request.matchdict['pull_request_id'])
999 999 pull_request_id = pull_request.pull_request_id
1000 1000 version = self.request.GET.get('version')
1001 1001
1002 1002 _render = self.request.get_partial_renderer(
1003 1003 'rhodecode:templates/base/sidebar.mako')
1004 1004 c = _render.get_call_context()
1005 1005
1006 1006 (pull_request_latest,
1007 1007 pull_request_at_ver,
1008 1008 pull_request_display_obj,
1009 1009 at_version) = PullRequestModel().get_pr_version(
1010 1010 pull_request_id, version=version)
1011 1011 versions = pull_request_display_obj.versions()
1012 1012 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
1013 1013 c.versions = versions + [latest_ver]
1014 1014
1015 1015 c.at_version = at_version
1016 1016 c.at_version_num = (at_version
1017 1017 if at_version and at_version != PullRequest.LATEST_VER
1018 1018 else None)
1019 1019
1020 1020 self.register_comments_vars(c, pull_request_latest, versions, include_drafts=False)
1021 1021 all_comments = c.inline_comments_flat + c.comments
1022 1022
1023 1023 existing_ids = self.get_comment_ids(self.request.POST)
1024 1024 return _render('comments_table', all_comments, len(all_comments),
1025 1025 existing_ids=existing_ids)
1026 1026
1027 1027 @LoginRequired()
1028 1028 @NotAnonymous()
1029 1029 @HasRepoPermissionAnyDecorator(
1030 1030 'repository.read', 'repository.write', 'repository.admin')
1031 1031 def pullrequest_todos(self):
1032 1032 self.load_default_context()
1033 1033
1034 1034 pull_request = PullRequest.get_or_404(
1035 1035 self.request.matchdict['pull_request_id'])
1036 1036 pull_request_id = pull_request.pull_request_id
1037 1037 version = self.request.GET.get('version')
1038 1038
1039 1039 _render = self.request.get_partial_renderer(
1040 1040 'rhodecode:templates/base/sidebar.mako')
1041 1041 c = _render.get_call_context()
1042 1042 (pull_request_latest,
1043 1043 pull_request_at_ver,
1044 1044 pull_request_display_obj,
1045 1045 at_version) = PullRequestModel().get_pr_version(
1046 1046 pull_request_id, version=version)
1047 1047 versions = pull_request_display_obj.versions()
1048 1048 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
1049 1049 c.versions = versions + [latest_ver]
1050 1050
1051 1051 c.at_version = at_version
1052 1052 c.at_version_num = (at_version
1053 1053 if at_version and at_version != PullRequest.LATEST_VER
1054 1054 else None)
1055 1055
1056 1056 c.unresolved_comments = CommentsModel() \
1057 1057 .get_pull_request_unresolved_todos(pull_request, include_drafts=False)
1058 1058 c.resolved_comments = CommentsModel() \
1059 1059 .get_pull_request_resolved_todos(pull_request, include_drafts=False)
1060 1060
1061 1061 all_comments = c.unresolved_comments + c.resolved_comments
1062 1062 existing_ids = self.get_comment_ids(self.request.POST)
1063 1063 return _render('comments_table', all_comments, len(c.unresolved_comments),
1064 1064 todo_comments=True, existing_ids=existing_ids)
1065 1065
1066 1066 @LoginRequired()
1067 1067 @NotAnonymous()
1068 1068 @HasRepoPermissionAnyDecorator(
1069 1069 'repository.read', 'repository.write', 'repository.admin')
1070 1070 def pullrequest_drafts(self):
1071 1071 self.load_default_context()
1072 1072
1073 1073 pull_request = PullRequest.get_or_404(
1074 1074 self.request.matchdict['pull_request_id'])
1075 1075 pull_request_id = pull_request.pull_request_id
1076 1076 version = self.request.GET.get('version')
1077 1077
1078 1078 _render = self.request.get_partial_renderer(
1079 1079 'rhodecode:templates/base/sidebar.mako')
1080 1080 c = _render.get_call_context()
1081 1081
1082 1082 (pull_request_latest,
1083 1083 pull_request_at_ver,
1084 1084 pull_request_display_obj,
1085 1085 at_version) = PullRequestModel().get_pr_version(
1086 1086 pull_request_id, version=version)
1087 1087 versions = pull_request_display_obj.versions()
1088 1088 latest_ver = PullRequest.get_pr_display_object(pull_request_latest, pull_request_latest)
1089 1089 c.versions = versions + [latest_ver]
1090 1090
1091 1091 c.at_version = at_version
1092 1092 c.at_version_num = (at_version
1093 1093 if at_version and at_version != PullRequest.LATEST_VER
1094 1094 else None)
1095 1095
1096 1096 c.draft_comments = CommentsModel() \
1097 1097 .get_pull_request_drafts(self._rhodecode_db_user.user_id, pull_request)
1098 1098
1099 1099 all_comments = c.draft_comments
1100 1100
1101 1101 existing_ids = self.get_comment_ids(self.request.POST)
1102 1102 return _render('comments_table', all_comments, len(all_comments),
1103 1103 existing_ids=existing_ids, draft_comments=True)
1104 1104
1105 1105 @LoginRequired()
1106 1106 @NotAnonymous()
1107 1107 @HasRepoPermissionAnyDecorator(
1108 1108 'repository.read', 'repository.write', 'repository.admin')
1109 1109 @CSRFRequired()
1110 1110 def pull_request_create(self):
1111 1111 _ = self.request.translate
1112 1112 self.assure_not_empty_repo()
1113 1113 self.load_default_context()
1114 1114
1115 1115 controls = peppercorn.parse(self.request.POST.items())
1116 1116
1117 1117 try:
1118 1118 form = PullRequestForm(
1119 1119 self.request.translate, self.db_repo.repo_id)()
1120 1120 _form = form.to_python(controls)
1121 1121 except formencode.Invalid as errors:
1122 1122 if errors.error_dict.get('revisions'):
1123 1123 msg = 'Revisions: %s' % errors.error_dict['revisions']
1124 1124 elif errors.error_dict.get('pullrequest_title'):
1125 1125 msg = errors.error_dict.get('pullrequest_title')
1126 1126 else:
1127 1127 msg = _('Error creating pull request: {}').format(errors)
1128 1128 log.exception(msg)
1129 1129 h.flash(msg, 'error')
1130 1130
1131 1131 # would rather just go back to form ...
1132 1132 raise HTTPFound(
1133 1133 h.route_path('pullrequest_new', repo_name=self.db_repo_name))
1134 1134
1135 1135 source_repo = _form['source_repo']
1136 1136 source_ref = _form['source_ref']
1137 1137 target_repo = _form['target_repo']
1138 1138 target_ref = _form['target_ref']
1139 1139 commit_ids = _form['revisions'][::-1]
1140 1140 common_ancestor_id = _form['common_ancestor']
1141 1141
1142 1142 # find the ancestor for this pr
1143 1143 source_db_repo = Repository.get_by_repo_name(_form['source_repo'])
1144 1144 target_db_repo = Repository.get_by_repo_name(_form['target_repo'])
1145 1145
1146 1146 if not (source_db_repo or target_db_repo):
1147 1147 h.flash(_('source_repo or target repo not found'), category='error')
1148 1148 raise HTTPFound(
1149 1149 h.route_path('pullrequest_new', repo_name=self.db_repo_name))
1150 1150
1151 1151 # re-check permissions again here
1152 1152 # source_repo we must have read permissions
1153 1153
1154 1154 source_perm = HasRepoPermissionAny(
1155 1155 'repository.read', 'repository.write', 'repository.admin')(
1156 1156 source_db_repo.repo_name)
1157 1157 if not source_perm:
1158 1158 msg = _('Not Enough permissions to source repo `{}`.'.format(
1159 1159 source_db_repo.repo_name))
1160 1160 h.flash(msg, category='error')
1161 1161 # copy the args back to redirect
1162 1162 org_query = self.request.GET.mixed()
1163 1163 raise HTTPFound(
1164 1164 h.route_path('pullrequest_new', repo_name=self.db_repo_name,
1165 1165 _query=org_query))
1166 1166
1167 1167 # target repo we must have read permissions, and also later on
1168 1168 # we want to check branch permissions here
1169 1169 target_perm = HasRepoPermissionAny(
1170 1170 'repository.read', 'repository.write', 'repository.admin')(
1171 1171 target_db_repo.repo_name)
1172 1172 if not target_perm:
1173 1173 msg = _('Not Enough permissions to target repo `{}`.'.format(
1174 1174 target_db_repo.repo_name))
1175 1175 h.flash(msg, category='error')
1176 1176 # copy the args back to redirect
1177 1177 org_query = self.request.GET.mixed()
1178 1178 raise HTTPFound(
1179 1179 h.route_path('pullrequest_new', repo_name=self.db_repo_name,
1180 1180 _query=org_query))
1181 1181
1182 1182 source_scm = source_db_repo.scm_instance()
1183 1183 target_scm = target_db_repo.scm_instance()
1184 1184
1185 1185 source_ref_obj = unicode_to_reference(source_ref)
1186 1186 target_ref_obj = unicode_to_reference(target_ref)
1187 1187
1188 1188 source_commit = source_scm.get_commit(source_ref_obj.commit_id)
1189 1189 target_commit = target_scm.get_commit(target_ref_obj.commit_id)
1190 1190
1191 1191 ancestor = source_scm.get_common_ancestor(
1192 1192 source_commit.raw_id, target_commit.raw_id, target_scm)
1193 1193
1194 1194 # recalculate target ref based on ancestor
1195 1195 target_ref = ':'.join((target_ref_obj.type, target_ref_obj.name, ancestor))
1196 1196
1197 1197 get_default_reviewers_data, validate_default_reviewers, validate_observers = \
1198 1198 PullRequestModel().get_reviewer_functions()
1199 1199
1200 1200 # recalculate reviewers logic, to make sure we can validate this
1201 1201 reviewer_rules = get_default_reviewers_data(
1202 1202 self._rhodecode_db_user,
1203 1203 source_db_repo,
1204 1204 source_ref_obj,
1205 1205 target_db_repo,
1206 1206 target_ref_obj,
1207 1207 include_diff_info=False)
1208 1208
1209 1209 reviewers = validate_default_reviewers(_form['review_members'], reviewer_rules)
1210 1210 observers = validate_observers(_form['observer_members'], reviewer_rules)
1211 1211
1212 1212 pullrequest_title = _form['pullrequest_title']
1213 1213 title_source_ref = source_ref_obj.name
1214 1214 if not pullrequest_title:
1215 1215 pullrequest_title = PullRequestModel().generate_pullrequest_title(
1216 1216 source=source_repo,
1217 1217 source_ref=title_source_ref,
1218 1218 target=target_repo
1219 1219 )
1220 1220
1221 1221 description = _form['pullrequest_desc']
1222 1222 description_renderer = _form['description_renderer']
1223 1223
1224 1224 try:
1225 1225 pull_request = PullRequestModel().create(
1226 1226 created_by=self._rhodecode_user.user_id,
1227 1227 source_repo=source_repo,
1228 1228 source_ref=source_ref,
1229 1229 target_repo=target_repo,
1230 1230 target_ref=target_ref,
1231 1231 revisions=commit_ids,
1232 1232 common_ancestor_id=common_ancestor_id,
1233 1233 reviewers=reviewers,
1234 1234 observers=observers,
1235 1235 title=pullrequest_title,
1236 1236 description=description,
1237 1237 description_renderer=description_renderer,
1238 1238 reviewer_data=reviewer_rules,
1239 1239 auth_user=self._rhodecode_user
1240 1240 )
1241 1241 Session().commit()
1242 1242
1243 1243 h.flash(_('Successfully opened new pull request'),
1244 1244 category='success')
1245 1245 except Exception:
1246 1246 msg = _('Error occurred during creation of this pull request.')
1247 1247 log.exception(msg)
1248 1248 h.flash(msg, category='error')
1249 1249
1250 1250 # copy the args back to redirect
1251 1251 org_query = self.request.GET.mixed()
1252 1252 raise HTTPFound(
1253 1253 h.route_path('pullrequest_new', repo_name=self.db_repo_name,
1254 1254 _query=org_query))
1255 1255
1256 1256 raise HTTPFound(
1257 1257 h.route_path('pullrequest_show', repo_name=target_repo,
1258 1258 pull_request_id=pull_request.pull_request_id))
1259 1259
1260 1260 @LoginRequired()
1261 1261 @NotAnonymous()
1262 1262 @HasRepoPermissionAnyDecorator(
1263 1263 'repository.read', 'repository.write', 'repository.admin')
1264 1264 @CSRFRequired()
1265 1265 def pull_request_update(self):
1266 1266 pull_request = PullRequest.get_or_404(
1267 1267 self.request.matchdict['pull_request_id'])
1268 1268 _ = self.request.translate
1269 1269
1270 1270 c = self.load_default_context()
1271 1271 redirect_url = None
1272 1272 # we do this check as first, because we want to know ASAP in the flow that
1273 1273 # pr is updating currently
1274 1274 is_state_changing = pull_request.is_state_changing()
1275 1275
1276 1276 if pull_request.is_closed():
1277 1277 log.debug('update: forbidden because pull request is closed')
1278 1278 msg = _(u'Cannot update closed pull requests.')
1279 1279 h.flash(msg, category='error')
1280 1280 return {'response': True,
1281 1281 'redirect_url': redirect_url}
1282 1282
1283 1283 c.pr_broadcast_channel = channelstream.pr_channel(pull_request)
1284 1284
1285 1285 # only owner or admin can update it
1286 1286 allowed_to_update = PullRequestModel().check_user_update(
1287 1287 pull_request, self._rhodecode_user)
1288 1288
1289 1289 if allowed_to_update:
1290 1290 controls = peppercorn.parse(self.request.POST.items())
1291 1291 force_refresh = str2bool(self.request.POST.get('force_refresh', 'false'))
1292 1292 do_update_commits = str2bool(self.request.POST.get('update_commits', 'false'))
1293 1293
1294 1294 if 'review_members' in controls:
1295 1295 self._update_reviewers(
1296 1296 c,
1297 1297 pull_request, controls['review_members'],
1298 1298 pull_request.reviewer_data,
1299 1299 PullRequestReviewers.ROLE_REVIEWER)
1300 1300 elif 'observer_members' in controls:
1301 1301 self._update_reviewers(
1302 1302 c,
1303 1303 pull_request, controls['observer_members'],
1304 1304 pull_request.reviewer_data,
1305 1305 PullRequestReviewers.ROLE_OBSERVER)
1306 1306 elif do_update_commits:
1307 1307 if is_state_changing:
1308 1308 log.debug('commits update: forbidden because pull request is in state %s',
1309 1309 pull_request.pull_request_state)
1310 1310 msg = _(u'Cannot update pull requests commits in state other than `{}`. '
1311 1311 u'Current state is: `{}`').format(
1312 1312 PullRequest.STATE_CREATED, pull_request.pull_request_state)
1313 1313 h.flash(msg, category='error')
1314 1314 return {'response': True,
1315 1315 'redirect_url': redirect_url}
1316 1316
1317 1317 self._update_commits(c, pull_request)
1318 1318 if force_refresh:
1319 1319 redirect_url = h.route_path(
1320 1320 'pullrequest_show', repo_name=self.db_repo_name,
1321 1321 pull_request_id=pull_request.pull_request_id,
1322 1322 _query={"force_refresh": 1})
1323 1323 elif str2bool(self.request.POST.get('edit_pull_request', 'false')):
1324 1324 self._edit_pull_request(pull_request)
1325 1325 else:
1326 1326 log.error('Unhandled update data.')
1327 1327 raise HTTPBadRequest()
1328 1328
1329 1329 return {'response': True,
1330 1330 'redirect_url': redirect_url}
1331 1331 raise HTTPForbidden()
1332 1332
1333 1333 def _edit_pull_request(self, pull_request):
1334 1334 """
1335 1335 Edit title and description
1336 1336 """
1337 1337 _ = self.request.translate
1338 1338
1339 1339 try:
1340 1340 PullRequestModel().edit(
1341 1341 pull_request,
1342 1342 self.request.POST.get('title'),
1343 1343 self.request.POST.get('description'),
1344 1344 self.request.POST.get('description_renderer'),
1345 1345 self._rhodecode_user)
1346 1346 except ValueError:
1347 1347 msg = _(u'Cannot update closed pull requests.')
1348 1348 h.flash(msg, category='error')
1349 1349 return
1350 1350 else:
1351 1351 Session().commit()
1352 1352
1353 1353 msg = _(u'Pull request title & description updated.')
1354 1354 h.flash(msg, category='success')
1355 1355 return
1356 1356
1357 1357 def _update_commits(self, c, pull_request):
1358 1358 _ = self.request.translate
1359 1359 log.debug('pull-request: running update commits actions')
1360 1360
1361 1361 @retry(exception=Exception, n_tries=3, delay=2)
1362 1362 def commits_update():
1363 1363 return PullRequestModel().update_commits(
1364 1364 pull_request, self._rhodecode_db_user)
1365 1365
1366 1366 with pull_request.set_state(PullRequest.STATE_UPDATING):
1367 1367 resp = commits_update() # retry x3
1368 1368
1369 1369 if resp.executed:
1370 1370
1371 1371 if resp.target_changed and resp.source_changed:
1372 1372 changed = 'target and source repositories'
1373 1373 elif resp.target_changed and not resp.source_changed:
1374 1374 changed = 'target repository'
1375 1375 elif not resp.target_changed and resp.source_changed:
1376 1376 changed = 'source repository'
1377 1377 else:
1378 1378 changed = 'nothing'
1379 1379
1380 1380 msg = _(u'Pull request updated to "{source_commit_id}" with '
1381 1381 u'{count_added} added, {count_removed} removed commits. '
1382 1382 u'Source of changes: {change_source}.')
1383 1383 msg = msg.format(
1384 1384 source_commit_id=pull_request.source_ref_parts.commit_id,
1385 1385 count_added=len(resp.changes.added),
1386 1386 count_removed=len(resp.changes.removed),
1387 1387 change_source=changed)
1388 1388 h.flash(msg, category='success')
1389 1389 channelstream.pr_update_channelstream_push(
1390 1390 self.request, c.pr_broadcast_channel, self._rhodecode_user, msg)
1391 1391 else:
1392 1392 msg = PullRequestModel.UPDATE_STATUS_MESSAGES[resp.reason]
1393 1393 warning_reasons = [
1394 1394 UpdateFailureReason.NO_CHANGE,
1395 1395 UpdateFailureReason.WRONG_REF_TYPE,
1396 1396 ]
1397 1397 category = 'warning' if resp.reason in warning_reasons else 'error'
1398 1398 h.flash(msg, category=category)
1399 1399
1400 1400 def _update_reviewers(self, c, pull_request, review_members, reviewer_rules, role):
1401 1401 _ = self.request.translate
1402 1402
1403 1403 get_default_reviewers_data, validate_default_reviewers, validate_observers = \
1404 1404 PullRequestModel().get_reviewer_functions()
1405 1405
1406 1406 if role == PullRequestReviewers.ROLE_REVIEWER:
1407 1407 try:
1408 1408 reviewers = validate_default_reviewers(review_members, reviewer_rules)
1409 1409 except ValueError as e:
1410 1410 log.error('Reviewers Validation: {}'.format(e))
1411 1411 h.flash(e, category='error')
1412 1412 return
1413 1413
1414 1414 old_calculated_status = pull_request.calculated_review_status()
1415 1415 PullRequestModel().update_reviewers(
1416 1416 pull_request, reviewers, self._rhodecode_db_user)
1417 1417
1418 1418 Session().commit()
1419 1419
1420 1420 msg = _('Pull request reviewers updated.')
1421 1421 h.flash(msg, category='success')
1422 1422 channelstream.pr_update_channelstream_push(
1423 1423 self.request, c.pr_broadcast_channel, self._rhodecode_user, msg)
1424 1424
1425 1425 # trigger status changed if change in reviewers changes the status
1426 1426 calculated_status = pull_request.calculated_review_status()
1427 1427 if old_calculated_status != calculated_status:
1428 1428 PullRequestModel().trigger_pull_request_hook(
1429 1429 pull_request, self._rhodecode_user, 'review_status_change',
1430 1430 data={'status': calculated_status})
1431 1431
1432 1432 elif role == PullRequestReviewers.ROLE_OBSERVER:
1433 1433 try:
1434 1434 observers = validate_observers(review_members, reviewer_rules)
1435 1435 except ValueError as e:
1436 1436 log.error('Observers Validation: {}'.format(e))
1437 1437 h.flash(e, category='error')
1438 1438 return
1439 1439
1440 1440 PullRequestModel().update_observers(
1441 1441 pull_request, observers, self._rhodecode_db_user)
1442 1442
1443 1443 Session().commit()
1444 1444 msg = _('Pull request observers updated.')
1445 1445 h.flash(msg, category='success')
1446 1446 channelstream.pr_update_channelstream_push(
1447 1447 self.request, c.pr_broadcast_channel, self._rhodecode_user, msg)
1448 1448
1449 1449 @LoginRequired()
1450 1450 @NotAnonymous()
1451 1451 @HasRepoPermissionAnyDecorator(
1452 1452 'repository.read', 'repository.write', 'repository.admin')
1453 1453 @CSRFRequired()
1454 1454 def pull_request_merge(self):
1455 1455 """
1456 1456 Merge will perform a server-side merge of the specified
1457 1457 pull request, if the pull request is approved and mergeable.
1458 1458 After successful merging, the pull request is automatically
1459 1459 closed, with a relevant comment.
1460 1460 """
1461 1461 pull_request = PullRequest.get_or_404(
1462 1462 self.request.matchdict['pull_request_id'])
1463 1463 _ = self.request.translate
1464 1464
1465 1465 if pull_request.is_state_changing():
1466 1466 log.debug('show: forbidden because pull request is in state %s',
1467 1467 pull_request.pull_request_state)
1468 1468 msg = _(u'Cannot merge pull requests in state other than `{}`. '
1469 1469 u'Current state is: `{}`').format(PullRequest.STATE_CREATED,
1470 1470 pull_request.pull_request_state)
1471 1471 h.flash(msg, category='error')
1472 1472 raise HTTPFound(
1473 1473 h.route_path('pullrequest_show',
1474 1474 repo_name=pull_request.target_repo.repo_name,
1475 1475 pull_request_id=pull_request.pull_request_id))
1476 1476
1477 1477 self.load_default_context()
1478 1478
1479 1479 with pull_request.set_state(PullRequest.STATE_UPDATING):
1480 1480 check = MergeCheck.validate(
1481 1481 pull_request, auth_user=self._rhodecode_user,
1482 1482 translator=self.request.translate)
1483 1483 merge_possible = not check.failed
1484 1484
1485 1485 for err_type, error_msg in check.errors:
1486 1486 h.flash(error_msg, category=err_type)
1487 1487
1488 1488 if merge_possible:
1489 1489 log.debug("Pre-conditions checked, trying to merge.")
1490 1490 extras = vcs_operation_context(
1491 1491 self.request.environ, repo_name=pull_request.target_repo.repo_name,
1492 1492 username=self._rhodecode_db_user.username, action='push',
1493 1493 scm=pull_request.target_repo.repo_type)
1494 1494 with pull_request.set_state(PullRequest.STATE_UPDATING):
1495 1495 self._merge_pull_request(
1496 1496 pull_request, self._rhodecode_db_user, extras)
1497 1497 else:
1498 1498 log.debug("Pre-conditions failed, NOT merging.")
1499 1499
1500 1500 raise HTTPFound(
1501 1501 h.route_path('pullrequest_show',
1502 1502 repo_name=pull_request.target_repo.repo_name,
1503 1503 pull_request_id=pull_request.pull_request_id))
1504 1504
1505 1505 def _merge_pull_request(self, pull_request, user, extras):
1506 1506 _ = self.request.translate
1507 1507 merge_resp = PullRequestModel().merge_repo(pull_request, user, extras=extras)
1508 1508
1509 1509 if merge_resp.executed:
1510 1510 log.debug("The merge was successful, closing the pull request.")
1511 1511 PullRequestModel().close_pull_request(
1512 1512 pull_request.pull_request_id, user)
1513 1513 Session().commit()
1514 1514 msg = _('Pull request was successfully merged and closed.')
1515 1515 h.flash(msg, category='success')
1516 1516 else:
1517 1517 log.debug(
1518 1518 "The merge was not successful. Merge response: %s", merge_resp)
1519 1519 msg = merge_resp.merge_status_message
1520 1520 h.flash(msg, category='error')
1521 1521
1522 1522 @LoginRequired()
1523 1523 @NotAnonymous()
1524 1524 @HasRepoPermissionAnyDecorator(
1525 1525 'repository.read', 'repository.write', 'repository.admin')
1526 1526 @CSRFRequired()
1527 1527 def pull_request_delete(self):
1528 1528 _ = self.request.translate
1529 1529
1530 1530 pull_request = PullRequest.get_or_404(
1531 1531 self.request.matchdict['pull_request_id'])
1532 1532 self.load_default_context()
1533 1533
1534 1534 pr_closed = pull_request.is_closed()
1535 1535 allowed_to_delete = PullRequestModel().check_user_delete(
1536 1536 pull_request, self._rhodecode_user) and not pr_closed
1537 1537
1538 1538 # only owner can delete it !
1539 1539 if allowed_to_delete:
1540 1540 PullRequestModel().delete(pull_request, self._rhodecode_user)
1541 1541 Session().commit()
1542 1542 h.flash(_('Successfully deleted pull request'),
1543 1543 category='success')
1544 1544 raise HTTPFound(h.route_path('pullrequest_show_all',
1545 1545 repo_name=self.db_repo_name))
1546 1546
1547 1547 log.warning('user %s tried to delete pull request without access',
1548 1548 self._rhodecode_user)
1549 1549 raise HTTPNotFound()
1550 1550
1551 1551 def _pull_request_comments_create(self, pull_request, comments):
1552 1552 _ = self.request.translate
1553 1553 data = {}
1554 1554 if not comments:
1555 1555 return
1556 1556 pull_request_id = pull_request.pull_request_id
1557 1557
1558 1558 all_drafts = len([x for x in comments if str2bool(x['is_draft'])]) == len(comments)
1559 1559
1560 1560 for entry in comments:
1561 1561 c = self.load_default_context()
1562 1562 comment_type = entry['comment_type']
1563 1563 text = entry['text']
1564 1564 status = entry['status']
1565 1565 is_draft = str2bool(entry['is_draft'])
1566 1566 resolves_comment_id = entry['resolves_comment_id']
1567 1567 close_pull_request = entry['close_pull_request']
1568 1568 f_path = entry['f_path']
1569 1569 line_no = entry['line']
1570 1570 target_elem_id = 'file-{}'.format(h.safeid(h.safe_unicode(f_path)))
1571 1571
1572 1572 # the logic here should work like following, if we submit close
1573 1573 # pr comment, use `close_pull_request_with_comment` function
1574 1574 # else handle regular comment logic
1575 1575
1576 1576 if close_pull_request:
1577 1577 # only owner or admin or person with write permissions
1578 1578 allowed_to_close = PullRequestModel().check_user_update(
1579 1579 pull_request, self._rhodecode_user)
1580 1580 if not allowed_to_close:
1581 1581 log.debug('comment: forbidden because not allowed to close '
1582 1582 'pull request %s', pull_request_id)
1583 1583 raise HTTPForbidden()
1584 1584
1585 1585 # This also triggers `review_status_change`
1586 1586 comment, status = PullRequestModel().close_pull_request_with_comment(
1587 1587 pull_request, self._rhodecode_user, self.db_repo, message=text,
1588 1588 auth_user=self._rhodecode_user)
1589 1589 Session().flush()
1590 1590 is_inline = comment.is_inline
1591 1591
1592 1592 PullRequestModel().trigger_pull_request_hook(
1593 1593 pull_request, self._rhodecode_user, 'comment',
1594 1594 data={'comment': comment})
1595 1595
1596 1596 else:
1597 1597 # regular comment case, could be inline, or one with status.
1598 1598 # for that one we check also permissions
1599 1599 # Additionally ENSURE if somehow draft is sent we're then unable to change status
1600 1600 allowed_to_change_status = PullRequestModel().check_user_change_status(
1601 1601 pull_request, self._rhodecode_user) and not is_draft
1602 1602
1603 1603 if status and allowed_to_change_status:
1604 1604 message = (_('Status change %(transition_icon)s %(status)s')
1605 1605 % {'transition_icon': '>',
1606 1606 'status': ChangesetStatus.get_status_lbl(status)})
1607 1607 text = text or message
1608 1608
1609 1609 comment = CommentsModel().create(
1610 1610 text=text,
1611 1611 repo=self.db_repo.repo_id,
1612 1612 user=self._rhodecode_user.user_id,
1613 1613 pull_request=pull_request,
1614 1614 f_path=f_path,
1615 1615 line_no=line_no,
1616 1616 status_change=(ChangesetStatus.get_status_lbl(status)
1617 1617 if status and allowed_to_change_status else None),
1618 1618 status_change_type=(status
1619 1619 if status and allowed_to_change_status else None),
1620 1620 comment_type=comment_type,
1621 1621 is_draft=is_draft,
1622 1622 resolves_comment_id=resolves_comment_id,
1623 1623 auth_user=self._rhodecode_user,
1624 1624 send_email=not is_draft, # skip notification for draft comments
1625 1625 )
1626 1626 is_inline = comment.is_inline
1627 1627
1628 1628 if allowed_to_change_status:
1629 1629 # calculate old status before we change it
1630 1630 old_calculated_status = pull_request.calculated_review_status()
1631 1631
1632 1632 # get status if set !
1633 1633 if status:
1634 1634 ChangesetStatusModel().set_status(
1635 1635 self.db_repo.repo_id,
1636 1636 status,
1637 1637 self._rhodecode_user.user_id,
1638 1638 comment,
1639 1639 pull_request=pull_request
1640 1640 )
1641 1641
1642 1642 Session().flush()
1643 1643 # this is somehow required to get access to some relationship
1644 1644 # loaded on comment
1645 1645 Session().refresh(comment)
1646 1646
1647 1647 # skip notifications for drafts
1648 1648 if not is_draft:
1649 1649 PullRequestModel().trigger_pull_request_hook(
1650 1650 pull_request, self._rhodecode_user, 'comment',
1651 1651 data={'comment': comment})
1652 1652
1653 1653 # we now calculate the status of pull request, and based on that
1654 1654 # calculation we set the commits status
1655 1655 calculated_status = pull_request.calculated_review_status()
1656 1656 if old_calculated_status != calculated_status:
1657 1657 PullRequestModel().trigger_pull_request_hook(
1658 1658 pull_request, self._rhodecode_user, 'review_status_change',
1659 1659 data={'status': calculated_status})
1660 1660
1661 1661 comment_id = comment.comment_id
1662 1662 data[comment_id] = {
1663 1663 'target_id': target_elem_id
1664 1664 }
1665 1665 Session().flush()
1666 1666
1667 1667 c.co = comment
1668 1668 c.at_version_num = None
1669 1669 c.is_new = True
1670 1670 rendered_comment = render(
1671 1671 'rhodecode:templates/changeset/changeset_comment_block.mako',
1672 1672 self._get_template_context(c), self.request)
1673 1673
1674 1674 data[comment_id].update(comment.get_dict())
1675 1675 data[comment_id].update({'rendered_text': rendered_comment})
1676 1676
1677 1677 Session().commit()
1678 1678
1679 1679 # skip channelstream for draft comments
1680 1680 if not all_drafts:
1681 1681 comment_broadcast_channel = channelstream.comment_channel(
1682 1682 self.db_repo_name, pull_request_obj=pull_request)
1683 1683
1684 1684 comment_data = data
1685 1685 posted_comment_type = 'inline' if is_inline else 'general'
1686 1686 if len(data) == 1:
1687 1687 msg = _('posted {} new {} comment').format(len(data), posted_comment_type)
1688 1688 else:
1689 1689 msg = _('posted {} new {} comments').format(len(data), posted_comment_type)
1690 1690
1691 1691 channelstream.comment_channelstream_push(
1692 1692 self.request, comment_broadcast_channel, self._rhodecode_user, msg,
1693 1693 comment_data=comment_data)
1694 1694
1695 1695 return data
1696 1696
1697 1697 @LoginRequired()
1698 1698 @NotAnonymous()
1699 1699 @HasRepoPermissionAnyDecorator(
1700 1700 'repository.read', 'repository.write', 'repository.admin')
1701 1701 @CSRFRequired()
1702 1702 def pull_request_comment_create(self):
1703 1703 _ = self.request.translate
1704 1704
1705 1705 pull_request = PullRequest.get_or_404(self.request.matchdict['pull_request_id'])
1706 1706
1707 1707 if pull_request.is_closed():
1708 1708 log.debug('comment: forbidden because pull request is closed')
1709 1709 raise HTTPForbidden()
1710 1710
1711 1711 allowed_to_comment = PullRequestModel().check_user_comment(
1712 1712 pull_request, self._rhodecode_user)
1713 1713 if not allowed_to_comment:
1714 1714 log.debug('comment: forbidden because pull request is from forbidden repo')
1715 1715 raise HTTPForbidden()
1716 1716
1717 1717 comment_data = {
1718 1718 'comment_type': self.request.POST.get('comment_type'),
1719 1719 'text': self.request.POST.get('text'),
1720 1720 'status': self.request.POST.get('changeset_status', None),
1721 1721 'is_draft': self.request.POST.get('draft'),
1722 1722 'resolves_comment_id': self.request.POST.get('resolves_comment_id', None),
1723 1723 'close_pull_request': self.request.POST.get('close_pull_request'),
1724 1724 'f_path': self.request.POST.get('f_path'),
1725 1725 'line': self.request.POST.get('line'),
1726 1726 }
1727 1727 data = self._pull_request_comments_create(pull_request, [comment_data])
1728 1728
1729 1729 return data
1730 1730
1731 1731 @LoginRequired()
1732 1732 @NotAnonymous()
1733 1733 @HasRepoPermissionAnyDecorator(
1734 1734 'repository.read', 'repository.write', 'repository.admin')
1735 1735 @CSRFRequired()
1736 1736 def pull_request_comment_delete(self):
1737 1737 pull_request = PullRequest.get_or_404(
1738 1738 self.request.matchdict['pull_request_id'])
1739 1739
1740 1740 comment = ChangesetComment.get_or_404(
1741 1741 self.request.matchdict['comment_id'])
1742 1742 comment_id = comment.comment_id
1743 1743
1744 1744 if comment.immutable:
1745 1745 # don't allow deleting comments that are immutable
1746 1746 raise HTTPForbidden()
1747 1747
1748 1748 if pull_request.is_closed():
1749 1749 log.debug('comment: forbidden because pull request is closed')
1750 1750 raise HTTPForbidden()
1751 1751
1752 1752 if not comment:
1753 1753 log.debug('Comment with id:%s not found, skipping', comment_id)
1754 1754 # comment already deleted in another call probably
1755 1755 return True
1756 1756
1757 1757 if comment.pull_request.is_closed():
1758 1758 # don't allow deleting comments on closed pull request
1759 1759 raise HTTPForbidden()
1760 1760
1761 1761 is_repo_admin = h.HasRepoPermissionAny('repository.admin')(self.db_repo_name)
1762 1762 super_admin = h.HasPermissionAny('hg.admin')()
1763 1763 comment_owner = comment.author.user_id == self._rhodecode_user.user_id
1764 1764 is_repo_comment = comment.repo.repo_name == self.db_repo_name
1765 1765 comment_repo_admin = is_repo_admin and is_repo_comment
1766 1766
1767 1767 if comment.draft and not comment_owner:
1768 1768 # We never allow to delete draft comments for other than owners
1769 1769 raise HTTPNotFound()
1770 1770
1771 1771 if super_admin or comment_owner or comment_repo_admin:
1772 1772 old_calculated_status = comment.pull_request.calculated_review_status()
1773 1773 CommentsModel().delete(comment=comment, auth_user=self._rhodecode_user)
1774 1774 Session().commit()
1775 1775 calculated_status = comment.pull_request.calculated_review_status()
1776 1776 if old_calculated_status != calculated_status:
1777 1777 PullRequestModel().trigger_pull_request_hook(
1778 1778 comment.pull_request, self._rhodecode_user, 'review_status_change',
1779 1779 data={'status': calculated_status})
1780 1780 return True
1781 1781 else:
1782 1782 log.warning('No permissions for user %s to delete comment_id: %s',
1783 1783 self._rhodecode_db_user, comment_id)
1784 1784 raise HTTPNotFound()
1785 1785
1786 1786 @LoginRequired()
1787 1787 @NotAnonymous()
1788 1788 @HasRepoPermissionAnyDecorator(
1789 1789 'repository.read', 'repository.write', 'repository.admin')
1790 1790 @CSRFRequired()
1791 1791 def pull_request_comment_edit(self):
1792 1792 self.load_default_context()
1793 1793
1794 1794 pull_request = PullRequest.get_or_404(
1795 1795 self.request.matchdict['pull_request_id']
1796 1796 )
1797 1797 comment = ChangesetComment.get_or_404(
1798 1798 self.request.matchdict['comment_id']
1799 1799 )
1800 1800 comment_id = comment.comment_id
1801 1801
1802 1802 if comment.immutable:
1803 1803 # don't allow deleting comments that are immutable
1804 1804 raise HTTPForbidden()
1805 1805
1806 1806 if pull_request.is_closed():
1807 1807 log.debug('comment: forbidden because pull request is closed')
1808 1808 raise HTTPForbidden()
1809 1809
1810 1810 if comment.pull_request.is_closed():
1811 1811 # don't allow deleting comments on closed pull request
1812 1812 raise HTTPForbidden()
1813 1813
1814 1814 is_repo_admin = h.HasRepoPermissionAny('repository.admin')(self.db_repo_name)
1815 1815 super_admin = h.HasPermissionAny('hg.admin')()
1816 1816 comment_owner = comment.author.user_id == self._rhodecode_user.user_id
1817 1817 is_repo_comment = comment.repo.repo_name == self.db_repo_name
1818 1818 comment_repo_admin = is_repo_admin and is_repo_comment
1819 1819
1820 1820 if super_admin or comment_owner or comment_repo_admin:
1821 1821 text = self.request.POST.get('text')
1822 1822 version = self.request.POST.get('version')
1823 1823 if text == comment.text:
1824 1824 log.warning(
1825 1825 'Comment(PR): '
1826 1826 'Trying to create new version '
1827 1827 'with the same comment body {}'.format(
1828 1828 comment_id,
1829 1829 )
1830 1830 )
1831 1831 raise HTTPNotFound()
1832 1832
1833 1833 if version.isdigit():
1834 1834 version = int(version)
1835 1835 else:
1836 1836 log.warning(
1837 1837 'Comment(PR): Wrong version type {} {} '
1838 1838 'for comment {}'.format(
1839 1839 version,
1840 1840 type(version),
1841 1841 comment_id,
1842 1842 )
1843 1843 )
1844 1844 raise HTTPNotFound()
1845 1845
1846 1846 try:
1847 1847 comment_history = CommentsModel().edit(
1848 1848 comment_id=comment_id,
1849 1849 text=text,
1850 1850 auth_user=self._rhodecode_user,
1851 1851 version=version,
1852 1852 )
1853 1853 except CommentVersionMismatch:
1854 1854 raise HTTPConflict()
1855 1855
1856 1856 if not comment_history:
1857 1857 raise HTTPNotFound()
1858 1858
1859 1859 Session().commit()
1860 1860 if not comment.draft:
1861 1861 PullRequestModel().trigger_pull_request_hook(
1862 1862 pull_request, self._rhodecode_user, 'comment_edit',
1863 1863 data={'comment': comment})
1864 1864
1865 1865 return {
1866 1866 'comment_history_id': comment_history.comment_history_id,
1867 1867 'comment_id': comment.comment_id,
1868 1868 'comment_version': comment_history.version,
1869 1869 'comment_author_username': comment_history.author.username,
1870 1870 'comment_author_gravatar': h.gravatar_url(comment_history.author.email, 16),
1871 1871 'comment_created_on': h.age_component(comment_history.created_on,
1872 1872 time_is_local=True),
1873 1873 }
1874 1874 else:
1875 1875 log.warning('No permissions for user %s to edit comment_id: %s',
1876 1876 self._rhodecode_db_user, comment_id)
1877 1877 raise HTTPNotFound()
@@ -1,127 +1,125 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2017-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import logging
22 22
23 23 from pyramid.httpexceptions import HTTPFound, HTTPNotFound
24 24
25 25 import formencode
26 26
27 27 from rhodecode.apps._base import RepoAppView
28 28 from rhodecode.lib import audit_logger
29 29 from rhodecode.lib import helpers as h
30 30 from rhodecode.lib.auth import (
31 31 LoginRequired, HasRepoPermissionAnyDecorator, CSRFRequired)
32 32 from rhodecode.model.forms import IssueTrackerPatternsForm
33 33 from rhodecode.model.meta import Session
34 34 from rhodecode.model.settings import SettingsModel
35 35
36 36 log = logging.getLogger(__name__)
37 37
38 38
39 39 class RepoSettingsIssueTrackersView(RepoAppView):
40 40 def load_default_context(self):
41 41 c = self._get_local_tmpl_context()
42
43
44 42 return c
45 43
46 44 @LoginRequired()
47 45 @HasRepoPermissionAnyDecorator('repository.admin')
48 46 def repo_issuetracker(self):
49 47 c = self.load_default_context()
50 48 c.active = 'issuetracker'
51 49 c.data = 'data'
52 50
53 51 c.settings_model = self.db_repo_patterns
54 52 c.global_patterns = c.settings_model.get_global_settings()
55 53 c.repo_patterns = c.settings_model.get_repo_settings()
56 54
57 55 return self._get_template_context(c)
58 56
59 57 @LoginRequired()
60 58 @HasRepoPermissionAnyDecorator('repository.admin')
61 59 @CSRFRequired()
62 60 def repo_issuetracker_test(self):
63 61 return h.urlify_commit_message(
64 62 self.request.POST.get('test_text', ''),
65 63 self.db_repo_name)
66 64
67 65 @LoginRequired()
68 66 @HasRepoPermissionAnyDecorator('repository.admin')
69 67 @CSRFRequired()
70 68 def repo_issuetracker_delete(self):
71 69 _ = self.request.translate
72 70 uid = self.request.POST.get('uid')
73 71 repo_settings = self.db_repo_patterns
74 72 try:
75 73 repo_settings.delete_entries(uid)
76 74 except Exception:
77 75 h.flash(_('Error occurred during deleting issue tracker entry'),
78 76 category='error')
79 77 raise HTTPNotFound()
80 78
81 79 SettingsModel().invalidate_settings_cache()
82 80 h.flash(_('Removed issue tracker entry.'), category='success')
83 81
84 82 return {'deleted': uid}
85 83
86 84 def _update_patterns(self, form, repo_settings):
87 85 for uid in form['delete_patterns']:
88 86 repo_settings.delete_entries(uid)
89 87
90 88 for pattern_data in form['patterns']:
91 89 for setting_key, pattern, type_ in pattern_data:
92 90 sett = repo_settings.create_or_update_setting(
93 91 setting_key, pattern.strip(), type_)
94 92 Session().add(sett)
95 93
96 94 Session().commit()
97 95
98 96 @LoginRequired()
99 97 @HasRepoPermissionAnyDecorator('repository.admin')
100 98 @CSRFRequired()
101 99 def repo_issuetracker_update(self):
102 100 _ = self.request.translate
103 101 # Save inheritance
104 102 repo_settings = self.db_repo_patterns
105 103 inherited = (
106 104 self.request.POST.get('inherit_global_issuetracker') == "inherited")
107 105 repo_settings.inherit_global_settings = inherited
108 106 Session().commit()
109 107
110 108 try:
111 109 form = IssueTrackerPatternsForm(self.request.translate)().to_python(self.request.POST)
112 110 except formencode.Invalid as errors:
113 111 log.exception('Failed to add new pattern')
114 112 error = errors
115 113 h.flash(_('Invalid issue tracker pattern: {}'.format(error)),
116 114 category='error')
117 115 raise HTTPFound(
118 116 h.route_path('edit_repo_issuetracker',
119 117 repo_name=self.db_repo_name))
120 118
121 119 if form:
122 120 self._update_patterns(form, repo_settings)
123 121
124 122 h.flash(_('Updated issue tracker entries'), category='success')
125 123 raise HTTPFound(
126 124 h.route_path('edit_repo_issuetracker', repo_name=self.db_repo_name))
127 125
@@ -1,290 +1,290 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import logging
22 22 import string
23 23 import time
24 24
25 25 import rhodecode
26 26
27 27
28 28
29 29 from rhodecode.lib.view_utils import get_format_ref_id
30 30 from rhodecode.apps._base import RepoAppView
31 31 from rhodecode.config.conf import (LANGUAGES_EXTENSIONS_MAP)
32 32 from rhodecode.lib import helpers as h, rc_cache
33 33 from rhodecode.lib.utils2 import safe_str, safe_int
34 34 from rhodecode.lib.auth import LoginRequired, HasRepoPermissionAnyDecorator
35 35 from rhodecode.lib.ext_json import json
36 36 from rhodecode.lib.vcs.backends.base import EmptyCommit
37 37 from rhodecode.lib.vcs.exceptions import (
38 38 CommitError, EmptyRepositoryError, CommitDoesNotExistError)
39 39 from rhodecode.model.db import Statistics, CacheKey, User
40 40 from rhodecode.model.meta import Session
41 41 from rhodecode.model.scm import ScmModel
42 42
43 43 log = logging.getLogger(__name__)
44 44
45 45
46 46 class RepoSummaryView(RepoAppView):
47 47
48 48 def load_default_context(self):
49 49 c = self._get_local_tmpl_context(include_app_defaults=True)
50 50 c.rhodecode_repo = None
51 51 if not c.repository_requirements_missing:
52 52 c.rhodecode_repo = self.rhodecode_vcs_repo
53 53 return c
54 54
55 55 def _load_commits_context(self, c):
56 56 p = safe_int(self.request.GET.get('page'), 1)
57 57 size = safe_int(self.request.GET.get('size'), 10)
58 58
59 59 def url_generator(page_num):
60 60 query_params = {
61 61 'page': page_num,
62 62 'size': size
63 63 }
64 64 return h.route_path(
65 65 'repo_summary_commits',
66 66 repo_name=c.rhodecode_db_repo.repo_name, _query=query_params)
67 67
68 68 pre_load = self.get_commit_preload_attrs()
69 69
70 70 try:
71 71 collection = self.rhodecode_vcs_repo.get_commits(
72 72 pre_load=pre_load, translate_tags=False)
73 73 except EmptyRepositoryError:
74 74 collection = self.rhodecode_vcs_repo
75 75
76 76 c.repo_commits = h.RepoPage(
77 77 collection, page=p, items_per_page=size, url_maker=url_generator)
78 78 page_ids = [x.raw_id for x in c.repo_commits]
79 79 c.comments = self.db_repo.get_comments(page_ids)
80 80 c.statuses = self.db_repo.statuses(page_ids)
81 81
82 82 def _prepare_and_set_clone_url(self, c):
83 83 username = ''
84 84 if self._rhodecode_user.username != User.DEFAULT_USER:
85 85 username = safe_str(self._rhodecode_user.username)
86 86
87 87 _def_clone_uri = c.clone_uri_tmpl
88 88 _def_clone_uri_id = c.clone_uri_id_tmpl
89 89 _def_clone_uri_ssh = c.clone_uri_ssh_tmpl
90 90
91 91 c.clone_repo_url = self.db_repo.clone_url(
92 92 user=username, uri_tmpl=_def_clone_uri)
93 93 c.clone_repo_url_id = self.db_repo.clone_url(
94 94 user=username, uri_tmpl=_def_clone_uri_id)
95 95 c.clone_repo_url_ssh = self.db_repo.clone_url(
96 96 uri_tmpl=_def_clone_uri_ssh, ssh=True)
97 97
98 98 @LoginRequired()
99 99 @HasRepoPermissionAnyDecorator(
100 100 'repository.read', 'repository.write', 'repository.admin')
101 101 def summary_commits(self):
102 102 c = self.load_default_context()
103 103 self._prepare_and_set_clone_url(c)
104 104 self._load_commits_context(c)
105 105 return self._get_template_context(c)
106 106
107 107 @LoginRequired()
108 108 @HasRepoPermissionAnyDecorator(
109 109 'repository.read', 'repository.write', 'repository.admin')
110 110 def summary(self):
111 111 c = self.load_default_context()
112 112
113 113 # Prepare the clone URL
114 114 self._prepare_and_set_clone_url(c)
115 115
116 116 # If enabled, get statistics data
117 117 c.show_stats = bool(self.db_repo.enable_statistics)
118 118
119 119 stats = Session().query(Statistics) \
120 120 .filter(Statistics.repository == self.db_repo) \
121 121 .scalar()
122 122
123 123 c.stats_percentage = 0
124 124
125 125 if stats and stats.languages:
126 126 c.no_data = False is self.db_repo.enable_statistics
127 127 lang_stats_d = json.loads(stats.languages)
128 128
129 129 # Sort first by decreasing count and second by the file extension,
130 130 # so we have a consistent output.
131 131 lang_stats_items = sorted(lang_stats_d.items(),
132 132 key=lambda k: (-k[1], k[0]))[:10]
133 133 lang_stats = [(x, {"count": y,
134 134 "desc": LANGUAGES_EXTENSIONS_MAP.get(x)})
135 135 for x, y in lang_stats_items]
136 136
137 137 c.trending_languages = json.dumps(lang_stats)
138 138 else:
139 139 c.no_data = True
140 140 c.trending_languages = json.dumps({})
141 141
142 142 scm_model = ScmModel()
143 143 c.enable_downloads = self.db_repo.enable_downloads
144 144 c.repository_followers = scm_model.get_followers(self.db_repo)
145 145 c.repository_forks = scm_model.get_forks(self.db_repo)
146 146
147 147 # first interaction with the VCS instance after here...
148 148 if c.repository_requirements_missing:
149 149 self.request.override_renderer = \
150 150 'rhodecode:templates/summary/missing_requirements.mako'
151 151 return self._get_template_context(c)
152 152
153 153 c.readme_data, c.readme_file = \
154 154 self._get_readme_data(self.db_repo, c.visual.default_renderer)
155 155
156 156 # loads the summary commits template context
157 157 self._load_commits_context(c)
158 158
159 159 return self._get_template_context(c)
160 160
161 161 @LoginRequired()
162 162 @HasRepoPermissionAnyDecorator(
163 163 'repository.read', 'repository.write', 'repository.admin')
164 164 def repo_stats(self):
165 165 show_stats = bool(self.db_repo.enable_statistics)
166 166 repo_id = self.db_repo.repo_id
167 167
168 168 landing_commit = self.db_repo.get_landing_commit()
169 169 if isinstance(landing_commit, EmptyCommit):
170 170 return {'size': 0, 'code_stats': {}}
171 171
172 172 cache_seconds = safe_int(rhodecode.CONFIG.get('rc_cache.cache_repo.expiration_time'))
173 173 cache_on = cache_seconds > 0
174 174
175 175 log.debug(
176 176 'Computing REPO STATS for repo_id %s commit_id `%s` '
177 177 'with caching: %s[TTL: %ss]' % (
178 178 repo_id, landing_commit, cache_on, cache_seconds or 0))
179 179
180 180 cache_namespace_uid = 'cache_repo.{}'.format(repo_id)
181 181 region = rc_cache.get_or_create_region('cache_repo', cache_namespace_uid)
182 182
183 183 @region.conditional_cache_on_arguments(namespace=cache_namespace_uid,
184 184 condition=cache_on)
185 185 def compute_stats(repo_id, commit_id, _show_stats):
186 186 code_stats = {}
187 187 size = 0
188 188 try:
189 189 commit = self.db_repo.get_commit(commit_id)
190 190
191 191 for node in commit.get_filenodes_generator():
192 192 size += node.size
193 193 if not _show_stats:
194 194 continue
195 ext = string.lower(node.extension)
195 ext = node.extension.lower()
196 196 ext_info = LANGUAGES_EXTENSIONS_MAP.get(ext)
197 197 if ext_info:
198 198 if ext in code_stats:
199 199 code_stats[ext]['count'] += 1
200 200 else:
201 201 code_stats[ext] = {"count": 1, "desc": ext_info}
202 202 except (EmptyRepositoryError, CommitDoesNotExistError):
203 203 pass
204 204 return {'size': h.format_byte_size_binary(size),
205 205 'code_stats': code_stats}
206 206
207 207 stats = compute_stats(self.db_repo.repo_id, landing_commit.raw_id, show_stats)
208 208 return stats
209 209
210 210 @LoginRequired()
211 211 @HasRepoPermissionAnyDecorator(
212 212 'repository.read', 'repository.write', 'repository.admin')
213 213 def repo_refs_data(self):
214 214 _ = self.request.translate
215 215 self.load_default_context()
216 216
217 217 repo = self.rhodecode_vcs_repo
218 218 refs_to_create = [
219 219 (_("Branch"), repo.branches, 'branch'),
220 220 (_("Tag"), repo.tags, 'tag'),
221 221 (_("Bookmark"), repo.bookmarks, 'book'),
222 222 ]
223 223 res = self._create_reference_data(repo, self.db_repo_name, refs_to_create)
224 224 data = {
225 225 'more': False,
226 226 'results': res
227 227 }
228 228 return data
229 229
230 230 @LoginRequired()
231 231 @HasRepoPermissionAnyDecorator(
232 232 'repository.read', 'repository.write', 'repository.admin')
233 233 def repo_refs_changelog_data(self):
234 234 _ = self.request.translate
235 235 self.load_default_context()
236 236
237 237 repo = self.rhodecode_vcs_repo
238 238
239 239 refs_to_create = [
240 240 (_("Branches"), repo.branches, 'branch'),
241 241 (_("Closed branches"), repo.branches_closed, 'branch_closed'),
242 242 # TODO: enable when vcs can handle bookmarks filters
243 243 # (_("Bookmarks"), repo.bookmarks, "book"),
244 244 ]
245 245 res = self._create_reference_data(
246 246 repo, self.db_repo_name, refs_to_create)
247 247 data = {
248 248 'more': False,
249 249 'results': res
250 250 }
251 251 return data
252 252
253 253 def _create_reference_data(self, repo, full_repo_name, refs_to_create):
254 254 format_ref_id = get_format_ref_id(repo)
255 255
256 256 result = []
257 257 for title, refs, ref_type in refs_to_create:
258 258 if refs:
259 259 result.append({
260 260 'text': title,
261 261 'children': self._create_reference_items(
262 262 repo, full_repo_name, refs, ref_type,
263 263 format_ref_id),
264 264 })
265 265 return result
266 266
267 267 def _create_reference_items(self, repo, full_repo_name, refs, ref_type, format_ref_id):
268 268 result = []
269 269 is_svn = h.is_svn(repo)
270 270 for ref_name, raw_id in refs.items():
271 271 files_url = self._create_files_url(
272 272 repo, full_repo_name, ref_name, raw_id, is_svn)
273 273 result.append({
274 274 'text': ref_name,
275 275 'id': format_ref_id(ref_name, raw_id),
276 276 'raw_id': raw_id,
277 277 'type': ref_type,
278 278 'files_url': files_url,
279 279 'idx': 0,
280 280 })
281 281 return result
282 282
283 283 def _create_files_url(self, repo, full_repo_name, ref_name, raw_id, is_svn):
284 284 use_commit_id = '/' in ref_name or is_svn
285 285 return h.route_path(
286 286 'repo_files',
287 287 repo_name=full_repo_name,
288 288 f_path=ref_name if is_svn else '',
289 289 commit_id=raw_id if use_commit_id else ref_name,
290 290 _query=dict(at=ref_name))
@@ -1,172 +1,172 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2012-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 """
22 22 RhodeCode authentication library for PAM
23 23 """
24 24
25 25 import colander
26 26 import grp
27 27 import logging
28 28 import pam
29 29 import pwd
30 30 import re
31 31 import socket
32 32
33 33 from rhodecode.translation import _
34 34 from rhodecode.authentication.base import (
35 35 RhodeCodeExternalAuthPlugin, hybrid_property)
36 36 from rhodecode.authentication.schema import AuthnPluginSettingsSchemaBase
37 37 from rhodecode.authentication.routes import AuthnPluginResourceBase
38 38 from rhodecode.lib.colander_utils import strip_whitespace
39 39
40 40 log = logging.getLogger(__name__)
41 41
42 42
43 43 def plugin_factory(plugin_id, *args, **kwargs):
44 44 """
45 45 Factory function that is called during plugin discovery.
46 46 It returns the plugin instance.
47 47 """
48 48 plugin = RhodeCodeAuthPlugin(plugin_id)
49 49 return plugin
50 50
51 51
52 52 class PamAuthnResource(AuthnPluginResourceBase):
53 53 pass
54 54
55 55
56 56 class PamSettingsSchema(AuthnPluginSettingsSchemaBase):
57 57 service = colander.SchemaNode(
58 58 colander.String(),
59 59 default='login',
60 60 description=_('PAM service name to use for authentication.'),
61 61 preparer=strip_whitespace,
62 62 title=_('PAM service name'),
63 63 widget='string')
64 64 gecos = colander.SchemaNode(
65 65 colander.String(),
66 default='(?P<last_name>.+),\s*(?P<first_name>\w+)',
66 default=r'(?P<last_name>.+),\s*(?P<first_name>\w+)',
67 67 description=_('Regular expression for extracting user name/email etc. '
68 68 'from Unix userinfo.'),
69 69 preparer=strip_whitespace,
70 70 title=_('Gecos Regex'),
71 71 widget='string')
72 72
73 73
74 74 class RhodeCodeAuthPlugin(RhodeCodeExternalAuthPlugin):
75 75 uid = 'pam'
76 76 # PAM authentication can be slow. Repository operations involve a lot of
77 77 # auth calls. Little caching helps speedup push/pull operations significantly
78 78 AUTH_CACHE_TTL = 4
79 79
80 80 def includeme(self, config):
81 81 config.add_authn_plugin(self)
82 82 config.add_authn_resource(self.get_id(), PamAuthnResource(self))
83 83 config.add_view(
84 84 'rhodecode.authentication.views.AuthnPluginViewBase',
85 85 attr='settings_get',
86 86 renderer='rhodecode:templates/admin/auth/plugin_settings.mako',
87 87 request_method='GET',
88 88 route_name='auth_home',
89 89 context=PamAuthnResource)
90 90 config.add_view(
91 91 'rhodecode.authentication.views.AuthnPluginViewBase',
92 92 attr='settings_post',
93 93 renderer='rhodecode:templates/admin/auth/plugin_settings.mako',
94 94 request_method='POST',
95 95 route_name='auth_home',
96 96 context=PamAuthnResource)
97 97
98 98 def get_display_name(self, load_from_settings=False):
99 99 return _('PAM')
100 100
101 101 @classmethod
102 102 def docs(cls):
103 103 return "https://docs.rhodecode.com/RhodeCode-Enterprise/auth/auth-pam.html"
104 104
105 105 @hybrid_property
106 106 def name(self):
107 107 return u"pam"
108 108
109 109 def get_settings_schema(self):
110 110 return PamSettingsSchema()
111 111
112 112 def use_fake_password(self):
113 113 return True
114 114
115 115 def auth(self, userobj, username, password, settings, **kwargs):
116 116 if not username or not password:
117 117 log.debug('Empty username or password skipping...')
118 118 return None
119 119 _pam = pam.pam()
120 120 auth_result = _pam.authenticate(username, password, settings["service"])
121 121
122 122 if not auth_result:
123 123 log.error("PAM was unable to authenticate user: %s", username)
124 124 return None
125 125
126 126 log.debug('Got PAM response %s', auth_result)
127 127
128 128 # old attrs fetched from RhodeCode database
129 129 default_email = "%s@%s" % (username, socket.gethostname())
130 130 admin = getattr(userobj, 'admin', False)
131 131 active = getattr(userobj, 'active', True)
132 132 email = getattr(userobj, 'email', '') or default_email
133 133 username = getattr(userobj, 'username', username)
134 134 firstname = getattr(userobj, 'firstname', '')
135 135 lastname = getattr(userobj, 'lastname', '')
136 136 extern_type = getattr(userobj, 'extern_type', '')
137 137
138 138 user_attrs = {
139 139 'username': username,
140 140 'firstname': firstname,
141 141 'lastname': lastname,
142 142 'groups': [g.gr_name for g in grp.getgrall()
143 143 if username in g.gr_mem],
144 144 'user_group_sync': True,
145 145 'email': email,
146 146 'admin': admin,
147 147 'active': active,
148 148 'active_from_extern': None,
149 149 'extern_name': username,
150 150 'extern_type': extern_type,
151 151 }
152 152
153 153 try:
154 154 user_data = pwd.getpwnam(username)
155 155 regex = settings["gecos"]
156 156 match = re.search(regex, user_data.pw_gecos)
157 157 if match:
158 158 user_attrs["firstname"] = match.group('first_name')
159 159 user_attrs["lastname"] = match.group('last_name')
160 160 except Exception:
161 161 log.warning("Cannot extract additional info for PAM user")
162 162 pass
163 163
164 164 log.debug("pamuser: %s", user_attrs)
165 165 log.info('user `%s` authenticated correctly', user_attrs['username'],
166 166 extra={"action": "user_auth_ok", "auth_module": "auth_pam", "username": user_attrs["username"]})
167 167 return user_attrs
168 168
169 169
170 170 def includeme(config):
171 171 plugin_id = 'egg:rhodecode-enterprise-ce#{}'.format(RhodeCodeAuthPlugin.uid)
172 172 plugin_factory(plugin_id).includeme(config)
@@ -1,89 +1,90 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import os
22 22 import logging
23 23 import rhodecode
24 import collections
24 25
25 26 from rhodecode.config import utils
26 27
27 28 from rhodecode.lib.utils import load_rcextensions
28 29 from rhodecode.lib.utils2 import str2bool
29 30 from rhodecode.lib.vcs import connect_vcs
30 31
31 32 log = logging.getLogger(__name__)
32 33
33 34
34 35 def load_pyramid_environment(global_config, settings):
35 36 # Some parts of the code expect a merge of global and app settings.
36 37 settings_merged = global_config.copy()
37 38 settings_merged.update(settings)
38 39
39 40 # TODO(marcink): probably not required anymore
40 41 # configure channelstream,
41 42 settings_merged['channelstream_config'] = {
42 43 'enabled': str2bool(settings_merged.get('channelstream.enabled', False)),
43 44 'server': settings_merged.get('channelstream.server'),
44 45 'secret': settings_merged.get('channelstream.secret')
45 46 }
46 47
47 48 # If this is a test run we prepare the test environment like
48 49 # creating a test database, test search index and test repositories.
49 50 # This has to be done before the database connection is initialized.
50 51 if settings['is_test']:
51 52 rhodecode.is_test = True
52 53 rhodecode.disable_error_handler = True
53 54 from rhodecode import authentication
54 55 authentication.plugin_default_auth_ttl = 0
55 56
56 57 utils.initialize_test_environment(settings_merged)
57 58
58 59 # Initialize the database connection.
59 60 utils.initialize_database(settings_merged)
60 61
61 62 load_rcextensions(root_path=settings_merged['here'])
62 63
63 64 # Limit backends to `vcs.backends` from configuration, and preserve the order
64 65 for alias in rhodecode.BACKENDS.keys():
65 66 if alias not in settings['vcs.backends']:
66 67 del rhodecode.BACKENDS[alias]
67 68
68 69 _sorted_backend = sorted(rhodecode.BACKENDS.items(),
69 70 key=lambda item: settings['vcs.backends'].index(item[0]))
70 rhodecode.BACKENDS = rhodecode.OrderedDict(_sorted_backend)
71 rhodecode.BACKENDS = collections.OrderedDict(_sorted_backend)
71 72
72 73 log.info('Enabled VCS backends: %s', rhodecode.BACKENDS.keys())
73 74
74 75 # initialize vcs client and optionally run the server if enabled
75 76 vcs_server_uri = settings['vcs.server']
76 77 vcs_server_enabled = settings['vcs.server.enable']
77 78
78 79 utils.configure_vcs(settings)
79 80
80 81 # Store the settings to make them available to other modules.
81 82
82 83 rhodecode.PYRAMID_SETTINGS = settings_merged
83 84 rhodecode.CONFIG = settings_merged
84 85 rhodecode.CONFIG['default_user_id'] = utils.get_default_user_id()
85 86
86 87 if vcs_server_enabled:
87 88 connect_vcs(vcs_server_uri, utils.get_vcs_server_protocol(settings))
88 89 else:
89 90 log.warning('vcs-server not enabled, vcs connection unavailable')
@@ -1,282 +1,280 b''
1 1 # -*- coding: utf-8 -*-
2 2 """
3 3 Adapters
4 4 --------
5 5
6 6 .. contents::
7 7 :backlinks: none
8 8
9 9 The :func:`authomatic.login` function needs access to functionality like
10 10 getting the **URL** of the handler where it is being called, getting the
11 11 **request params** and **cookies** and **writing the body**, **headers**
12 12 and **status** to the response.
13 13
14 14 Since implementation of these features varies across Python web frameworks,
15 15 the Authomatic library uses **adapters** to unify these differences into a
16 16 single interface.
17 17
18 18 Available Adapters
19 19 ^^^^^^^^^^^^^^^^^^
20 20
21 21 If you are missing an adapter for the framework of your choice, please
22 22 open an `enhancement issue <https://github.com/authomatic/authomatic/issues>`_
23 23 or consider a contribution to this module by
24 24 :ref:`implementing <implement_adapters>` one by yourself.
25 25 Its very easy and shouldn't take you more than a few minutes.
26 26
27 27 .. autoclass:: DjangoAdapter
28 28 :members:
29 29
30 30 .. autoclass:: Webapp2Adapter
31 31 :members:
32 32
33 33 .. autoclass:: WebObAdapter
34 34 :members:
35 35
36 36 .. autoclass:: WerkzeugAdapter
37 37 :members:
38 38
39 39 .. _implement_adapters:
40 40
41 41 Implementing an Adapter
42 42 ^^^^^^^^^^^^^^^^^^^^^^^
43 43
44 44 Implementing an adapter for a Python web framework is pretty easy.
45 45
46 46 Do it by subclassing the :class:`.BaseAdapter` abstract class.
47 47 There are only **six** members that you need to implement.
48 48
49 49 Moreover if your framework is based on the |webob|_ or |werkzeug|_ package
50 50 you can subclass the :class:`.WebObAdapter` or :class:`.WerkzeugAdapter`
51 51 respectively.
52 52
53 53 .. autoclass:: BaseAdapter
54 54 :members:
55 55
56 56 """
57 57
58 58 import abc
59 59 from authomatic.core import Response
60 60
61 61
62 class BaseAdapter(object):
62 class BaseAdapter(object, metaclass=abc.ABCMeta):
63 63 """
64 64 Base class for platform adapters.
65 65
66 66 Defines common interface for WSGI framework specific functionality.
67 67
68 68 """
69 69
70 __metaclass__ = abc.ABCMeta
71
72 70 @abc.abstractproperty
73 71 def params(self):
74 72 """
75 73 Must return a :class:`dict` of all request parameters of any HTTP
76 74 method.
77 75
78 76 :returns:
79 77 :class:`dict`
80 78
81 79 """
82 80
83 81 @abc.abstractproperty
84 82 def url(self):
85 83 """
86 84 Must return the url of the actual request including path but without
87 85 query and fragment.
88 86
89 87 :returns:
90 88 :class:`str`
91 89
92 90 """
93 91
94 92 @abc.abstractproperty
95 93 def cookies(self):
96 94 """
97 95 Must return cookies as a :class:`dict`.
98 96
99 97 :returns:
100 98 :class:`dict`
101 99
102 100 """
103 101
104 102 @abc.abstractmethod
105 103 def write(self, value):
106 104 """
107 105 Must write specified value to response.
108 106
109 107 :param str value:
110 108 String to be written to response.
111 109
112 110 """
113 111
114 112 @abc.abstractmethod
115 113 def set_header(self, key, value):
116 114 """
117 115 Must set response headers to ``Key: value``.
118 116
119 117 :param str key:
120 118 Header name.
121 119
122 120 :param str value:
123 121 Header value.
124 122
125 123 """
126 124
127 125 @abc.abstractmethod
128 126 def set_status(self, status):
129 127 """
130 128 Must set the response status e.g. ``'302 Found'``.
131 129
132 130 :param str status:
133 131 The HTTP response status.
134 132
135 133 """
136 134
137 135
138 136 class DjangoAdapter(BaseAdapter):
139 137 """
140 138 Adapter for the |django|_ framework.
141 139 """
142 140
143 141 def __init__(self, request, response):
144 142 """
145 143 :param request:
146 144 An instance of the :class:`django.http.HttpRequest` class.
147 145
148 146 :param response:
149 147 An instance of the :class:`django.http.HttpResponse` class.
150 148 """
151 149 self.request = request
152 150 self.response = response
153 151
154 152 @property
155 153 def params(self):
156 154 params = {}
157 155 params.update(self.request.GET.dict())
158 156 params.update(self.request.POST.dict())
159 157 return params
160 158
161 159 @property
162 160 def url(self):
163 161 return self.request.build_absolute_uri(self.request.path)
164 162
165 163 @property
166 164 def cookies(self):
167 165 return dict(self.request.COOKIES)
168 166
169 167 def write(self, value):
170 168 self.response.write(value)
171 169
172 170 def set_header(self, key, value):
173 171 self.response[key] = value
174 172
175 173 def set_status(self, status):
176 174 status_code, reason = status.split(' ', 1)
177 175 self.response.status_code = int(status_code)
178 176
179 177
180 178 class WebObAdapter(BaseAdapter):
181 179 """
182 180 Adapter for the |webob|_ package.
183 181 """
184 182
185 183 def __init__(self, request, response):
186 184 """
187 185 :param request:
188 186 A |webob|_ :class:`Request` instance.
189 187
190 188 :param response:
191 189 A |webob|_ :class:`Response` instance.
192 190 """
193 191 self.request = request
194 192 self.response = response
195 193
196 194 # =========================================================================
197 195 # Request
198 196 # =========================================================================
199 197
200 198 @property
201 199 def url(self):
202 200 return self.request.path_url
203 201
204 202 @property
205 203 def params(self):
206 204 return dict(self.request.params)
207 205
208 206 @property
209 207 def cookies(self):
210 208 return dict(self.request.cookies)
211 209
212 210 # =========================================================================
213 211 # Response
214 212 # =========================================================================
215 213
216 214 def write(self, value):
217 215 self.response.write(value)
218 216
219 217 def set_header(self, key, value):
220 218 self.response.headers[key] = str(value)
221 219
222 220 def set_status(self, status):
223 221 self.response.status = status
224 222
225 223
226 224 class Webapp2Adapter(WebObAdapter):
227 225 """
228 226 Adapter for the |webapp2|_ framework.
229 227
230 228 Inherits from the :class:`.WebObAdapter`.
231 229
232 230 """
233 231
234 232 def __init__(self, handler):
235 233 """
236 234 :param handler:
237 235 A :class:`webapp2.RequestHandler` instance.
238 236 """
239 237 self.request = handler.request
240 238 self.response = handler.response
241 239
242 240
243 241 class WerkzeugAdapter(BaseAdapter):
244 242 """
245 243 Adapter for |flask|_ and other |werkzeug|_ based frameworks.
246 244
247 245 Thanks to `Mark Steve Samson <http://marksteve.com>`_.
248 246
249 247 """
250 248
251 249 @property
252 250 def params(self):
253 251 return self.request.args
254 252
255 253 @property
256 254 def url(self):
257 255 return self.request.base_url
258 256
259 257 @property
260 258 def cookies(self):
261 259 return self.request.cookies
262 260
263 261 def __init__(self, request, response):
264 262 """
265 263 :param request:
266 264 Instance of the :class:`werkzeug.wrappers.Request` class.
267 265
268 266 :param response:
269 267 Instance of the :class:`werkzeug.wrappers.Response` class.
270 268 """
271 269
272 270 self.request = request
273 271 self.response = response
274 272
275 273 def write(self, value):
276 274 self.response.data = self.response.data + value
277 275
278 276 def set_header(self, key, value):
279 277 self.response.headers[key] = value
280 278
281 279 def set_status(self, status):
282 280 self.response.status = status
@@ -1,305 +1,305 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2017-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import logging
22 22 import datetime
23 23
24 24 from rhodecode.lib.jsonalchemy import JsonRaw
25 25 from rhodecode.model import meta
26 26 from rhodecode.model.db import User, UserLog, Repository
27 27
28 28
29 29 log = logging.getLogger(__name__)
30 30
31 31 # action as key, and expected action_data as value
32 32 ACTIONS_V1 = {
33 33 'user.login.success': {'user_agent': ''},
34 34 'user.login.failure': {'user_agent': ''},
35 35 'user.logout': {'user_agent': ''},
36 36 'user.register': {},
37 37 'user.password.reset_request': {},
38 38 'user.push': {'user_agent': '', 'commit_ids': []},
39 39 'user.pull': {'user_agent': ''},
40 40
41 41 'user.create': {'data': {}},
42 42 'user.delete': {'old_data': {}},
43 43 'user.edit': {'old_data': {}},
44 44 'user.edit.permissions': {},
45 45 'user.edit.ip.add': {'ip': {}, 'user': {}},
46 46 'user.edit.ip.delete': {'ip': {}, 'user': {}},
47 47 'user.edit.token.add': {'token': {}, 'user': {}},
48 48 'user.edit.token.delete': {'token': {}, 'user': {}},
49 49 'user.edit.email.add': {'email': ''},
50 50 'user.edit.email.delete': {'email': ''},
51 51 'user.edit.ssh_key.add': {'token': {}, 'user': {}},
52 52 'user.edit.ssh_key.delete': {'token': {}, 'user': {}},
53 53 'user.edit.password_reset.enabled': {},
54 54 'user.edit.password_reset.disabled': {},
55 55
56 56 'user_group.create': {'data': {}},
57 57 'user_group.delete': {'old_data': {}},
58 58 'user_group.edit': {'old_data': {}},
59 59 'user_group.edit.permissions': {},
60 60 'user_group.edit.member.add': {'user': {}},
61 61 'user_group.edit.member.delete': {'user': {}},
62 62
63 63 'repo.create': {'data': {}},
64 64 'repo.fork': {'data': {}},
65 65 'repo.edit': {'old_data': {}},
66 66 'repo.edit.permissions': {},
67 67 'repo.edit.permissions.branch': {},
68 68 'repo.archive': {'old_data': {}},
69 69 'repo.delete': {'old_data': {}},
70 70
71 71 'repo.archive.download': {'user_agent': '', 'archive_name': '',
72 72 'archive_spec': '', 'archive_cached': ''},
73 73
74 74 'repo.permissions.branch_rule.create': {},
75 75 'repo.permissions.branch_rule.edit': {},
76 76 'repo.permissions.branch_rule.delete': {},
77 77
78 78 'repo.pull_request.create': '',
79 79 'repo.pull_request.edit': '',
80 80 'repo.pull_request.delete': '',
81 81 'repo.pull_request.close': '',
82 82 'repo.pull_request.merge': '',
83 83 'repo.pull_request.vote': '',
84 84 'repo.pull_request.comment.create': '',
85 85 'repo.pull_request.comment.edit': '',
86 86 'repo.pull_request.comment.delete': '',
87 87
88 88 'repo.pull_request.reviewer.add': '',
89 89 'repo.pull_request.reviewer.delete': '',
90 90
91 91 'repo.pull_request.observer.add': '',
92 92 'repo.pull_request.observer.delete': '',
93 93
94 94 'repo.commit.strip': {'commit_id': ''},
95 95 'repo.commit.comment.create': {'data': {}},
96 96 'repo.commit.comment.delete': {'data': {}},
97 97 'repo.commit.comment.edit': {'data': {}},
98 98 'repo.commit.vote': '',
99 99
100 100 'repo.artifact.add': '',
101 101 'repo.artifact.delete': '',
102 102
103 103 'repo_group.create': {'data': {}},
104 104 'repo_group.edit': {'old_data': {}},
105 105 'repo_group.edit.permissions': {},
106 106 'repo_group.delete': {'old_data': {}},
107 107 }
108 108
109 109 ACTIONS = ACTIONS_V1
110 110
111 111 SOURCE_WEB = 'source_web'
112 112 SOURCE_API = 'source_api'
113 113
114 114
115 115 class UserWrap(object):
116 116 """
117 117 Fake object used to imitate AuthUser
118 118 """
119 119
120 120 def __init__(self, user_id=None, username=None, ip_addr=None):
121 121 self.user_id = user_id
122 122 self.username = username
123 123 self.ip_addr = ip_addr
124 124
125 125
126 126 class RepoWrap(object):
127 127 """
128 128 Fake object used to imitate RepoObject that audit logger requires
129 129 """
130 130
131 131 def __init__(self, repo_id=None, repo_name=None):
132 132 self.repo_id = repo_id
133 133 self.repo_name = repo_name
134 134
135 135
136 136 def _store_log(action_name, action_data, user_id, username, user_data,
137 137 ip_address, repository_id, repository_name):
138 138 user_log = UserLog()
139 139 user_log.version = UserLog.VERSION_2
140 140
141 141 user_log.action = action_name
142 user_log.action_data = action_data or JsonRaw(u'{}')
142 user_log.action_data = action_data or JsonRaw('{}')
143 143
144 144 user_log.user_ip = ip_address
145 145
146 146 user_log.user_id = user_id
147 147 user_log.username = username
148 user_log.user_data = user_data or JsonRaw(u'{}')
148 user_log.user_data = user_data or JsonRaw('{}')
149 149
150 150 user_log.repository_id = repository_id
151 151 user_log.repository_name = repository_name
152 152
153 153 user_log.action_date = datetime.datetime.now()
154 154
155 155 return user_log
156 156
157 157
158 158 def store_web(*args, **kwargs):
159 159 action_data = {}
160 160 org_action_data = kwargs.pop('action_data', {})
161 161 action_data.update(org_action_data)
162 162 action_data['source'] = SOURCE_WEB
163 163 kwargs['action_data'] = action_data
164 164
165 165 return store(*args, **kwargs)
166 166
167 167
168 168 def store_api(*args, **kwargs):
169 169 action_data = {}
170 170 org_action_data = kwargs.pop('action_data', {})
171 171 action_data.update(org_action_data)
172 172 action_data['source'] = SOURCE_API
173 173 kwargs['action_data'] = action_data
174 174
175 175 return store(*args, **kwargs)
176 176
177 177
178 178 def store(action, user, action_data=None, user_data=None, ip_addr=None,
179 179 repo=None, sa_session=None, commit=False):
180 180 """
181 181 Audit logger for various actions made by users, typically this
182 182 results in a call such::
183 183
184 184 from rhodecode.lib import audit_logger
185 185
186 186 audit_logger.store(
187 187 'repo.edit', user=self._rhodecode_user)
188 188 audit_logger.store(
189 189 'repo.delete', action_data={'data': repo_data},
190 190 user=audit_logger.UserWrap(username='itried-login', ip_addr='8.8.8.8'))
191 191
192 192 # repo action
193 193 audit_logger.store(
194 194 'repo.delete',
195 195 user=audit_logger.UserWrap(username='itried-login', ip_addr='8.8.8.8'),
196 196 repo=audit_logger.RepoWrap(repo_name='some-repo'))
197 197
198 198 # repo action, when we know and have the repository object already
199 199 audit_logger.store(
200 200 'repo.delete', action_data={'source': audit_logger.SOURCE_WEB, },
201 201 user=self._rhodecode_user,
202 202 repo=repo_object)
203 203
204 204 # alternative wrapper to the above
205 205 audit_logger.store_web(
206 206 'repo.delete', action_data={},
207 207 user=self._rhodecode_user,
208 208 repo=repo_object)
209 209
210 210 # without an user ?
211 211 audit_logger.store(
212 212 'user.login.failure',
213 213 user=audit_logger.UserWrap(
214 214 username=self.request.params.get('username'),
215 215 ip_addr=self.request.remote_addr))
216 216
217 217 """
218 218 from rhodecode.lib.utils2 import safe_unicode
219 219 from rhodecode.lib.auth import AuthUser
220 220
221 221 action_spec = ACTIONS.get(action, None)
222 222 if action_spec is None:
223 223 raise ValueError('Action `{}` is not supported'.format(action))
224 224
225 225 if not sa_session:
226 226 sa_session = meta.Session()
227 227
228 228 try:
229 229 username = getattr(user, 'username', None)
230 230 if not username:
231 231 pass
232 232
233 233 user_id = getattr(user, 'user_id', None)
234 234 if not user_id:
235 235 # maybe we have username ? Try to figure user_id from username
236 236 if username:
237 237 user_id = getattr(
238 238 User.get_by_username(username), 'user_id', None)
239 239
240 240 ip_addr = ip_addr or getattr(user, 'ip_addr', None)
241 241 if not ip_addr:
242 242 pass
243 243
244 244 if not user_data:
245 245 # try to get this from the auth user
246 246 if isinstance(user, AuthUser):
247 247 user_data = {
248 248 'username': user.username,
249 249 'email': user.email,
250 250 }
251 251
252 252 repository_name = getattr(repo, 'repo_name', None)
253 253 repository_id = getattr(repo, 'repo_id', None)
254 254 if not repository_id:
255 255 # maybe we have repo_name ? Try to figure repo_id from repo_name
256 256 if repository_name:
257 257 repository_id = getattr(
258 258 Repository.get_by_repo_name(repository_name), 'repo_id', None)
259 259
260 260 action_name = safe_unicode(action)
261 261 ip_address = safe_unicode(ip_addr)
262 262
263 263 with sa_session.no_autoflush:
264 264
265 265 user_log = _store_log(
266 266 action_name=action_name,
267 267 action_data=action_data or {},
268 268 user_id=user_id,
269 269 username=username,
270 270 user_data=user_data or {},
271 271 ip_address=ip_address,
272 272 repository_id=repository_id,
273 273 repository_name=repository_name
274 274 )
275 275
276 276 sa_session.add(user_log)
277 277 if commit:
278 278 sa_session.commit()
279 279 entry_id = user_log.entry_id or ''
280 280
281 281 update_user_last_activity(sa_session, user_id)
282 282
283 283 if commit:
284 284 sa_session.commit()
285 285
286 286 log.info('AUDIT[%s]: Logging action: `%s` by user:id:%s[%s] ip:%s',
287 287 entry_id, action_name, user_id, username, ip_address,
288 288 extra={"entry_id": entry_id, "action": action_name,
289 289 "user_id": user_id, "ip": ip_address})
290 290
291 291 except Exception:
292 292 log.exception('AUDIT: failed to store audit log')
293 293
294 294
295 295 def update_user_last_activity(sa_session, user_id):
296 296 _last_activity = datetime.datetime.now()
297 297 try:
298 298 sa_session.query(User).filter(User.user_id == user_id).update(
299 299 {"last_activity": _last_activity})
300 300 log.debug(
301 301 'updated user `%s` last activity to:%s', user_id, _last_activity)
302 302 except Exception:
303 303 log.exception("Failed last activity update for user_id: %s", user_id)
304 304 sa_session.rollback()
305 305
@@ -1,371 +1,371 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2016-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import os
22 22 import hashlib
23 23 import itsdangerous
24 24 import logging
25 25 import requests
26 26 import datetime
27 27
28 28 from dogpile.util.readwrite_lock import ReadWriteMutex
29 29 from pyramid.threadlocal import get_current_registry
30 30
31 31 import rhodecode.lib.helpers as h
32 32 from rhodecode.lib.auth import HasRepoPermissionAny
33 33 from rhodecode.lib.ext_json import json
34 34 from rhodecode.model.db import User
35 35
36 36 log = logging.getLogger(__name__)
37 37
38 38 LOCK = ReadWriteMutex()
39 39
40 40 USER_STATE_PUBLIC_KEYS = [
41 41 'id', 'username', 'first_name', 'last_name',
42 42 'icon_link', 'display_name', 'display_link']
43 43
44 44
45 45 class ChannelstreamException(Exception):
46 46 pass
47 47
48 48
49 49 class ChannelstreamConnectionException(ChannelstreamException):
50 50 pass
51 51
52 52
53 53 class ChannelstreamPermissionException(ChannelstreamException):
54 54 pass
55 55
56 56
57 57 def get_channelstream_server_url(config, endpoint):
58 58 return 'http://{}{}'.format(config['server'], endpoint)
59 59
60 60
61 61 def channelstream_request(config, payload, endpoint, raise_exc=True):
62 62 signer = itsdangerous.TimestampSigner(config['secret'])
63 63 sig_for_server = signer.sign(endpoint)
64 64 secret_headers = {'x-channelstream-secret': sig_for_server,
65 65 'x-channelstream-endpoint': endpoint,
66 66 'Content-Type': 'application/json'}
67 67 req_url = get_channelstream_server_url(config, endpoint)
68 68
69 69 log.debug('Sending a channelstream request to endpoint: `%s`', req_url)
70 70 response = None
71 71 try:
72 72 response = requests.post(req_url, data=json.dumps(payload),
73 73 headers=secret_headers).json()
74 74 except requests.ConnectionError:
75 75 log.exception('ConnectionError occurred for endpoint %s', req_url)
76 76 if raise_exc:
77 77 raise ChannelstreamConnectionException(req_url)
78 78 except Exception:
79 79 log.exception('Exception related to Channelstream happened')
80 80 if raise_exc:
81 81 raise ChannelstreamConnectionException()
82 82 log.debug('Got channelstream response: %s', response)
83 83 return response
84 84
85 85
86 86 def get_user_data(user_id):
87 87 user = User.get(user_id)
88 88 return {
89 89 'id': user.user_id,
90 90 'username': user.username,
91 91 'first_name': user.first_name,
92 92 'last_name': user.last_name,
93 93 'icon_link': h.gravatar_url(user.email, 60),
94 94 'display_name': h.person(user, 'username_or_name_or_email'),
95 95 'display_link': h.link_to_user(user),
96 96 'notifications': user.user_data.get('notification_status', True)
97 97 }
98 98
99 99
100 100 def broadcast_validator(channel_name):
101 101 """ checks if user can access the broadcast channel """
102 102 if channel_name == 'broadcast':
103 103 return True
104 104
105 105
106 106 def repo_validator(channel_name):
107 107 """ checks if user can access the broadcast channel """
108 108 channel_prefix = '/repo$'
109 109 if channel_name.startswith(channel_prefix):
110 110 elements = channel_name[len(channel_prefix):].split('$')
111 111 repo_name = elements[0]
112 112 can_access = HasRepoPermissionAny(
113 113 'repository.read',
114 114 'repository.write',
115 115 'repository.admin')(repo_name)
116 116 log.debug(
117 117 'permission check for %s channel resulted in %s',
118 118 repo_name, can_access)
119 119 if can_access:
120 120 return True
121 121 return False
122 122
123 123
124 124 def check_channel_permissions(channels, plugin_validators, should_raise=True):
125 125 valid_channels = []
126 126
127 127 validators = [broadcast_validator, repo_validator]
128 128 if plugin_validators:
129 129 validators.extend(plugin_validators)
130 130 for channel_name in channels:
131 131 is_valid = False
132 132 for validator in validators:
133 133 if validator(channel_name):
134 134 is_valid = True
135 135 break
136 136 if is_valid:
137 137 valid_channels.append(channel_name)
138 138 else:
139 139 if should_raise:
140 140 raise ChannelstreamPermissionException()
141 141 return valid_channels
142 142
143 143
144 144 def get_channels_info(self, channels):
145 145 payload = {'channels': channels}
146 146 # gather persistence info
147 147 return channelstream_request(self._config(), payload, '/info')
148 148
149 149
150 150 def parse_channels_info(info_result, include_channel_info=None):
151 151 """
152 152 Returns data that contains only secure information that can be
153 153 presented to clients
154 154 """
155 155 include_channel_info = include_channel_info or []
156 156
157 157 user_state_dict = {}
158 158 for userinfo in info_result['users']:
159 159 user_state_dict[userinfo['user']] = {
160 160 k: v for k, v in userinfo['state'].items()
161 161 if k in USER_STATE_PUBLIC_KEYS
162 162 }
163 163
164 164 channels_info = {}
165 165
166 166 for c_name, c_info in info_result['channels'].items():
167 167 if c_name not in include_channel_info:
168 168 continue
169 169 connected_list = []
170 170 for username in c_info['users']:
171 171 connected_list.append({
172 172 'user': username,
173 173 'state': user_state_dict[username]
174 174 })
175 175 channels_info[c_name] = {'users': connected_list,
176 176 'history': c_info['history']}
177 177
178 178 return channels_info
179 179
180 180
181 181 def log_filepath(history_location, channel_name):
182 182 hasher = hashlib.sha256()
183 183 hasher.update(channel_name.encode('utf8'))
184 184 filename = '{}.log'.format(hasher.hexdigest())
185 185 filepath = os.path.join(history_location, filename)
186 186 return filepath
187 187
188 188
189 189 def read_history(history_location, channel_name):
190 190 filepath = log_filepath(history_location, channel_name)
191 191 if not os.path.exists(filepath):
192 192 return []
193 193 history_lines_limit = -100
194 194 history = []
195 195 with open(filepath, 'rb') as f:
196 196 for line in f.readlines()[history_lines_limit:]:
197 197 try:
198 198 history.append(json.loads(line))
199 199 except Exception:
200 200 log.exception('Failed to load history')
201 201 return history
202 202
203 203
204 204 def update_history_from_logs(config, channels, payload):
205 205 history_location = config.get('history.location')
206 206 for channel in channels:
207 207 history = read_history(history_location, channel)
208 208 payload['channels_info'][channel]['history'] = history
209 209
210 210
211 211 def write_history(config, message):
212 212 """ writes a message to a base64encoded filename """
213 213 history_location = config.get('history.location')
214 214 if not os.path.exists(history_location):
215 215 return
216 216 try:
217 217 LOCK.acquire_write_lock()
218 218 filepath = log_filepath(history_location, message['channel'])
219 219 json_message = json.dumps(message)
220 220 with open(filepath, 'ab') as f:
221 221 f.write(json_message)
222 222 f.write('\n')
223 223 finally:
224 224 LOCK.release_write_lock()
225 225
226 226
227 227 def get_connection_validators(registry):
228 228 validators = []
229 229 for k, config in registry.rhodecode_plugins.items():
230 230 validator = config.get('channelstream', {}).get('connect_validator')
231 231 if validator:
232 232 validators.append(validator)
233 233 return validators
234 234
235 235
236 236 def get_channelstream_config(registry=None):
237 237 if not registry:
238 238 registry = get_current_registry()
239 239
240 240 rhodecode_plugins = getattr(registry, 'rhodecode_plugins', {})
241 241 channelstream_config = rhodecode_plugins.get('channelstream', {})
242 242 return channelstream_config
243 243
244 244
245 245 def post_message(channel, message, username, registry=None):
246 246 channelstream_config = get_channelstream_config(registry)
247 247 if not channelstream_config.get('enabled'):
248 248 return
249 249
250 250 message_obj = message
251 251 if isinstance(message, str):
252 252 message_obj = {
253 253 'message': message,
254 254 'level': 'success',
255 255 'topic': '/notifications'
256 256 }
257 257
258 258 log.debug('Channelstream: sending notification to channel %s', channel)
259 259 payload = {
260 260 'type': 'message',
261 261 'timestamp': datetime.datetime.utcnow(),
262 262 'user': 'system',
263 263 'exclude_users': [username],
264 264 'channel': channel,
265 265 'message': message_obj
266 266 }
267 267
268 268 try:
269 269 return channelstream_request(
270 270 channelstream_config, [payload], '/message',
271 271 raise_exc=False)
272 272 except ChannelstreamException:
273 273 log.exception('Failed to send channelstream data')
274 274 raise
275 275
276 276
277 277 def _reload_link(label):
278 278 return (
279 279 '<a onclick="window.location.reload()">'
280 280 '<strong>{}</strong>'
281 281 '</a>'.format(label)
282 282 )
283 283
284 284
285 285 def pr_channel(pull_request):
286 286 repo_name = pull_request.target_repo.repo_name
287 287 pull_request_id = pull_request.pull_request_id
288 288 channel = '/repo${}$/pr/{}'.format(repo_name, pull_request_id)
289 289 log.debug('Getting pull-request channelstream broadcast channel: %s', channel)
290 290 return channel
291 291
292 292
293 293 def comment_channel(repo_name, commit_obj=None, pull_request_obj=None):
294 294 channel = None
295 295 if commit_obj:
296 channel = u'/repo${}$/commit/{}'.format(
296 channel = '/repo${}$/commit/{}'.format(
297 297 repo_name, commit_obj.raw_id
298 298 )
299 299 elif pull_request_obj:
300 channel = u'/repo${}$/pr/{}'.format(
300 channel = '/repo${}$/pr/{}'.format(
301 301 repo_name, pull_request_obj.pull_request_id
302 302 )
303 303 log.debug('Getting comment channelstream broadcast channel: %s', channel)
304 304
305 305 return channel
306 306
307 307
308 308 def pr_update_channelstream_push(request, pr_broadcast_channel, user, msg, **kwargs):
309 309 """
310 310 Channel push on pull request update
311 311 """
312 312 if not pr_broadcast_channel:
313 313 return
314 314
315 315 _ = request.translate
316 316
317 317 message = '{} {}'.format(
318 318 msg,
319 319 _reload_link(_(' Reload page to load changes')))
320 320
321 321 message_obj = {
322 322 'message': message,
323 323 'level': 'success',
324 324 'topic': '/notifications'
325 325 }
326 326
327 327 post_message(
328 328 pr_broadcast_channel, message_obj, user.username,
329 329 registry=request.registry)
330 330
331 331
332 332 def comment_channelstream_push(request, comment_broadcast_channel, user, msg, **kwargs):
333 333 """
334 334 Channelstream push on comment action, on commit, or pull-request
335 335 """
336 336 if not comment_broadcast_channel:
337 337 return
338 338
339 339 _ = request.translate
340 340
341 341 comment_data = kwargs.pop('comment_data', {})
342 342 user_data = kwargs.pop('user_data', {})
343 343 comment_id = comment_data.keys()[0] if comment_data else ''
344 344
345 345 message = '<strong>{}</strong> {} #{}'.format(
346 346 user.username,
347 347 msg,
348 348 comment_id,
349 349 )
350 350
351 351 message_obj = {
352 352 'message': message,
353 353 'level': 'success',
354 354 'topic': '/notifications'
355 355 }
356 356
357 357 post_message(
358 358 comment_broadcast_channel, message_obj, user.username,
359 359 registry=request.registry)
360 360
361 361 message_obj = {
362 362 'message': None,
363 363 'user': user.username,
364 364 'comment_id': comment_id,
365 365 'comment_data': comment_data,
366 366 'user_data': user_data,
367 367 'topic': '/comment'
368 368 }
369 369 post_message(
370 370 comment_broadcast_channel, message_obj, user.username,
371 371 registry=request.registry)
@@ -1,797 +1,797 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import logging
22 22 import difflib
23 23 from itertools import groupby
24 24
25 25 from pygments import lex
26 26 from pygments.formatters.html import _get_ttype_class as pygment_token_class
27 27 from pygments.lexers.special import TextLexer, Token
28 28 from pygments.lexers import get_lexer_by_name
29 29
30 30 from rhodecode.lib.helpers import (
31 31 get_lexer_for_filenode, html_escape, get_custom_lexer)
32 32 from rhodecode.lib.utils2 import AttributeDict, StrictAttributeDict, safe_unicode
33 33 from rhodecode.lib.vcs.nodes import FileNode
34 34 from rhodecode.lib.vcs.exceptions import VCSError, NodeDoesNotExistError
35 35 from rhodecode.lib.diff_match_patch import diff_match_patch
36 36 from rhodecode.lib.diffs import LimitedDiffContainer, DEL_FILENODE, BIN_FILENODE
37 37
38 38
39 39 plain_text_lexer = get_lexer_by_name(
40 40 'text', stripall=False, stripnl=False, ensurenl=False)
41 41
42 42
43 43 log = logging.getLogger(__name__)
44 44
45 45
46 46 def filenode_as_lines_tokens(filenode, lexer=None):
47 47 org_lexer = lexer
48 48 lexer = lexer or get_lexer_for_filenode(filenode)
49 49 log.debug('Generating file node pygment tokens for %s, %s, org_lexer:%s',
50 50 lexer, filenode, org_lexer)
51 51 content = filenode.content
52 52 tokens = tokenize_string(content, lexer)
53 53 lines = split_token_stream(tokens, content)
54 54 rv = list(lines)
55 55 return rv
56 56
57 57
58 58 def tokenize_string(content, lexer):
59 59 """
60 60 Use pygments to tokenize some content based on a lexer
61 61 ensuring all original new lines and whitespace is preserved
62 62 """
63 63
64 64 lexer.stripall = False
65 65 lexer.stripnl = False
66 66 lexer.ensurenl = False
67 67
68 68 if isinstance(lexer, TextLexer):
69 69 lexed = [(Token.Text, content)]
70 70 else:
71 71 lexed = lex(content, lexer)
72 72
73 73 for token_type, token_text in lexed:
74 74 yield pygment_token_class(token_type), token_text
75 75
76 76
77 77 def split_token_stream(tokens, content):
78 78 """
79 79 Take a list of (TokenType, text) tuples and split them by a string
80 80
81 81 split_token_stream([(TEXT, 'some\ntext'), (TEXT, 'more\n')])
82 82 [(TEXT, 'some'), (TEXT, 'text'),
83 83 (TEXT, 'more'), (TEXT, 'text')]
84 84 """
85 85
86 86 token_buffer = []
87 87 for token_class, token_text in tokens:
88 88 parts = token_text.split('\n')
89 89 for part in parts[:-1]:
90 90 token_buffer.append((token_class, part))
91 91 yield token_buffer
92 92 token_buffer = []
93 93
94 94 token_buffer.append((token_class, parts[-1]))
95 95
96 96 if token_buffer:
97 97 yield token_buffer
98 98 elif content:
99 99 # this is a special case, we have the content, but tokenization didn't produce
100 100 # any results. THis can happen if know file extensions like .css have some bogus
101 101 # unicode content without any newline characters
102 102 yield [(pygment_token_class(Token.Text), content)]
103 103
104 104
105 105 def filenode_as_annotated_lines_tokens(filenode):
106 106 """
107 107 Take a file node and return a list of annotations => lines, if no annotation
108 108 is found, it will be None.
109 109
110 110 eg:
111 111
112 112 [
113 113 (annotation1, [
114 114 (1, line1_tokens_list),
115 115 (2, line2_tokens_list),
116 116 ]),
117 117 (annotation2, [
118 118 (3, line1_tokens_list),
119 119 ]),
120 120 (None, [
121 121 (4, line1_tokens_list),
122 122 ]),
123 123 (annotation1, [
124 124 (5, line1_tokens_list),
125 125 (6, line2_tokens_list),
126 126 ])
127 127 ]
128 128 """
129 129
130 130 commit_cache = {} # cache commit_getter lookups
131 131
132 132 def _get_annotation(commit_id, commit_getter):
133 133 if commit_id not in commit_cache:
134 134 commit_cache[commit_id] = commit_getter()
135 135 return commit_cache[commit_id]
136 136
137 137 annotation_lookup = {
138 138 line_no: _get_annotation(commit_id, commit_getter)
139 139 for line_no, commit_id, commit_getter, line_content
140 140 in filenode.annotate
141 141 }
142 142
143 143 annotations_lines = ((annotation_lookup.get(line_no), line_no, tokens)
144 144 for line_no, tokens
145 145 in enumerate(filenode_as_lines_tokens(filenode), 1))
146 146
147 147 grouped_annotations_lines = groupby(annotations_lines, lambda x: x[0])
148 148
149 149 for annotation, group in grouped_annotations_lines:
150 150 yield (
151 151 annotation, [(line_no, tokens)
152 152 for (_, line_no, tokens) in group]
153 153 )
154 154
155 155
156 156 def render_tokenstream(tokenstream):
157 157 result = []
158 158 for token_class, token_ops_texts in rollup_tokenstream(tokenstream):
159 159
160 160 if token_class:
161 result.append(u'<span class="%s">' % token_class)
161 result.append('<span class="%s">' % token_class)
162 162 else:
163 result.append(u'<span>')
163 result.append('<span>')
164 164
165 165 for op_tag, token_text in token_ops_texts:
166 166
167 167 if op_tag:
168 result.append(u'<%s>' % op_tag)
168 result.append('<%s>' % op_tag)
169 169
170 170 # NOTE(marcink): in some cases of mixed encodings, we might run into
171 171 # troubles in the html_escape, in this case we say unicode force on token_text
172 172 # that would ensure "correct" data even with the cost of rendered
173 173 try:
174 174 escaped_text = html_escape(token_text)
175 175 except TypeError:
176 176 escaped_text = html_escape(safe_unicode(token_text))
177 177
178 178 # TODO: dan: investigate showing hidden characters like space/nl/tab
179 179 # escaped_text = escaped_text.replace(' ', '<sp> </sp>')
180 180 # escaped_text = escaped_text.replace('\n', '<nl>\n</nl>')
181 181 # escaped_text = escaped_text.replace('\t', '<tab>\t</tab>')
182 182
183 183 result.append(escaped_text)
184 184
185 185 if op_tag:
186 result.append(u'</%s>' % op_tag)
186 result.append('</%s>' % op_tag)
187 187
188 result.append(u'</span>')
188 result.append('</span>')
189 189
190 190 html = ''.join(result)
191 191 return html
192 192
193 193
194 194 def rollup_tokenstream(tokenstream):
195 195 """
196 196 Group a token stream of the format:
197 197
198 198 ('class', 'op', 'text')
199 199 or
200 200 ('class', 'text')
201 201
202 202 into
203 203
204 204 [('class1',
205 205 [('op1', 'text'),
206 206 ('op2', 'text')]),
207 207 ('class2',
208 208 [('op3', 'text')])]
209 209
210 210 This is used to get the minimal tags necessary when
211 211 rendering to html eg for a token stream ie.
212 212
213 213 <span class="A"><ins>he</ins>llo</span>
214 214 vs
215 215 <span class="A"><ins>he</ins></span><span class="A">llo</span>
216 216
217 217 If a 2 tuple is passed in, the output op will be an empty string.
218 218
219 219 eg:
220 220
221 221 >>> rollup_tokenstream([('classA', '', 'h'),
222 222 ('classA', 'del', 'ell'),
223 223 ('classA', '', 'o'),
224 224 ('classB', '', ' '),
225 225 ('classA', '', 'the'),
226 226 ('classA', '', 're'),
227 227 ])
228 228
229 229 [('classA', [('', 'h'), ('del', 'ell'), ('', 'o')],
230 230 ('classB', [('', ' ')],
231 231 ('classA', [('', 'there')]]
232 232
233 233 """
234 234 if tokenstream and len(tokenstream[0]) == 2:
235 235 tokenstream = ((t[0], '', t[1]) for t in tokenstream)
236 236
237 237 result = []
238 238 for token_class, op_list in groupby(tokenstream, lambda t: t[0]):
239 239 ops = []
240 240 for token_op, token_text_list in groupby(op_list, lambda o: o[1]):
241 241 text_buffer = []
242 242 for t_class, t_op, t_text in token_text_list:
243 243 text_buffer.append(t_text)
244 244 ops.append((token_op, ''.join(text_buffer)))
245 245 result.append((token_class, ops))
246 246 return result
247 247
248 248
249 249 def tokens_diff(old_tokens, new_tokens, use_diff_match_patch=True):
250 250 """
251 251 Converts a list of (token_class, token_text) tuples to a list of
252 252 (token_class, token_op, token_text) tuples where token_op is one of
253 253 ('ins', 'del', '')
254 254
255 255 :param old_tokens: list of (token_class, token_text) tuples of old line
256 256 :param new_tokens: list of (token_class, token_text) tuples of new line
257 257 :param use_diff_match_patch: boolean, will use google's diff match patch
258 258 library which has options to 'smooth' out the character by character
259 259 differences making nicer ins/del blocks
260 260 """
261 261
262 262 old_tokens_result = []
263 263 new_tokens_result = []
264 264
265 265 similarity = difflib.SequenceMatcher(None,
266 266 ''.join(token_text for token_class, token_text in old_tokens),
267 267 ''.join(token_text for token_class, token_text in new_tokens)
268 268 ).ratio()
269 269
270 270 if similarity < 0.6: # return, the blocks are too different
271 271 for token_class, token_text in old_tokens:
272 272 old_tokens_result.append((token_class, '', token_text))
273 273 for token_class, token_text in new_tokens:
274 274 new_tokens_result.append((token_class, '', token_text))
275 275 return old_tokens_result, new_tokens_result, similarity
276 276
277 277 token_sequence_matcher = difflib.SequenceMatcher(None,
278 278 [x[1] for x in old_tokens],
279 279 [x[1] for x in new_tokens])
280 280
281 281 for tag, o1, o2, n1, n2 in token_sequence_matcher.get_opcodes():
282 282 # check the differences by token block types first to give a more
283 283 # nicer "block" level replacement vs character diffs
284 284
285 285 if tag == 'equal':
286 286 for token_class, token_text in old_tokens[o1:o2]:
287 287 old_tokens_result.append((token_class, '', token_text))
288 288 for token_class, token_text in new_tokens[n1:n2]:
289 289 new_tokens_result.append((token_class, '', token_text))
290 290 elif tag == 'delete':
291 291 for token_class, token_text in old_tokens[o1:o2]:
292 292 old_tokens_result.append((token_class, 'del', token_text))
293 293 elif tag == 'insert':
294 294 for token_class, token_text in new_tokens[n1:n2]:
295 295 new_tokens_result.append((token_class, 'ins', token_text))
296 296 elif tag == 'replace':
297 297 # if same type token blocks must be replaced, do a diff on the
298 298 # characters in the token blocks to show individual changes
299 299
300 300 old_char_tokens = []
301 301 new_char_tokens = []
302 302 for token_class, token_text in old_tokens[o1:o2]:
303 303 for char in token_text:
304 304 old_char_tokens.append((token_class, char))
305 305
306 306 for token_class, token_text in new_tokens[n1:n2]:
307 307 for char in token_text:
308 308 new_char_tokens.append((token_class, char))
309 309
310 310 old_string = ''.join([token_text for
311 311 token_class, token_text in old_char_tokens])
312 312 new_string = ''.join([token_text for
313 313 token_class, token_text in new_char_tokens])
314 314
315 315 char_sequence = difflib.SequenceMatcher(
316 316 None, old_string, new_string)
317 317 copcodes = char_sequence.get_opcodes()
318 318 obuffer, nbuffer = [], []
319 319
320 320 if use_diff_match_patch:
321 321 dmp = diff_match_patch()
322 322 dmp.Diff_EditCost = 11 # TODO: dan: extract this to a setting
323 323 reps = dmp.diff_main(old_string, new_string)
324 324 dmp.diff_cleanupEfficiency(reps)
325 325
326 326 a, b = 0, 0
327 327 for op, rep in reps:
328 328 l = len(rep)
329 329 if op == 0:
330 330 for i, c in enumerate(rep):
331 331 obuffer.append((old_char_tokens[a+i][0], '', c))
332 332 nbuffer.append((new_char_tokens[b+i][0], '', c))
333 333 a += l
334 334 b += l
335 335 elif op == -1:
336 336 for i, c in enumerate(rep):
337 337 obuffer.append((old_char_tokens[a+i][0], 'del', c))
338 338 a += l
339 339 elif op == 1:
340 340 for i, c in enumerate(rep):
341 341 nbuffer.append((new_char_tokens[b+i][0], 'ins', c))
342 342 b += l
343 343 else:
344 344 for ctag, co1, co2, cn1, cn2 in copcodes:
345 345 if ctag == 'equal':
346 346 for token_class, token_text in old_char_tokens[co1:co2]:
347 347 obuffer.append((token_class, '', token_text))
348 348 for token_class, token_text in new_char_tokens[cn1:cn2]:
349 349 nbuffer.append((token_class, '', token_text))
350 350 elif ctag == 'delete':
351 351 for token_class, token_text in old_char_tokens[co1:co2]:
352 352 obuffer.append((token_class, 'del', token_text))
353 353 elif ctag == 'insert':
354 354 for token_class, token_text in new_char_tokens[cn1:cn2]:
355 355 nbuffer.append((token_class, 'ins', token_text))
356 356 elif ctag == 'replace':
357 357 for token_class, token_text in old_char_tokens[co1:co2]:
358 358 obuffer.append((token_class, 'del', token_text))
359 359 for token_class, token_text in new_char_tokens[cn1:cn2]:
360 360 nbuffer.append((token_class, 'ins', token_text))
361 361
362 362 old_tokens_result.extend(obuffer)
363 363 new_tokens_result.extend(nbuffer)
364 364
365 365 return old_tokens_result, new_tokens_result, similarity
366 366
367 367
368 368 def diffset_node_getter(commit):
369 369 def get_node(fname):
370 370 try:
371 371 return commit.get_node(fname)
372 372 except NodeDoesNotExistError:
373 373 return None
374 374
375 375 return get_node
376 376
377 377
378 378 class DiffSet(object):
379 379 """
380 380 An object for parsing the diff result from diffs.DiffProcessor and
381 381 adding highlighting, side by side/unified renderings and line diffs
382 382 """
383 383
384 384 HL_REAL = 'REAL' # highlights using original file, slow
385 385 HL_FAST = 'FAST' # highlights using just the line, fast but not correct
386 386 # in the case of multiline code
387 387 HL_NONE = 'NONE' # no highlighting, fastest
388 388
389 389 def __init__(self, highlight_mode=HL_REAL, repo_name=None,
390 390 source_repo_name=None,
391 391 source_node_getter=lambda filename: None,
392 392 target_repo_name=None,
393 393 target_node_getter=lambda filename: None,
394 394 source_nodes=None, target_nodes=None,
395 395 # files over this size will use fast highlighting
396 396 max_file_size_limit=150 * 1024,
397 397 ):
398 398
399 399 self.highlight_mode = highlight_mode
400 400 self.highlighted_filenodes = {
401 401 'before': {},
402 402 'after': {}
403 403 }
404 404 self.source_node_getter = source_node_getter
405 405 self.target_node_getter = target_node_getter
406 406 self.source_nodes = source_nodes or {}
407 407 self.target_nodes = target_nodes or {}
408 408 self.repo_name = repo_name
409 409 self.target_repo_name = target_repo_name or repo_name
410 410 self.source_repo_name = source_repo_name or repo_name
411 411 self.max_file_size_limit = max_file_size_limit
412 412
413 413 def render_patchset(self, patchset, source_ref=None, target_ref=None):
414 414 diffset = AttributeDict(dict(
415 415 lines_added=0,
416 416 lines_deleted=0,
417 417 changed_files=0,
418 418 files=[],
419 419 file_stats={},
420 420 limited_diff=isinstance(patchset, LimitedDiffContainer),
421 421 repo_name=self.repo_name,
422 422 target_repo_name=self.target_repo_name,
423 423 source_repo_name=self.source_repo_name,
424 424 source_ref=source_ref,
425 425 target_ref=target_ref,
426 426 ))
427 427 for patch in patchset:
428 428 diffset.file_stats[patch['filename']] = patch['stats']
429 429 filediff = self.render_patch(patch)
430 430 filediff.diffset = StrictAttributeDict(dict(
431 431 source_ref=diffset.source_ref,
432 432 target_ref=diffset.target_ref,
433 433 repo_name=diffset.repo_name,
434 434 source_repo_name=diffset.source_repo_name,
435 435 target_repo_name=diffset.target_repo_name,
436 436 ))
437 437 diffset.files.append(filediff)
438 438 diffset.changed_files += 1
439 439 if not patch['stats']['binary']:
440 440 diffset.lines_added += patch['stats']['added']
441 441 diffset.lines_deleted += patch['stats']['deleted']
442 442
443 443 return diffset
444 444
445 445 _lexer_cache = {}
446 446
447 447 def _get_lexer_for_filename(self, filename, filenode=None):
448 448 # cached because we might need to call it twice for source/target
449 449 if filename not in self._lexer_cache:
450 450 if filenode:
451 451 lexer = filenode.lexer
452 452 extension = filenode.extension
453 453 else:
454 454 lexer = FileNode.get_lexer(filename=filename)
455 455 extension = filename.split('.')[-1]
456 456
457 457 lexer = get_custom_lexer(extension) or lexer
458 458 self._lexer_cache[filename] = lexer
459 459 return self._lexer_cache[filename]
460 460
461 461 def render_patch(self, patch):
462 462 log.debug('rendering diff for %r', patch['filename'])
463 463
464 464 source_filename = patch['original_filename']
465 465 target_filename = patch['filename']
466 466
467 467 source_lexer = plain_text_lexer
468 468 target_lexer = plain_text_lexer
469 469
470 470 if not patch['stats']['binary']:
471 471 node_hl_mode = self.HL_NONE if patch['chunks'] == [] else None
472 472 hl_mode = node_hl_mode or self.highlight_mode
473 473
474 474 if hl_mode == self.HL_REAL:
475 475 if (source_filename and patch['operation'] in ('D', 'M')
476 476 and source_filename not in self.source_nodes):
477 477 self.source_nodes[source_filename] = (
478 478 self.source_node_getter(source_filename))
479 479
480 480 if (target_filename and patch['operation'] in ('A', 'M')
481 481 and target_filename not in self.target_nodes):
482 482 self.target_nodes[target_filename] = (
483 483 self.target_node_getter(target_filename))
484 484
485 485 elif hl_mode == self.HL_FAST:
486 486 source_lexer = self._get_lexer_for_filename(source_filename)
487 487 target_lexer = self._get_lexer_for_filename(target_filename)
488 488
489 489 source_file = self.source_nodes.get(source_filename, source_filename)
490 490 target_file = self.target_nodes.get(target_filename, target_filename)
491 491 raw_id_uid = ''
492 492 if self.source_nodes.get(source_filename):
493 493 raw_id_uid = self.source_nodes[source_filename].commit.raw_id
494 494
495 495 if not raw_id_uid and self.target_nodes.get(target_filename):
496 496 # in case this is a new file we only have it in target
497 497 raw_id_uid = self.target_nodes[target_filename].commit.raw_id
498 498
499 499 source_filenode, target_filenode = None, None
500 500
501 501 # TODO: dan: FileNode.lexer works on the content of the file - which
502 502 # can be slow - issue #4289 explains a lexer clean up - which once
503 503 # done can allow caching a lexer for a filenode to avoid the file lookup
504 504 if isinstance(source_file, FileNode):
505 505 source_filenode = source_file
506 506 #source_lexer = source_file.lexer
507 507 source_lexer = self._get_lexer_for_filename(source_filename)
508 508 source_file.lexer = source_lexer
509 509
510 510 if isinstance(target_file, FileNode):
511 511 target_filenode = target_file
512 512 #target_lexer = target_file.lexer
513 513 target_lexer = self._get_lexer_for_filename(target_filename)
514 514 target_file.lexer = target_lexer
515 515
516 516 source_file_path, target_file_path = None, None
517 517
518 518 if source_filename != '/dev/null':
519 519 source_file_path = source_filename
520 520 if target_filename != '/dev/null':
521 521 target_file_path = target_filename
522 522
523 523 source_file_type = source_lexer.name
524 524 target_file_type = target_lexer.name
525 525
526 526 filediff = AttributeDict({
527 527 'source_file_path': source_file_path,
528 528 'target_file_path': target_file_path,
529 529 'source_filenode': source_filenode,
530 530 'target_filenode': target_filenode,
531 531 'source_file_type': target_file_type,
532 532 'target_file_type': source_file_type,
533 533 'patch': {'filename': patch['filename'], 'stats': patch['stats']},
534 534 'operation': patch['operation'],
535 535 'source_mode': patch['stats']['old_mode'],
536 536 'target_mode': patch['stats']['new_mode'],
537 537 'limited_diff': patch['is_limited_diff'],
538 538 'hunks': [],
539 539 'hunk_ops': None,
540 540 'diffset': self,
541 541 'raw_id': raw_id_uid,
542 542 })
543 543
544 544 file_chunks = patch['chunks'][1:]
545 545 for i, hunk in enumerate(file_chunks, 1):
546 546 hunkbit = self.parse_hunk(hunk, source_file, target_file)
547 547 hunkbit.source_file_path = source_file_path
548 548 hunkbit.target_file_path = target_file_path
549 549 hunkbit.index = i
550 550 filediff.hunks.append(hunkbit)
551 551
552 552 # Simulate hunk on OPS type line which doesn't really contain any diff
553 553 # this allows commenting on those
554 554 if not file_chunks:
555 555 actions = []
556 556 for op_id, op_text in filediff.patch['stats']['ops'].items():
557 557 if op_id == DEL_FILENODE:
558 actions.append(u'file was removed')
558 actions.append('file was removed')
559 559 elif op_id == BIN_FILENODE:
560 actions.append(u'binary diff hidden')
560 actions.append('binary diff hidden')
561 561 else:
562 562 actions.append(safe_unicode(op_text))
563 action_line = u'NO CONTENT: ' + \
564 u', '.join(actions) or u'UNDEFINED_ACTION'
563 action_line = 'NO CONTENT: ' + \
564 ', '.join(actions) or 'UNDEFINED_ACTION'
565 565
566 566 hunk_ops = {'source_length': 0, 'source_start': 0,
567 567 'lines': [
568 568 {'new_lineno': 0, 'old_lineno': 1,
569 569 'action': 'unmod-no-hl', 'line': action_line}
570 570 ],
571 'section_header': u'', 'target_start': 1, 'target_length': 1}
571 'section_header': '', 'target_start': 1, 'target_length': 1}
572 572
573 573 hunkbit = self.parse_hunk(hunk_ops, source_file, target_file)
574 574 hunkbit.source_file_path = source_file_path
575 575 hunkbit.target_file_path = target_file_path
576 576 filediff.hunk_ops = hunkbit
577 577 return filediff
578 578
579 579 def parse_hunk(self, hunk, source_file, target_file):
580 580 result = AttributeDict(dict(
581 581 source_start=hunk['source_start'],
582 582 source_length=hunk['source_length'],
583 583 target_start=hunk['target_start'],
584 584 target_length=hunk['target_length'],
585 585 section_header=hunk['section_header'],
586 586 lines=[],
587 587 ))
588 588 before, after = [], []
589 589
590 590 for line in hunk['lines']:
591 591 if line['action'] in ['unmod', 'unmod-no-hl']:
592 592 no_hl = line['action'] == 'unmod-no-hl'
593 593 result.lines.extend(
594 594 self.parse_lines(before, after, source_file, target_file, no_hl=no_hl))
595 595 after.append(line)
596 596 before.append(line)
597 597 elif line['action'] == 'add':
598 598 after.append(line)
599 599 elif line['action'] == 'del':
600 600 before.append(line)
601 601 elif line['action'] == 'old-no-nl':
602 602 before.append(line)
603 603 elif line['action'] == 'new-no-nl':
604 604 after.append(line)
605 605
606 606 all_actions = [x['action'] for x in after] + [x['action'] for x in before]
607 607 no_hl = {x for x in all_actions} == {'unmod-no-hl'}
608 608 result.lines.extend(
609 609 self.parse_lines(before, after, source_file, target_file, no_hl=no_hl))
610 610 # NOTE(marcink): we must keep list() call here so we can cache the result...
611 611 result.unified = list(self.as_unified(result.lines))
612 612 result.sideside = result.lines
613 613
614 614 return result
615 615
616 616 def parse_lines(self, before_lines, after_lines, source_file, target_file,
617 617 no_hl=False):
618 618 # TODO: dan: investigate doing the diff comparison and fast highlighting
619 619 # on the entire before and after buffered block lines rather than by
620 620 # line, this means we can get better 'fast' highlighting if the context
621 621 # allows it - eg.
622 622 # line 4: """
623 623 # line 5: this gets highlighted as a string
624 624 # line 6: """
625 625
626 626 lines = []
627 627
628 628 before_newline = AttributeDict()
629 629 after_newline = AttributeDict()
630 630 if before_lines and before_lines[-1]['action'] == 'old-no-nl':
631 631 before_newline_line = before_lines.pop(-1)
632 632 before_newline.content = '\n {}'.format(
633 633 render_tokenstream(
634 634 [(x[0], '', x[1])
635 635 for x in [('nonl', before_newline_line['line'])]]))
636 636
637 637 if after_lines and after_lines[-1]['action'] == 'new-no-nl':
638 638 after_newline_line = after_lines.pop(-1)
639 639 after_newline.content = '\n {}'.format(
640 640 render_tokenstream(
641 641 [(x[0], '', x[1])
642 642 for x in [('nonl', after_newline_line['line'])]]))
643 643
644 644 while before_lines or after_lines:
645 645 before, after = None, None
646 646 before_tokens, after_tokens = None, None
647 647
648 648 if before_lines:
649 649 before = before_lines.pop(0)
650 650 if after_lines:
651 651 after = after_lines.pop(0)
652 652
653 653 original = AttributeDict()
654 654 modified = AttributeDict()
655 655
656 656 if before:
657 657 if before['action'] == 'old-no-nl':
658 658 before_tokens = [('nonl', before['line'])]
659 659 else:
660 660 before_tokens = self.get_line_tokens(
661 661 line_text=before['line'], line_number=before['old_lineno'],
662 662 input_file=source_file, no_hl=no_hl, source='before')
663 663 original.lineno = before['old_lineno']
664 664 original.content = before['line']
665 665 original.action = self.action_to_op(before['action'])
666 666
667 667 original.get_comment_args = (
668 668 source_file, 'o', before['old_lineno'])
669 669
670 670 if after:
671 671 if after['action'] == 'new-no-nl':
672 672 after_tokens = [('nonl', after['line'])]
673 673 else:
674 674 after_tokens = self.get_line_tokens(
675 675 line_text=after['line'], line_number=after['new_lineno'],
676 676 input_file=target_file, no_hl=no_hl, source='after')
677 677 modified.lineno = after['new_lineno']
678 678 modified.content = after['line']
679 679 modified.action = self.action_to_op(after['action'])
680 680
681 681 modified.get_comment_args = (target_file, 'n', after['new_lineno'])
682 682
683 683 # diff the lines
684 684 if before_tokens and after_tokens:
685 685 o_tokens, m_tokens, similarity = tokens_diff(
686 686 before_tokens, after_tokens)
687 687 original.content = render_tokenstream(o_tokens)
688 688 modified.content = render_tokenstream(m_tokens)
689 689 elif before_tokens:
690 690 original.content = render_tokenstream(
691 691 [(x[0], '', x[1]) for x in before_tokens])
692 692 elif after_tokens:
693 693 modified.content = render_tokenstream(
694 694 [(x[0], '', x[1]) for x in after_tokens])
695 695
696 696 if not before_lines and before_newline:
697 697 original.content += before_newline.content
698 698 before_newline = None
699 699 if not after_lines and after_newline:
700 700 modified.content += after_newline.content
701 701 after_newline = None
702 702
703 703 lines.append(AttributeDict({
704 704 'original': original,
705 705 'modified': modified,
706 706 }))
707 707
708 708 return lines
709 709
710 710 def get_line_tokens(self, line_text, line_number, input_file=None, no_hl=False, source=''):
711 711 filenode = None
712 712 filename = None
713 713
714 714 if isinstance(input_file, str):
715 715 filename = input_file
716 716 elif isinstance(input_file, FileNode):
717 717 filenode = input_file
718 718 filename = input_file.unicode_path
719 719
720 720 hl_mode = self.HL_NONE if no_hl else self.highlight_mode
721 721 if hl_mode == self.HL_REAL and filenode:
722 722 lexer = self._get_lexer_for_filename(filename)
723 723 file_size_allowed = input_file.size < self.max_file_size_limit
724 724 if line_number and file_size_allowed:
725 725 return self.get_tokenized_filenode_line(input_file, line_number, lexer, source)
726 726
727 727 if hl_mode in (self.HL_REAL, self.HL_FAST) and filename:
728 728 lexer = self._get_lexer_for_filename(filename)
729 729 return list(tokenize_string(line_text, lexer))
730 730
731 731 return list(tokenize_string(line_text, plain_text_lexer))
732 732
733 733 def get_tokenized_filenode_line(self, filenode, line_number, lexer=None, source=''):
734 734
735 735 def tokenize(_filenode):
736 736 self.highlighted_filenodes[source][filenode] = filenode_as_lines_tokens(filenode, lexer)
737 737
738 738 if filenode not in self.highlighted_filenodes[source]:
739 739 tokenize(filenode)
740 740
741 741 try:
742 742 return self.highlighted_filenodes[source][filenode][line_number - 1]
743 743 except Exception:
744 744 log.exception('diff rendering error')
745 return [('', u'L{}: rhodecode diff rendering error'.format(line_number))]
745 return [('', 'L{}: rhodecode diff rendering error'.format(line_number))]
746 746
747 747 def action_to_op(self, action):
748 748 return {
749 749 'add': '+',
750 750 'del': '-',
751 751 'unmod': ' ',
752 752 'unmod-no-hl': ' ',
753 753 'old-no-nl': ' ',
754 754 'new-no-nl': ' ',
755 755 }.get(action, action)
756 756
757 757 def as_unified(self, lines):
758 758 """
759 759 Return a generator that yields the lines of a diff in unified order
760 760 """
761 761 def generator():
762 762 buf = []
763 763 for line in lines:
764 764
765 765 if buf and not line.original or line.original.action == ' ':
766 766 for b in buf:
767 767 yield b
768 768 buf = []
769 769
770 770 if line.original:
771 771 if line.original.action == ' ':
772 772 yield (line.original.lineno, line.modified.lineno,
773 773 line.original.action, line.original.content,
774 774 line.original.get_comment_args)
775 775 continue
776 776
777 777 if line.original.action == '-':
778 778 yield (line.original.lineno, None,
779 779 line.original.action, line.original.content,
780 780 line.original.get_comment_args)
781 781
782 782 if line.modified.action == '+':
783 783 buf.append((
784 784 None, line.modified.lineno,
785 785 line.modified.action, line.modified.content,
786 786 line.modified.get_comment_args))
787 787 continue
788 788
789 789 if line.modified:
790 790 yield (None, line.modified.lineno,
791 791 line.modified.action, line.modified.content,
792 792 line.modified.get_comment_args)
793 793
794 794 for b in buf:
795 795 yield b
796 796
797 797 return generator()
@@ -1,680 +1,680 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 """
22 22 Database creation, and setup module for RhodeCode Enterprise. Used for creation
23 23 of database as well as for migration operations
24 24 """
25 25
26 26 import os
27 27 import sys
28 28 import time
29 29 import uuid
30 30 import logging
31 31 import getpass
32 32 from os.path import dirname as dn, join as jn
33 33
34 34 from sqlalchemy.engine import create_engine
35 35
36 36 from rhodecode import __dbversion__
37 37 from rhodecode.model import init_model
38 38 from rhodecode.model.user import UserModel
39 39 from rhodecode.model.db import (
40 40 User, Permission, RhodeCodeUi, RhodeCodeSetting, UserToPerm,
41 41 DbMigrateVersion, RepoGroup, UserRepoGroupToPerm, CacheKey, Repository)
42 42 from rhodecode.model.meta import Session, Base
43 43 from rhodecode.model.permission import PermissionModel
44 44 from rhodecode.model.repo import RepoModel
45 45 from rhodecode.model.repo_group import RepoGroupModel
46 46 from rhodecode.model.settings import SettingsModel
47 47
48 48
49 49 log = logging.getLogger(__name__)
50 50
51 51
52 52 def notify(msg):
53 53 """
54 54 Notification for migrations messages
55 55 """
56 56 ml = len(msg) + (4 * 2)
57 57 print(('\n%s\n*** %s ***\n%s' % ('*' * ml, msg, '*' * ml)).upper())
58 58
59 59
60 60 class DbManage(object):
61 61
62 62 def __init__(self, log_sql, dbconf, root, tests=False,
63 63 SESSION=None, cli_args=None):
64 64 self.dbname = dbconf.split('/')[-1]
65 65 self.tests = tests
66 66 self.root = root
67 67 self.dburi = dbconf
68 68 self.log_sql = log_sql
69 69 self.cli_args = cli_args or {}
70 70 self.init_db(SESSION=SESSION)
71 71 self.ask_ok = self.get_ask_ok_func(self.cli_args.get('force_ask'))
72 72
73 73 def db_exists(self):
74 74 if not self.sa:
75 75 self.init_db()
76 76 try:
77 77 self.sa.query(RhodeCodeUi)\
78 78 .filter(RhodeCodeUi.ui_key == '/')\
79 79 .scalar()
80 80 return True
81 81 except Exception:
82 82 return False
83 83 finally:
84 84 self.sa.rollback()
85 85
86 86 def get_ask_ok_func(self, param):
87 87 if param not in [None]:
88 88 # return a function lambda that has a default set to param
89 89 return lambda *args, **kwargs: param
90 90 else:
91 91 from rhodecode.lib.utils import ask_ok
92 92 return ask_ok
93 93
94 94 def init_db(self, SESSION=None):
95 95 if SESSION:
96 96 self.sa = SESSION
97 97 else:
98 98 # init new sessions
99 99 engine = create_engine(self.dburi, echo=self.log_sql)
100 100 init_model(engine)
101 101 self.sa = Session()
102 102
103 103 def create_tables(self, override=False):
104 104 """
105 105 Create a auth database
106 106 """
107 107
108 108 log.info("Existing database with the same name is going to be destroyed.")
109 109 log.info("Setup command will run DROP ALL command on that database.")
110 110 if self.tests:
111 111 destroy = True
112 112 else:
113 113 destroy = self.ask_ok('Are you sure that you want to destroy the old database? [y/n]')
114 114 if not destroy:
115 115 log.info('db tables bootstrap: Nothing done.')
116 116 sys.exit(0)
117 117 if destroy:
118 118 Base.metadata.drop_all()
119 119
120 120 checkfirst = not override
121 121 Base.metadata.create_all(checkfirst=checkfirst)
122 122 log.info('Created tables for %s', self.dbname)
123 123
124 124 def set_db_version(self):
125 125 ver = DbMigrateVersion()
126 126 ver.version = __dbversion__
127 127 ver.repository_id = 'rhodecode_db_migrations'
128 128 ver.repository_path = 'versions'
129 129 self.sa.add(ver)
130 130 log.info('db version set to: %s', __dbversion__)
131 131
132 132 def run_post_migration_tasks(self):
133 133 """
134 134 Run various tasks before actually doing migrations
135 135 """
136 136 # delete cache keys on each upgrade
137 137 total = CacheKey.query().count()
138 138 log.info("Deleting (%s) cache keys now...", total)
139 139 CacheKey.delete_all_cache()
140 140
141 141 def upgrade(self, version=None):
142 142 """
143 143 Upgrades given database schema to given revision following
144 144 all needed steps, to perform the upgrade
145 145
146 146 """
147 147
148 148 from rhodecode.lib.dbmigrate.migrate.versioning import api
149 149 from rhodecode.lib.dbmigrate.migrate.exceptions import \
150 150 DatabaseNotControlledError
151 151
152 152 if 'sqlite' in self.dburi:
153 153 print(
154 154 '********************** WARNING **********************\n'
155 155 'Make sure your version of sqlite is at least 3.7.X. \n'
156 156 'Earlier versions are known to fail on some migrations\n'
157 157 '*****************************************************\n')
158 158
159 159 upgrade = self.ask_ok(
160 160 'You are about to perform a database upgrade. Make '
161 161 'sure you have backed up your database. '
162 162 'Continue ? [y/n]')
163 163 if not upgrade:
164 164 log.info('No upgrade performed')
165 165 sys.exit(0)
166 166
167 167 repository_path = jn(dn(dn(dn(os.path.realpath(__file__)))),
168 168 'rhodecode/lib/dbmigrate')
169 169 db_uri = self.dburi
170 170
171 171 if version:
172 172 DbMigrateVersion.set_version(version)
173 173
174 174 try:
175 175 curr_version = api.db_version(db_uri, repository_path)
176 176 msg = ('Found current database db_uri under version '
177 177 'control with version {}'.format(curr_version))
178 178
179 179 except (RuntimeError, DatabaseNotControlledError):
180 180 curr_version = 1
181 181 msg = ('Current database is not under version control. Setting '
182 182 'as version %s' % curr_version)
183 183 api.version_control(db_uri, repository_path, curr_version)
184 184
185 185 notify(msg)
186 186
187 187
188 188 if curr_version == __dbversion__:
189 189 log.info('This database is already at the newest version')
190 190 sys.exit(0)
191 191
192 192 upgrade_steps = range(curr_version + 1, __dbversion__ + 1)
193 193 notify('attempting to upgrade database from '
194 194 'version %s to version %s' % (curr_version, __dbversion__))
195 195
196 196 # CALL THE PROPER ORDER OF STEPS TO PERFORM FULL UPGRADE
197 197 _step = None
198 198 for step in upgrade_steps:
199 199 notify('performing upgrade step %s' % step)
200 200 time.sleep(0.5)
201 201
202 202 api.upgrade(db_uri, repository_path, step)
203 203 self.sa.rollback()
204 204 notify('schema upgrade for step %s completed' % (step,))
205 205
206 206 _step = step
207 207
208 208 self.run_post_migration_tasks()
209 209 notify('upgrade to version %s successful' % _step)
210 210
211 211 def fix_repo_paths(self):
212 212 """
213 213 Fixes an old RhodeCode version path into new one without a '*'
214 214 """
215 215
216 216 paths = self.sa.query(RhodeCodeUi)\
217 217 .filter(RhodeCodeUi.ui_key == '/')\
218 218 .scalar()
219 219
220 220 paths.ui_value = paths.ui_value.replace('*', '')
221 221
222 222 try:
223 223 self.sa.add(paths)
224 224 self.sa.commit()
225 225 except Exception:
226 226 self.sa.rollback()
227 227 raise
228 228
229 229 def fix_default_user(self):
230 230 """
231 231 Fixes an old default user with some 'nicer' default values,
232 232 used mostly for anonymous access
233 233 """
234 234 def_user = self.sa.query(User)\
235 235 .filter(User.username == User.DEFAULT_USER)\
236 236 .one()
237 237
238 238 def_user.name = 'Anonymous'
239 239 def_user.lastname = 'User'
240 240 def_user.email = User.DEFAULT_USER_EMAIL
241 241
242 242 try:
243 243 self.sa.add(def_user)
244 244 self.sa.commit()
245 245 except Exception:
246 246 self.sa.rollback()
247 247 raise
248 248
249 249 def fix_settings(self):
250 250 """
251 251 Fixes rhodecode settings and adds ga_code key for google analytics
252 252 """
253 253
254 254 hgsettings3 = RhodeCodeSetting('ga_code', '')
255 255
256 256 try:
257 257 self.sa.add(hgsettings3)
258 258 self.sa.commit()
259 259 except Exception:
260 260 self.sa.rollback()
261 261 raise
262 262
263 263 def create_admin_and_prompt(self):
264 264
265 265 # defaults
266 266 defaults = self.cli_args
267 267 username = defaults.get('username')
268 268 password = defaults.get('password')
269 269 email = defaults.get('email')
270 270
271 271 if username is None:
272 username = input('Specify admin username:')
272 username = eval(input('Specify admin username:'))
273 273 if password is None:
274 274 password = self._get_admin_password()
275 275 if not password:
276 276 # second try
277 277 password = self._get_admin_password()
278 278 if not password:
279 279 sys.exit()
280 280 if email is None:
281 email = input('Specify admin email:')
281 email = eval(input('Specify admin email:'))
282 282 api_key = self.cli_args.get('api_key')
283 283 self.create_user(username, password, email, True,
284 284 strict_creation_check=False,
285 285 api_key=api_key)
286 286
287 287 def _get_admin_password(self):
288 288 password = getpass.getpass('Specify admin password '
289 289 '(min 6 chars):')
290 290 confirm = getpass.getpass('Confirm password:')
291 291
292 292 if password != confirm:
293 293 log.error('passwords mismatch')
294 294 return False
295 295 if len(password) < 6:
296 296 log.error('password is too short - use at least 6 characters')
297 297 return False
298 298
299 299 return password
300 300
301 301 def create_test_admin_and_users(self):
302 302 log.info('creating admin and regular test users')
303 303 from rhodecode.tests import TEST_USER_ADMIN_LOGIN, \
304 304 TEST_USER_ADMIN_PASS, TEST_USER_ADMIN_EMAIL, \
305 305 TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS, \
306 306 TEST_USER_REGULAR_EMAIL, TEST_USER_REGULAR2_LOGIN, \
307 307 TEST_USER_REGULAR2_PASS, TEST_USER_REGULAR2_EMAIL
308 308
309 309 self.create_user(TEST_USER_ADMIN_LOGIN, TEST_USER_ADMIN_PASS,
310 310 TEST_USER_ADMIN_EMAIL, True, api_key=True)
311 311
312 312 self.create_user(TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS,
313 313 TEST_USER_REGULAR_EMAIL, False, api_key=True)
314 314
315 315 self.create_user(TEST_USER_REGULAR2_LOGIN, TEST_USER_REGULAR2_PASS,
316 316 TEST_USER_REGULAR2_EMAIL, False, api_key=True)
317 317
318 318 def create_ui_settings(self, repo_store_path):
319 319 """
320 320 Creates ui settings, fills out hooks
321 321 and disables dotencode
322 322 """
323 323 settings_model = SettingsModel(sa=self.sa)
324 324 from rhodecode.lib.vcs.backends.hg import largefiles_store
325 325 from rhodecode.lib.vcs.backends.git import lfs_store
326 326
327 327 # Build HOOKS
328 328 hooks = [
329 329 (RhodeCodeUi.HOOK_REPO_SIZE, 'python:vcsserver.hooks.repo_size'),
330 330
331 331 # HG
332 332 (RhodeCodeUi.HOOK_PRE_PULL, 'python:vcsserver.hooks.pre_pull'),
333 333 (RhodeCodeUi.HOOK_PULL, 'python:vcsserver.hooks.log_pull_action'),
334 334 (RhodeCodeUi.HOOK_PRE_PUSH, 'python:vcsserver.hooks.pre_push'),
335 335 (RhodeCodeUi.HOOK_PRETX_PUSH, 'python:vcsserver.hooks.pre_push'),
336 336 (RhodeCodeUi.HOOK_PUSH, 'python:vcsserver.hooks.log_push_action'),
337 337 (RhodeCodeUi.HOOK_PUSH_KEY, 'python:vcsserver.hooks.key_push'),
338 338
339 339 ]
340 340
341 341 for key, value in hooks:
342 342 hook_obj = settings_model.get_ui_by_key(key)
343 343 hooks2 = hook_obj if hook_obj else RhodeCodeUi()
344 344 hooks2.ui_section = 'hooks'
345 345 hooks2.ui_key = key
346 346 hooks2.ui_value = value
347 347 self.sa.add(hooks2)
348 348
349 349 # enable largefiles
350 350 largefiles = RhodeCodeUi()
351 351 largefiles.ui_section = 'extensions'
352 352 largefiles.ui_key = 'largefiles'
353 353 largefiles.ui_value = ''
354 354 self.sa.add(largefiles)
355 355
356 356 # set default largefiles cache dir, defaults to
357 357 # /repo_store_location/.cache/largefiles
358 358 largefiles = RhodeCodeUi()
359 359 largefiles.ui_section = 'largefiles'
360 360 largefiles.ui_key = 'usercache'
361 361 largefiles.ui_value = largefiles_store(repo_store_path)
362 362
363 363 self.sa.add(largefiles)
364 364
365 365 # set default lfs cache dir, defaults to
366 366 # /repo_store_location/.cache/lfs_store
367 367 lfsstore = RhodeCodeUi()
368 368 lfsstore.ui_section = 'vcs_git_lfs'
369 369 lfsstore.ui_key = 'store_location'
370 370 lfsstore.ui_value = lfs_store(repo_store_path)
371 371
372 372 self.sa.add(lfsstore)
373 373
374 374 # enable hgsubversion disabled by default
375 375 hgsubversion = RhodeCodeUi()
376 376 hgsubversion.ui_section = 'extensions'
377 377 hgsubversion.ui_key = 'hgsubversion'
378 378 hgsubversion.ui_value = ''
379 379 hgsubversion.ui_active = False
380 380 self.sa.add(hgsubversion)
381 381
382 382 # enable hgevolve disabled by default
383 383 hgevolve = RhodeCodeUi()
384 384 hgevolve.ui_section = 'extensions'
385 385 hgevolve.ui_key = 'evolve'
386 386 hgevolve.ui_value = ''
387 387 hgevolve.ui_active = False
388 388 self.sa.add(hgevolve)
389 389
390 390 hgevolve = RhodeCodeUi()
391 391 hgevolve.ui_section = 'experimental'
392 392 hgevolve.ui_key = 'evolution'
393 393 hgevolve.ui_value = ''
394 394 hgevolve.ui_active = False
395 395 self.sa.add(hgevolve)
396 396
397 397 hgevolve = RhodeCodeUi()
398 398 hgevolve.ui_section = 'experimental'
399 399 hgevolve.ui_key = 'evolution.exchange'
400 400 hgevolve.ui_value = ''
401 401 hgevolve.ui_active = False
402 402 self.sa.add(hgevolve)
403 403
404 404 hgevolve = RhodeCodeUi()
405 405 hgevolve.ui_section = 'extensions'
406 406 hgevolve.ui_key = 'topic'
407 407 hgevolve.ui_value = ''
408 408 hgevolve.ui_active = False
409 409 self.sa.add(hgevolve)
410 410
411 411 # enable hggit disabled by default
412 412 hggit = RhodeCodeUi()
413 413 hggit.ui_section = 'extensions'
414 414 hggit.ui_key = 'hggit'
415 415 hggit.ui_value = ''
416 416 hggit.ui_active = False
417 417 self.sa.add(hggit)
418 418
419 419 # set svn branch defaults
420 420 branches = ["/branches/*", "/trunk"]
421 421 tags = ["/tags/*"]
422 422
423 423 for branch in branches:
424 424 settings_model.create_ui_section_value(
425 425 RhodeCodeUi.SVN_BRANCH_ID, branch)
426 426
427 427 for tag in tags:
428 428 settings_model.create_ui_section_value(RhodeCodeUi.SVN_TAG_ID, tag)
429 429
430 430 def create_auth_plugin_options(self, skip_existing=False):
431 431 """
432 432 Create default auth plugin settings, and make it active
433 433
434 434 :param skip_existing:
435 435 """
436 436 defaults = [
437 437 ('auth_plugins',
438 438 'egg:rhodecode-enterprise-ce#token,egg:rhodecode-enterprise-ce#rhodecode',
439 439 'list'),
440 440
441 441 ('auth_authtoken_enabled',
442 442 'True',
443 443 'bool'),
444 444
445 445 ('auth_rhodecode_enabled',
446 446 'True',
447 447 'bool'),
448 448 ]
449 449 for k, v, t in defaults:
450 450 if (skip_existing and
451 451 SettingsModel().get_setting_by_name(k) is not None):
452 452 log.debug('Skipping option %s', k)
453 453 continue
454 454 setting = RhodeCodeSetting(k, v, t)
455 455 self.sa.add(setting)
456 456
457 457 def create_default_options(self, skip_existing=False):
458 458 """Creates default settings"""
459 459
460 460 for k, v, t in [
461 461 ('default_repo_enable_locking', False, 'bool'),
462 462 ('default_repo_enable_downloads', False, 'bool'),
463 463 ('default_repo_enable_statistics', False, 'bool'),
464 464 ('default_repo_private', False, 'bool'),
465 465 ('default_repo_type', 'hg', 'unicode')]:
466 466
467 467 if (skip_existing and
468 468 SettingsModel().get_setting_by_name(k) is not None):
469 469 log.debug('Skipping option %s', k)
470 470 continue
471 471 setting = RhodeCodeSetting(k, v, t)
472 472 self.sa.add(setting)
473 473
474 474 def fixup_groups(self):
475 475 def_usr = User.get_default_user()
476 476 for g in RepoGroup.query().all():
477 477 g.group_name = g.get_new_name(g.name)
478 478 self.sa.add(g)
479 479 # get default perm
480 480 default = UserRepoGroupToPerm.query()\
481 481 .filter(UserRepoGroupToPerm.group == g)\
482 482 .filter(UserRepoGroupToPerm.user == def_usr)\
483 483 .scalar()
484 484
485 485 if default is None:
486 486 log.debug('missing default permission for group %s adding', g)
487 487 perm_obj = RepoGroupModel()._create_default_perms(g)
488 488 self.sa.add(perm_obj)
489 489
490 490 def reset_permissions(self, username):
491 491 """
492 492 Resets permissions to default state, useful when old systems had
493 493 bad permissions, we must clean them up
494 494
495 495 :param username:
496 496 """
497 497 default_user = User.get_by_username(username)
498 498 if not default_user:
499 499 return
500 500
501 501 u2p = UserToPerm.query()\
502 502 .filter(UserToPerm.user == default_user).all()
503 503 fixed = False
504 504 if len(u2p) != len(Permission.DEFAULT_USER_PERMISSIONS):
505 505 for p in u2p:
506 506 Session().delete(p)
507 507 fixed = True
508 508 self.populate_default_permissions()
509 509 return fixed
510 510
511 511 def config_prompt(self, test_repo_path='', retries=3):
512 512 defaults = self.cli_args
513 513 _path = defaults.get('repos_location')
514 514 if retries == 3:
515 515 log.info('Setting up repositories config')
516 516
517 517 if _path is not None:
518 518 path = _path
519 519 elif not self.tests and not test_repo_path:
520 path = input(
520 path = eval(input(
521 521 'Enter a valid absolute path to store repositories. '
522 522 'All repositories in that path will be added automatically:'
523 )
523 ))
524 524 else:
525 525 path = test_repo_path
526 526 path_ok = True
527 527
528 528 # check proper dir
529 529 if not os.path.isdir(path):
530 530 path_ok = False
531 531 log.error('Given path %s is not a valid directory', path)
532 532
533 533 elif not os.path.isabs(path):
534 534 path_ok = False
535 535 log.error('Given path %s is not an absolute path', path)
536 536
537 537 # check if path is at least readable.
538 538 if not os.access(path, os.R_OK):
539 539 path_ok = False
540 540 log.error('Given path %s is not readable', path)
541 541
542 542 # check write access, warn user about non writeable paths
543 543 elif not os.access(path, os.W_OK) and path_ok:
544 544 log.warning('No write permission to given path %s', path)
545 545
546 546 q = ('Given path %s is not writeable, do you want to '
547 547 'continue with read only mode ? [y/n]' % (path,))
548 548 if not self.ask_ok(q):
549 549 log.error('Canceled by user')
550 550 sys.exit(-1)
551 551
552 552 if retries == 0:
553 553 sys.exit('max retries reached')
554 554 if not path_ok:
555 555 retries -= 1
556 556 return self.config_prompt(test_repo_path, retries)
557 557
558 558 real_path = os.path.normpath(os.path.realpath(path))
559 559
560 560 if real_path != os.path.normpath(path):
561 561 q = ('Path looks like a symlink, RhodeCode Enterprise will store '
562 562 'given path as %s ? [y/n]') % (real_path,)
563 563 if not self.ask_ok(q):
564 564 log.error('Canceled by user')
565 565 sys.exit(-1)
566 566
567 567 return real_path
568 568
569 569 def create_settings(self, path):
570 570
571 571 self.create_ui_settings(path)
572 572
573 573 ui_config = [
574 574 ('web', 'push_ssl', 'False'),
575 575 ('web', 'allow_archive', 'gz zip bz2'),
576 576 ('web', 'allow_push', '*'),
577 577 ('web', 'baseurl', '/'),
578 578 ('paths', '/', path),
579 579 ('phases', 'publish', 'True')
580 580 ]
581 581 for section, key, value in ui_config:
582 582 ui_conf = RhodeCodeUi()
583 583 setattr(ui_conf, 'ui_section', section)
584 584 setattr(ui_conf, 'ui_key', key)
585 585 setattr(ui_conf, 'ui_value', value)
586 586 self.sa.add(ui_conf)
587 587
588 588 # rhodecode app settings
589 589 settings = [
590 590 ('realm', 'RhodeCode', 'unicode'),
591 591 ('title', '', 'unicode'),
592 592 ('pre_code', '', 'unicode'),
593 593 ('post_code', '', 'unicode'),
594 594
595 595 # Visual
596 596 ('show_public_icon', True, 'bool'),
597 597 ('show_private_icon', True, 'bool'),
598 598 ('stylify_metatags', True, 'bool'),
599 599 ('dashboard_items', 100, 'int'),
600 600 ('admin_grid_items', 25, 'int'),
601 601
602 602 ('markup_renderer', 'markdown', 'unicode'),
603 603
604 604 ('repository_fields', True, 'bool'),
605 605 ('show_version', True, 'bool'),
606 606 ('show_revision_number', True, 'bool'),
607 607 ('show_sha_length', 12, 'int'),
608 608
609 609 ('use_gravatar', False, 'bool'),
610 610 ('gravatar_url', User.DEFAULT_GRAVATAR_URL, 'unicode'),
611 611
612 612 ('clone_uri_tmpl', Repository.DEFAULT_CLONE_URI, 'unicode'),
613 613 ('clone_uri_id_tmpl', Repository.DEFAULT_CLONE_URI_ID, 'unicode'),
614 614 ('clone_uri_ssh_tmpl', Repository.DEFAULT_CLONE_URI_SSH, 'unicode'),
615 615 ('support_url', '', 'unicode'),
616 616 ('update_url', RhodeCodeSetting.DEFAULT_UPDATE_URL, 'unicode'),
617 617
618 618 # VCS Settings
619 619 ('pr_merge_enabled', True, 'bool'),
620 620 ('use_outdated_comments', True, 'bool'),
621 621 ('diff_cache', True, 'bool'),
622 622 ]
623 623
624 624 for key, val, type_ in settings:
625 625 sett = RhodeCodeSetting(key, val, type_)
626 626 self.sa.add(sett)
627 627
628 628 self.create_auth_plugin_options()
629 629 self.create_default_options()
630 630
631 631 log.info('created ui config')
632 632
633 633 def create_user(self, username, password, email='', admin=False,
634 634 strict_creation_check=True, api_key=None):
635 635 log.info('creating user `%s`', username)
636 636 user = UserModel().create_or_update(
637 username, password, email, firstname=u'RhodeCode', lastname=u'Admin',
637 username, password, email, firstname='RhodeCode', lastname='Admin',
638 638 active=True, admin=admin, extern_type="rhodecode",
639 639 strict_creation_check=strict_creation_check)
640 640
641 641 if api_key:
642 642 log.info('setting a new default auth token for user `%s`', username)
643 643 UserModel().add_auth_token(
644 644 user=user, lifetime_minutes=-1,
645 645 role=UserModel.auth_token_role.ROLE_ALL,
646 description=u'BUILTIN TOKEN')
646 description='BUILTIN TOKEN')
647 647
648 648 def create_default_user(self):
649 649 log.info('creating default user')
650 650 # create default user for handling default permissions.
651 651 user = UserModel().create_or_update(username=User.DEFAULT_USER,
652 652 password=str(uuid.uuid1())[:20],
653 653 email=User.DEFAULT_USER_EMAIL,
654 firstname=u'Anonymous',
655 lastname=u'User',
654 firstname='Anonymous',
655 lastname='User',
656 656 strict_creation_check=False)
657 657 # based on configuration options activate/de-activate this user which
658 # controlls anonymous access
658 # controls anonymous access
659 659 if self.cli_args.get('public_access') is False:
660 660 log.info('Public access disabled')
661 661 user.active = False
662 662 Session().add(user)
663 663 Session().commit()
664 664
665 665 def create_permissions(self):
666 666 """
667 667 Creates all permissions defined in the system
668 668 """
669 669 # module.(access|create|change|delete)_[name]
670 670 # module.(none|read|write|admin)
671 671 log.info('creating permissions')
672 672 PermissionModel(self.sa).create_permissions()
673 673
674 674 def populate_default_permissions(self):
675 675 """
676 676 Populate default permissions. It will create only the default
677 677 permissions that are missing, and not alter already defined ones
678 678 """
679 679 log.info('creating default user permissions')
680 680 PermissionModel(self.sa).create_default_user_permissions(user=User.DEFAULT_USER)
@@ -1,2031 +1,2031 b''
1 1 """Diff Match and Patch
2 2
3 3 Copyright 2006 Google Inc.
4 4 http://code.google.com/p/google-diff-match-patch/
5 5
6 6 Licensed under the Apache License, Version 2.0 (the "License");
7 7 you may not use this file except in compliance with the License.
8 8 You may obtain a copy of the License at
9 9
10 10 http://www.apache.org/licenses/LICENSE-2.0
11 11
12 12 Unless required by applicable law or agreed to in writing, software
13 13 distributed under the License is distributed on an "AS IS" BASIS,
14 14 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 15 See the License for the specific language governing permissions and
16 16 limitations under the License.
17 17 """
18 18
19 19 """Functions for diff, match and patch.
20 20
21 21 Computes the difference between two texts to create a patch.
22 22 Applies the patch onto another text, allowing for errors.
23 23 """
24 24
25 25 __author__ = "fraser@google.com (Neil Fraser)"
26 26
27 27 import math
28 28 import re
29 29 import sys
30 30 import time
31 31 import urllib.request, urllib.parse, urllib.error
32 32
33 33
34 34 class diff_match_patch:
35 35 """Class containing the diff, match and patch methods.
36 36
37 37 Also contains the behaviour settings.
38 38 """
39 39
40 40 def __init__(self):
41 41 """Inits a diff_match_patch object with default settings.
42 42 Redefine these in your program to override the defaults.
43 43 """
44 44
45 45 # Number of seconds to map a diff before giving up (0 for infinity).
46 46 self.Diff_Timeout = 1.0
47 47 # Cost of an empty edit operation in terms of edit characters.
48 48 self.Diff_EditCost = 4
49 49 # At what point is no match declared (0.0 = perfection, 1.0 = very loose).
50 50 self.Match_Threshold = 0.5
51 51 # How far to search for a match (0 = exact location, 1000+ = broad match).
52 52 # A match this many characters away from the expected location will add
53 53 # 1.0 to the score (0.0 is a perfect match).
54 54 self.Match_Distance = 1000
55 55 # When deleting a large block of text (over ~64 characters), how close do
56 56 # the contents have to be to match the expected contents. (0.0 = perfection,
57 57 # 1.0 = very loose). Note that Match_Threshold controls how closely the
58 58 # end points of a delete need to match.
59 59 self.Patch_DeleteThreshold = 0.5
60 60 # Chunk size for context length.
61 61 self.Patch_Margin = 4
62 62
63 63 # The number of bits in an int.
64 64 # Python has no maximum, thus to disable patch splitting set to 0.
65 65 # However to avoid long patches in certain pathological cases, use 32.
66 66 # Multiple short patches (using native ints) are much faster than long ones.
67 67 self.Match_MaxBits = 32
68 68
69 69 # DIFF FUNCTIONS
70 70
71 71 # The data structure representing a diff is an array of tuples:
72 72 # [(DIFF_DELETE, "Hello"), (DIFF_INSERT, "Goodbye"), (DIFF_EQUAL, " world.")]
73 73 # which means: delete "Hello", add "Goodbye" and keep " world."
74 74 DIFF_DELETE = -1
75 75 DIFF_INSERT = 1
76 76 DIFF_EQUAL = 0
77 77
78 78 def diff_main(self, text1, text2, checklines=True, deadline=None):
79 79 """Find the differences between two texts. Simplifies the problem by
80 80 stripping any common prefix or suffix off the texts before diffing.
81 81
82 82 Args:
83 83 text1: Old string to be diffed.
84 84 text2: New string to be diffed.
85 85 checklines: Optional speedup flag. If present and false, then don't run
86 86 a line-level diff first to identify the changed areas.
87 87 Defaults to true, which does a faster, slightly less optimal diff.
88 88 deadline: Optional time when the diff should be complete by. Used
89 89 internally for recursive calls. Users should set DiffTimeout instead.
90 90
91 91 Returns:
92 92 Array of changes.
93 93 """
94 94 # Set a deadline by which time the diff must be complete.
95 95 if deadline is None:
96 96 # Unlike in most languages, Python counts time in seconds.
97 97 if self.Diff_Timeout <= 0:
98 98 deadline = sys.maxsize
99 99 else:
100 100 deadline = time.time() + self.Diff_Timeout
101 101
102 102 # Check for null inputs.
103 103 if text1 is None or text2 is None:
104 104 raise ValueError("Null inputs. (diff_main)")
105 105
106 106 # Check for equality (speedup).
107 107 if text1 == text2:
108 108 if text1:
109 109 return [(self.DIFF_EQUAL, text1)]
110 110 return []
111 111
112 112 # Trim off common prefix (speedup).
113 113 commonlength = self.diff_commonPrefix(text1, text2)
114 114 commonprefix = text1[:commonlength]
115 115 text1 = text1[commonlength:]
116 116 text2 = text2[commonlength:]
117 117
118 118 # Trim off common suffix (speedup).
119 119 commonlength = self.diff_commonSuffix(text1, text2)
120 120 if commonlength == 0:
121 121 commonsuffix = ""
122 122 else:
123 123 commonsuffix = text1[-commonlength:]
124 124 text1 = text1[:-commonlength]
125 125 text2 = text2[:-commonlength]
126 126
127 127 # Compute the diff on the middle block.
128 128 diffs = self.diff_compute(text1, text2, checklines, deadline)
129 129
130 130 # Restore the prefix and suffix.
131 131 if commonprefix:
132 132 diffs[:0] = [(self.DIFF_EQUAL, commonprefix)]
133 133 if commonsuffix:
134 134 diffs.append((self.DIFF_EQUAL, commonsuffix))
135 135 self.diff_cleanupMerge(diffs)
136 136 return diffs
137 137
138 138 def diff_compute(self, text1, text2, checklines, deadline):
139 139 """Find the differences between two texts. Assumes that the texts do not
140 140 have any common prefix or suffix.
141 141
142 142 Args:
143 143 text1: Old string to be diffed.
144 144 text2: New string to be diffed.
145 145 checklines: Speedup flag. If false, then don't run a line-level diff
146 146 first to identify the changed areas.
147 147 If true, then run a faster, slightly less optimal diff.
148 148 deadline: Time when the diff should be complete by.
149 149
150 150 Returns:
151 151 Array of changes.
152 152 """
153 153 if not text1:
154 154 # Just add some text (speedup).
155 155 return [(self.DIFF_INSERT, text2)]
156 156
157 157 if not text2:
158 158 # Just delete some text (speedup).
159 159 return [(self.DIFF_DELETE, text1)]
160 160
161 161 if len(text1) > len(text2):
162 162 (longtext, shorttext) = (text1, text2)
163 163 else:
164 164 (shorttext, longtext) = (text1, text2)
165 165 i = longtext.find(shorttext)
166 166 if i != -1:
167 167 # Shorter text is inside the longer text (speedup).
168 168 diffs = [
169 169 (self.DIFF_INSERT, longtext[:i]),
170 170 (self.DIFF_EQUAL, shorttext),
171 171 (self.DIFF_INSERT, longtext[i + len(shorttext) :]),
172 172 ]
173 173 # Swap insertions for deletions if diff is reversed.
174 174 if len(text1) > len(text2):
175 175 diffs[0] = (self.DIFF_DELETE, diffs[0][1])
176 176 diffs[2] = (self.DIFF_DELETE, diffs[2][1])
177 177 return diffs
178 178
179 179 if len(shorttext) == 1:
180 180 # Single character string.
181 181 # After the previous speedup, the character can't be an equality.
182 182 return [(self.DIFF_DELETE, text1), (self.DIFF_INSERT, text2)]
183 183
184 184 # Check to see if the problem can be split in two.
185 185 hm = self.diff_halfMatch(text1, text2)
186 186 if hm:
187 187 # A half-match was found, sort out the return data.
188 188 (text1_a, text1_b, text2_a, text2_b, mid_common) = hm
189 189 # Send both pairs off for separate processing.
190 190 diffs_a = self.diff_main(text1_a, text2_a, checklines, deadline)
191 191 diffs_b = self.diff_main(text1_b, text2_b, checklines, deadline)
192 192 # Merge the results.
193 193 return diffs_a + [(self.DIFF_EQUAL, mid_common)] + diffs_b
194 194
195 195 if checklines and len(text1) > 100 and len(text2) > 100:
196 196 return self.diff_lineMode(text1, text2, deadline)
197 197
198 198 return self.diff_bisect(text1, text2, deadline)
199 199
200 200 def diff_lineMode(self, text1, text2, deadline):
201 201 """Do a quick line-level diff on both strings, then rediff the parts for
202 202 greater accuracy.
203 203 This speedup can produce non-minimal diffs.
204 204
205 205 Args:
206 206 text1: Old string to be diffed.
207 207 text2: New string to be diffed.
208 208 deadline: Time when the diff should be complete by.
209 209
210 210 Returns:
211 211 Array of changes.
212 212 """
213 213
214 214 # Scan the text on a line-by-line basis first.
215 215 (text1, text2, linearray) = self.diff_linesToChars(text1, text2)
216 216
217 217 diffs = self.diff_main(text1, text2, False, deadline)
218 218
219 219 # Convert the diff back to original text.
220 220 self.diff_charsToLines(diffs, linearray)
221 221 # Eliminate freak matches (e.g. blank lines)
222 222 self.diff_cleanupSemantic(diffs)
223 223
224 224 # Rediff any replacement blocks, this time character-by-character.
225 225 # Add a dummy entry at the end.
226 226 diffs.append((self.DIFF_EQUAL, ""))
227 227 pointer = 0
228 228 count_delete = 0
229 229 count_insert = 0
230 230 text_delete = ""
231 231 text_insert = ""
232 232 while pointer < len(diffs):
233 233 if diffs[pointer][0] == self.DIFF_INSERT:
234 234 count_insert += 1
235 235 text_insert += diffs[pointer][1]
236 236 elif diffs[pointer][0] == self.DIFF_DELETE:
237 237 count_delete += 1
238 238 text_delete += diffs[pointer][1]
239 239 elif diffs[pointer][0] == self.DIFF_EQUAL:
240 240 # Upon reaching an equality, check for prior redundancies.
241 241 if count_delete >= 1 and count_insert >= 1:
242 242 # Delete the offending records and add the merged ones.
243 243 a = self.diff_main(text_delete, text_insert, False, deadline)
244 244 diffs[pointer - count_delete - count_insert : pointer] = a
245 245 pointer = pointer - count_delete - count_insert + len(a)
246 246 count_insert = 0
247 247 count_delete = 0
248 248 text_delete = ""
249 249 text_insert = ""
250 250
251 251 pointer += 1
252 252
253 253 diffs.pop() # Remove the dummy entry at the end.
254 254
255 255 return diffs
256 256
257 257 def diff_bisect(self, text1, text2, deadline):
258 258 """Find the 'middle snake' of a diff, split the problem in two
259 259 and return the recursively constructed diff.
260 260 See Myers 1986 paper: An O(ND) Difference Algorithm and Its Variations.
261 261
262 262 Args:
263 263 text1: Old string to be diffed.
264 264 text2: New string to be diffed.
265 265 deadline: Time at which to bail if not yet complete.
266 266
267 267 Returns:
268 268 Array of diff tuples.
269 269 """
270 270
271 271 # Cache the text lengths to prevent multiple calls.
272 272 text1_length = len(text1)
273 273 text2_length = len(text2)
274 274 max_d = (text1_length + text2_length + 1) // 2
275 275 v_offset = max_d
276 276 v_length = 2 * max_d
277 277 v1 = [-1] * v_length
278 278 v1[v_offset + 1] = 0
279 279 v2 = v1[:]
280 280 delta = text1_length - text2_length
281 281 # If the total number of characters is odd, then the front path will
282 282 # collide with the reverse path.
283 283 front = delta % 2 != 0
284 284 # Offsets for start and end of k loop.
285 285 # Prevents mapping of space beyond the grid.
286 286 k1start = 0
287 287 k1end = 0
288 288 k2start = 0
289 289 k2end = 0
290 290 for d in range(max_d):
291 291 # Bail out if deadline is reached.
292 292 if time.time() > deadline:
293 293 break
294 294
295 295 # Walk the front path one step.
296 296 for k1 in range(-d + k1start, d + 1 - k1end, 2):
297 297 k1_offset = v_offset + k1
298 298 if k1 == -d or (k1 != d and v1[k1_offset - 1] < v1[k1_offset + 1]):
299 299 x1 = v1[k1_offset + 1]
300 300 else:
301 301 x1 = v1[k1_offset - 1] + 1
302 302 y1 = x1 - k1
303 303 while (
304 304 x1 < text1_length and y1 < text2_length and text1[x1] == text2[y1]
305 305 ):
306 306 x1 += 1
307 307 y1 += 1
308 308 v1[k1_offset] = x1
309 309 if x1 > text1_length:
310 310 # Ran off the right of the graph.
311 311 k1end += 2
312 312 elif y1 > text2_length:
313 313 # Ran off the bottom of the graph.
314 314 k1start += 2
315 315 elif front:
316 316 k2_offset = v_offset + delta - k1
317 317 if k2_offset >= 0 and k2_offset < v_length and v2[k2_offset] != -1:
318 318 # Mirror x2 onto top-left coordinate system.
319 319 x2 = text1_length - v2[k2_offset]
320 320 if x1 >= x2:
321 321 # Overlap detected.
322 322 return self.diff_bisectSplit(text1, text2, x1, y1, deadline)
323 323
324 324 # Walk the reverse path one step.
325 325 for k2 in range(-d + k2start, d + 1 - k2end, 2):
326 326 k2_offset = v_offset + k2
327 327 if k2 == -d or (k2 != d and v2[k2_offset - 1] < v2[k2_offset + 1]):
328 328 x2 = v2[k2_offset + 1]
329 329 else:
330 330 x2 = v2[k2_offset - 1] + 1
331 331 y2 = x2 - k2
332 332 while (
333 333 x2 < text1_length
334 334 and y2 < text2_length
335 335 and text1[-x2 - 1] == text2[-y2 - 1]
336 336 ):
337 337 x2 += 1
338 338 y2 += 1
339 339 v2[k2_offset] = x2
340 340 if x2 > text1_length:
341 341 # Ran off the left of the graph.
342 342 k2end += 2
343 343 elif y2 > text2_length:
344 344 # Ran off the top of the graph.
345 345 k2start += 2
346 346 elif not front:
347 347 k1_offset = v_offset + delta - k2
348 348 if k1_offset >= 0 and k1_offset < v_length and v1[k1_offset] != -1:
349 349 x1 = v1[k1_offset]
350 350 y1 = v_offset + x1 - k1_offset
351 351 # Mirror x2 onto top-left coordinate system.
352 352 x2 = text1_length - x2
353 353 if x1 >= x2:
354 354 # Overlap detected.
355 355 return self.diff_bisectSplit(text1, text2, x1, y1, deadline)
356 356
357 357 # Diff took too long and hit the deadline or
358 358 # number of diffs equals number of characters, no commonality at all.
359 359 return [(self.DIFF_DELETE, text1), (self.DIFF_INSERT, text2)]
360 360
361 361 def diff_bisectSplit(self, text1, text2, x, y, deadline):
362 362 """Given the location of the 'middle snake', split the diff in two parts
363 363 and recurse.
364 364
365 365 Args:
366 366 text1: Old string to be diffed.
367 367 text2: New string to be diffed.
368 368 x: Index of split point in text1.
369 369 y: Index of split point in text2.
370 370 deadline: Time at which to bail if not yet complete.
371 371
372 372 Returns:
373 373 Array of diff tuples.
374 374 """
375 375 text1a = text1[:x]
376 376 text2a = text2[:y]
377 377 text1b = text1[x:]
378 378 text2b = text2[y:]
379 379
380 380 # Compute both diffs serially.
381 381 diffs = self.diff_main(text1a, text2a, False, deadline)
382 382 diffsb = self.diff_main(text1b, text2b, False, deadline)
383 383
384 384 return diffs + diffsb
385 385
386 386 def diff_linesToChars(self, text1, text2):
387 387 """Split two texts into an array of strings. Reduce the texts to a string
388 388 of hashes where each Unicode character represents one line.
389 389
390 390 Args:
391 391 text1: First string.
392 392 text2: Second string.
393 393
394 394 Returns:
395 395 Three element tuple, containing the encoded text1, the encoded text2 and
396 396 the array of unique strings. The zeroth element of the array of unique
397 397 strings is intentionally blank.
398 398 """
399 399 lineArray = [] # e.g. lineArray[4] == "Hello\n"
400 400 lineHash = {} # e.g. lineHash["Hello\n"] == 4
401 401
402 402 # "\x00" is a valid character, but various debuggers don't like it.
403 403 # So we'll insert a junk entry to avoid generating a null character.
404 404 lineArray.append("")
405 405
406 406 def diff_linesToCharsMunge(text):
407 407 """Split a text into an array of strings. Reduce the texts to a string
408 408 of hashes where each Unicode character represents one line.
409 409 Modifies linearray and linehash through being a closure.
410 410
411 411 Args:
412 412 text: String to encode.
413 413
414 414 Returns:
415 415 Encoded string.
416 416 """
417 417 chars = []
418 418 # Walk the text, pulling out a substring for each line.
419 419 # text.split('\n') would would temporarily double our memory footprint.
420 420 # Modifying text would create many large strings to garbage collect.
421 421 lineStart = 0
422 422 lineEnd = -1
423 423 while lineEnd < len(text) - 1:
424 424 lineEnd = text.find("\n", lineStart)
425 425 if lineEnd == -1:
426 426 lineEnd = len(text) - 1
427 427 line = text[lineStart : lineEnd + 1]
428 428 lineStart = lineEnd + 1
429 429
430 430 if line in lineHash:
431 431 chars.append(chr(lineHash[line]))
432 432 else:
433 433 lineArray.append(line)
434 434 lineHash[line] = len(lineArray) - 1
435 435 chars.append(chr(len(lineArray) - 1))
436 436 return "".join(chars)
437 437
438 438 chars1 = diff_linesToCharsMunge(text1)
439 439 chars2 = diff_linesToCharsMunge(text2)
440 440 return (chars1, chars2, lineArray)
441 441
442 442 def diff_charsToLines(self, diffs, lineArray):
443 443 """Rehydrate the text in a diff from a string of line hashes to real lines
444 444 of text.
445 445
446 446 Args:
447 447 diffs: Array of diff tuples.
448 448 lineArray: Array of unique strings.
449 449 """
450 450 for x in range(len(diffs)):
451 451 text = []
452 452 for char in diffs[x][1]:
453 453 text.append(lineArray[ord(char)])
454 454 diffs[x] = (diffs[x][0], "".join(text))
455 455
456 456 def diff_commonPrefix(self, text1, text2):
457 457 """Determine the common prefix of two strings.
458 458
459 459 Args:
460 460 text1: First string.
461 461 text2: Second string.
462 462
463 463 Returns:
464 464 The number of characters common to the start of each string.
465 465 """
466 466 # Quick check for common null cases.
467 467 if not text1 or not text2 or text1[0] != text2[0]:
468 468 return 0
469 469 # Binary search.
470 470 # Performance analysis: http://neil.fraser.name/news/2007/10/09/
471 471 pointermin = 0
472 472 pointermax = min(len(text1), len(text2))
473 473 pointermid = pointermax
474 474 pointerstart = 0
475 475 while pointermin < pointermid:
476 476 if text1[pointerstart:pointermid] == text2[pointerstart:pointermid]:
477 477 pointermin = pointermid
478 478 pointerstart = pointermin
479 479 else:
480 480 pointermax = pointermid
481 481 pointermid = (pointermax - pointermin) // 2 + pointermin
482 482 return pointermid
483 483
484 484 def diff_commonSuffix(self, text1, text2):
485 485 """Determine the common suffix of two strings.
486 486
487 487 Args:
488 488 text1: First string.
489 489 text2: Second string.
490 490
491 491 Returns:
492 492 The number of characters common to the end of each string.
493 493 """
494 494 # Quick check for common null cases.
495 495 if not text1 or not text2 or text1[-1] != text2[-1]:
496 496 return 0
497 497 # Binary search.
498 498 # Performance analysis: http://neil.fraser.name/news/2007/10/09/
499 499 pointermin = 0
500 500 pointermax = min(len(text1), len(text2))
501 501 pointermid = pointermax
502 502 pointerend = 0
503 503 while pointermin < pointermid:
504 504 if (
505 505 text1[-pointermid : len(text1) - pointerend]
506 506 == text2[-pointermid : len(text2) - pointerend]
507 507 ):
508 508 pointermin = pointermid
509 509 pointerend = pointermin
510 510 else:
511 511 pointermax = pointermid
512 512 pointermid = (pointermax - pointermin) // 2 + pointermin
513 513 return pointermid
514 514
515 515 def diff_commonOverlap(self, text1, text2):
516 516 """Determine if the suffix of one string is the prefix of another.
517 517
518 518 Args:
519 519 text1 First string.
520 520 text2 Second string.
521 521
522 522 Returns:
523 523 The number of characters common to the end of the first
524 524 string and the start of the second string.
525 525 """
526 526 # Cache the text lengths to prevent multiple calls.
527 527 text1_length = len(text1)
528 528 text2_length = len(text2)
529 529 # Eliminate the null case.
530 530 if text1_length == 0 or text2_length == 0:
531 531 return 0
532 532 # Truncate the longer string.
533 533 if text1_length > text2_length:
534 534 text1 = text1[-text2_length:]
535 535 elif text1_length < text2_length:
536 536 text2 = text2[:text1_length]
537 537 text_length = min(text1_length, text2_length)
538 538 # Quick check for the worst case.
539 539 if text1 == text2:
540 540 return text_length
541 541
542 542 # Start by looking for a single character match
543 543 # and increase length until no match is found.
544 544 # Performance analysis: http://neil.fraser.name/news/2010/11/04/
545 545 best = 0
546 546 length = 1
547 547 while True:
548 548 pattern = text1[-length:]
549 549 found = text2.find(pattern)
550 550 if found == -1:
551 551 return best
552 552 length += found
553 553 if found == 0 or text1[-length:] == text2[:length]:
554 554 best = length
555 555 length += 1
556 556
557 557 def diff_halfMatch(self, text1, text2):
558 558 """Do the two texts share a substring which is at least half the length of
559 559 the longer text?
560 560 This speedup can produce non-minimal diffs.
561 561
562 562 Args:
563 563 text1: First string.
564 564 text2: Second string.
565 565
566 566 Returns:
567 567 Five element Array, containing the prefix of text1, the suffix of text1,
568 568 the prefix of text2, the suffix of text2 and the common middle. Or None
569 569 if there was no match.
570 570 """
571 571 if self.Diff_Timeout <= 0:
572 572 # Don't risk returning a non-optimal diff if we have unlimited time.
573 573 return None
574 574 if len(text1) > len(text2):
575 575 (longtext, shorttext) = (text1, text2)
576 576 else:
577 577 (shorttext, longtext) = (text1, text2)
578 578 if len(longtext) < 4 or len(shorttext) * 2 < len(longtext):
579 579 return None # Pointless.
580 580
581 581 def diff_halfMatchI(longtext, shorttext, i):
582 582 """Does a substring of shorttext exist within longtext such that the
583 583 substring is at least half the length of longtext?
584 584 Closure, but does not reference any external variables.
585 585
586 586 Args:
587 587 longtext: Longer string.
588 588 shorttext: Shorter string.
589 589 i: Start index of quarter length substring within longtext.
590 590
591 591 Returns:
592 592 Five element Array, containing the prefix of longtext, the suffix of
593 593 longtext, the prefix of shorttext, the suffix of shorttext and the
594 594 common middle. Or None if there was no match.
595 595 """
596 596 seed = longtext[i : i + len(longtext) // 4]
597 597 best_common = ""
598 598 j = shorttext.find(seed)
599 599 while j != -1:
600 600 prefixLength = self.diff_commonPrefix(longtext[i:], shorttext[j:])
601 601 suffixLength = self.diff_commonSuffix(longtext[:i], shorttext[:j])
602 602 if len(best_common) < suffixLength + prefixLength:
603 603 best_common = (
604 604 shorttext[j - suffixLength : j]
605 605 + shorttext[j : j + prefixLength]
606 606 )
607 607 best_longtext_a = longtext[: i - suffixLength]
608 608 best_longtext_b = longtext[i + prefixLength :]
609 609 best_shorttext_a = shorttext[: j - suffixLength]
610 610 best_shorttext_b = shorttext[j + prefixLength :]
611 611 j = shorttext.find(seed, j + 1)
612 612
613 613 if len(best_common) * 2 >= len(longtext):
614 614 return (
615 615 best_longtext_a,
616 616 best_longtext_b,
617 617 best_shorttext_a,
618 618 best_shorttext_b,
619 619 best_common,
620 620 )
621 621 else:
622 622 return None
623 623
624 624 # First check if the second quarter is the seed for a half-match.
625 625 hm1 = diff_halfMatchI(longtext, shorttext, (len(longtext) + 3) // 4)
626 626 # Check again based on the third quarter.
627 627 hm2 = diff_halfMatchI(longtext, shorttext, (len(longtext) + 1) // 2)
628 628 if not hm1 and not hm2:
629 629 return None
630 630 elif not hm2:
631 631 hm = hm1
632 632 elif not hm1:
633 633 hm = hm2
634 634 else:
635 635 # Both matched. Select the longest.
636 636 if len(hm1[4]) > len(hm2[4]):
637 637 hm = hm1
638 638 else:
639 639 hm = hm2
640 640
641 641 # A half-match was found, sort out the return data.
642 642 if len(text1) > len(text2):
643 643 (text1_a, text1_b, text2_a, text2_b, mid_common) = hm
644 644 else:
645 645 (text2_a, text2_b, text1_a, text1_b, mid_common) = hm
646 646 return (text1_a, text1_b, text2_a, text2_b, mid_common)
647 647
648 648 def diff_cleanupSemantic(self, diffs):
649 649 """Reduce the number of edits by eliminating semantically trivial
650 650 equalities.
651 651
652 652 Args:
653 653 diffs: Array of diff tuples.
654 654 """
655 655 changes = False
656 656 equalities = [] # Stack of indices where equalities are found.
657 657 lastequality = None # Always equal to diffs[equalities[-1]][1]
658 658 pointer = 0 # Index of current position.
659 659 # Number of chars that changed prior to the equality.
660 660 length_insertions1, length_deletions1 = 0, 0
661 661 # Number of chars that changed after the equality.
662 662 length_insertions2, length_deletions2 = 0, 0
663 663 while pointer < len(diffs):
664 664 if diffs[pointer][0] == self.DIFF_EQUAL: # Equality found.
665 665 equalities.append(pointer)
666 666 length_insertions1, length_insertions2 = length_insertions2, 0
667 667 length_deletions1, length_deletions2 = length_deletions2, 0
668 668 lastequality = diffs[pointer][1]
669 669 else: # An insertion or deletion.
670 670 if diffs[pointer][0] == self.DIFF_INSERT:
671 671 length_insertions2 += len(diffs[pointer][1])
672 672 else:
673 673 length_deletions2 += len(diffs[pointer][1])
674 674 # Eliminate an equality that is smaller or equal to the edits on both
675 675 # sides of it.
676 676 if (
677 677 lastequality
678 678 and (
679 679 len(lastequality) <= max(length_insertions1, length_deletions1)
680 680 )
681 681 and (
682 682 len(lastequality) <= max(length_insertions2, length_deletions2)
683 683 )
684 684 ):
685 685 # Duplicate record.
686 686 diffs.insert(equalities[-1], (self.DIFF_DELETE, lastequality))
687 687 # Change second copy to insert.
688 688 diffs[equalities[-1] + 1] = (
689 689 self.DIFF_INSERT,
690 690 diffs[equalities[-1] + 1][1],
691 691 )
692 692 # Throw away the equality we just deleted.
693 693 equalities.pop()
694 694 # Throw away the previous equality (it needs to be reevaluated).
695 695 if len(equalities):
696 696 equalities.pop()
697 697 if len(equalities):
698 698 pointer = equalities[-1]
699 699 else:
700 700 pointer = -1
701 701 # Reset the counters.
702 702 length_insertions1, length_deletions1 = 0, 0
703 703 length_insertions2, length_deletions2 = 0, 0
704 704 lastequality = None
705 705 changes = True
706 706 pointer += 1
707 707
708 708 # Normalize the diff.
709 709 if changes:
710 710 self.diff_cleanupMerge(diffs)
711 711 self.diff_cleanupSemanticLossless(diffs)
712 712
713 713 # Find any overlaps between deletions and insertions.
714 714 # e.g: <del>abcxxx</del><ins>xxxdef</ins>
715 715 # -> <del>abc</del>xxx<ins>def</ins>
716 716 # e.g: <del>xxxabc</del><ins>defxxx</ins>
717 717 # -> <ins>def</ins>xxx<del>abc</del>
718 718 # Only extract an overlap if it is as big as the edit ahead or behind it.
719 719 pointer = 1
720 720 while pointer < len(diffs):
721 721 if (
722 722 diffs[pointer - 1][0] == self.DIFF_DELETE
723 723 and diffs[pointer][0] == self.DIFF_INSERT
724 724 ):
725 725 deletion = diffs[pointer - 1][1]
726 726 insertion = diffs[pointer][1]
727 727 overlap_length1 = self.diff_commonOverlap(deletion, insertion)
728 728 overlap_length2 = self.diff_commonOverlap(insertion, deletion)
729 729 if overlap_length1 >= overlap_length2:
730 730 if (
731 731 overlap_length1 >= len(deletion) / 2.0
732 732 or overlap_length1 >= len(insertion) / 2.0
733 733 ):
734 734 # Overlap found. Insert an equality and trim the surrounding edits.
735 735 diffs.insert(
736 736 pointer, (self.DIFF_EQUAL, insertion[:overlap_length1])
737 737 )
738 738 diffs[pointer - 1] = (
739 739 self.DIFF_DELETE,
740 740 deletion[: len(deletion) - overlap_length1],
741 741 )
742 742 diffs[pointer + 1] = (
743 743 self.DIFF_INSERT,
744 744 insertion[overlap_length1:],
745 745 )
746 746 pointer += 1
747 747 else:
748 748 if (
749 749 overlap_length2 >= len(deletion) / 2.0
750 750 or overlap_length2 >= len(insertion) / 2.0
751 751 ):
752 752 # Reverse overlap found.
753 753 # Insert an equality and swap and trim the surrounding edits.
754 754 diffs.insert(
755 755 pointer, (self.DIFF_EQUAL, deletion[:overlap_length2])
756 756 )
757 757 diffs[pointer - 1] = (
758 758 self.DIFF_INSERT,
759 759 insertion[: len(insertion) - overlap_length2],
760 760 )
761 761 diffs[pointer + 1] = (
762 762 self.DIFF_DELETE,
763 763 deletion[overlap_length2:],
764 764 )
765 765 pointer += 1
766 766 pointer += 1
767 767 pointer += 1
768 768
769 769 def diff_cleanupSemanticLossless(self, diffs):
770 770 """Look for single edits surrounded on both sides by equalities
771 771 which can be shifted sideways to align the edit to a word boundary.
772 772 e.g: The c<ins>at c</ins>ame. -> The <ins>cat </ins>came.
773 773
774 774 Args:
775 775 diffs: Array of diff tuples.
776 776 """
777 777
778 778 def diff_cleanupSemanticScore(one, two):
779 779 """Given two strings, compute a score representing whether the
780 780 internal boundary falls on logical boundaries.
781 781 Scores range from 6 (best) to 0 (worst).
782 782 Closure, but does not reference any external variables.
783 783
784 784 Args:
785 785 one: First string.
786 786 two: Second string.
787 787
788 788 Returns:
789 789 The score.
790 790 """
791 791 if not one or not two:
792 792 # Edges are the best.
793 793 return 6
794 794
795 795 # Each port of this function behaves slightly differently due to
796 796 # subtle differences in each language's definition of things like
797 797 # 'whitespace'. Since this function's purpose is largely cosmetic,
798 798 # the choice has been made to use each language's native features
799 799 # rather than force total conformity.
800 800 char1 = one[-1]
801 801 char2 = two[0]
802 802 nonAlphaNumeric1 = not char1.isalnum()
803 803 nonAlphaNumeric2 = not char2.isalnum()
804 804 whitespace1 = nonAlphaNumeric1 and char1.isspace()
805 805 whitespace2 = nonAlphaNumeric2 and char2.isspace()
806 806 lineBreak1 = whitespace1 and (char1 == "\r" or char1 == "\n")
807 807 lineBreak2 = whitespace2 and (char2 == "\r" or char2 == "\n")
808 808 blankLine1 = lineBreak1 and self.BLANKLINEEND.search(one)
809 809 blankLine2 = lineBreak2 and self.BLANKLINESTART.match(two)
810 810
811 811 if blankLine1 or blankLine2:
812 812 # Five points for blank lines.
813 813 return 5
814 814 elif lineBreak1 or lineBreak2:
815 815 # Four points for line breaks.
816 816 return 4
817 817 elif nonAlphaNumeric1 and not whitespace1 and whitespace2:
818 818 # Three points for end of sentences.
819 819 return 3
820 820 elif whitespace1 or whitespace2:
821 821 # Two points for whitespace.
822 822 return 2
823 823 elif nonAlphaNumeric1 or nonAlphaNumeric2:
824 824 # One point for non-alphanumeric.
825 825 return 1
826 826 return 0
827 827
828 828 pointer = 1
829 829 # Intentionally ignore the first and last element (don't need checking).
830 830 while pointer < len(diffs) - 1:
831 831 if (
832 832 diffs[pointer - 1][0] == self.DIFF_EQUAL
833 833 and diffs[pointer + 1][0] == self.DIFF_EQUAL
834 834 ):
835 835 # This is a single edit surrounded by equalities.
836 836 equality1 = diffs[pointer - 1][1]
837 837 edit = diffs[pointer][1]
838 838 equality2 = diffs[pointer + 1][1]
839 839
840 840 # First, shift the edit as far left as possible.
841 841 commonOffset = self.diff_commonSuffix(equality1, edit)
842 842 if commonOffset:
843 843 commonString = edit[-commonOffset:]
844 844 equality1 = equality1[:-commonOffset]
845 845 edit = commonString + edit[:-commonOffset]
846 846 equality2 = commonString + equality2
847 847
848 848 # Second, step character by character right, looking for the best fit.
849 849 bestEquality1 = equality1
850 850 bestEdit = edit
851 851 bestEquality2 = equality2
852 852 bestScore = diff_cleanupSemanticScore(
853 853 equality1, edit
854 854 ) + diff_cleanupSemanticScore(edit, equality2)
855 855 while edit and equality2 and edit[0] == equality2[0]:
856 856 equality1 += edit[0]
857 857 edit = edit[1:] + equality2[0]
858 858 equality2 = equality2[1:]
859 859 score = diff_cleanupSemanticScore(
860 860 equality1, edit
861 861 ) + diff_cleanupSemanticScore(edit, equality2)
862 862 # The >= encourages trailing rather than leading whitespace on edits.
863 863 if score >= bestScore:
864 864 bestScore = score
865 865 bestEquality1 = equality1
866 866 bestEdit = edit
867 867 bestEquality2 = equality2
868 868
869 869 if diffs[pointer - 1][1] != bestEquality1:
870 870 # We have an improvement, save it back to the diff.
871 871 if bestEquality1:
872 872 diffs[pointer - 1] = (diffs[pointer - 1][0], bestEquality1)
873 873 else:
874 874 del diffs[pointer - 1]
875 875 pointer -= 1
876 876 diffs[pointer] = (diffs[pointer][0], bestEdit)
877 877 if bestEquality2:
878 878 diffs[pointer + 1] = (diffs[pointer + 1][0], bestEquality2)
879 879 else:
880 880 del diffs[pointer + 1]
881 881 pointer -= 1
882 882 pointer += 1
883 883
884 884 # Define some regex patterns for matching boundaries.
885 885 BLANKLINEEND = re.compile(r"\n\r?\n$")
886 886 BLANKLINESTART = re.compile(r"^\r?\n\r?\n")
887 887
888 888 def diff_cleanupEfficiency(self, diffs):
889 889 """Reduce the number of edits by eliminating operationally trivial
890 890 equalities.
891 891
892 892 Args:
893 893 diffs: Array of diff tuples.
894 894 """
895 895 changes = False
896 896 equalities = [] # Stack of indices where equalities are found.
897 897 lastequality = None # Always equal to diffs[equalities[-1]][1]
898 898 pointer = 0 # Index of current position.
899 899 pre_ins = False # Is there an insertion operation before the last equality.
900 900 pre_del = False # Is there a deletion operation before the last equality.
901 901 post_ins = False # Is there an insertion operation after the last equality.
902 902 post_del = False # Is there a deletion operation after the last equality.
903 903 while pointer < len(diffs):
904 904 if diffs[pointer][0] == self.DIFF_EQUAL: # Equality found.
905 905 if len(diffs[pointer][1]) < self.Diff_EditCost and (
906 906 post_ins or post_del
907 907 ):
908 908 # Candidate found.
909 909 equalities.append(pointer)
910 910 pre_ins = post_ins
911 911 pre_del = post_del
912 912 lastequality = diffs[pointer][1]
913 913 else:
914 914 # Not a candidate, and can never become one.
915 915 equalities = []
916 916 lastequality = None
917 917
918 918 post_ins = post_del = False
919 919 else: # An insertion or deletion.
920 920 if diffs[pointer][0] == self.DIFF_DELETE:
921 921 post_del = True
922 922 else:
923 923 post_ins = True
924 924
925 925 # Five types to be split:
926 926 # <ins>A</ins><del>B</del>XY<ins>C</ins><del>D</del>
927 927 # <ins>A</ins>X<ins>C</ins><del>D</del>
928 928 # <ins>A</ins><del>B</del>X<ins>C</ins>
929 929 # <ins>A</del>X<ins>C</ins><del>D</del>
930 930 # <ins>A</ins><del>B</del>X<del>C</del>
931 931
932 932 if lastequality and (
933 933 (pre_ins and pre_del and post_ins and post_del)
934 934 or (
935 935 (len(lastequality) < self.Diff_EditCost / 2)
936 936 and (pre_ins + pre_del + post_ins + post_del) == 3
937 937 )
938 938 ):
939 939 # Duplicate record.
940 940 diffs.insert(equalities[-1], (self.DIFF_DELETE, lastequality))
941 941 # Change second copy to insert.
942 942 diffs[equalities[-1] + 1] = (
943 943 self.DIFF_INSERT,
944 944 diffs[equalities[-1] + 1][1],
945 945 )
946 946 equalities.pop() # Throw away the equality we just deleted.
947 947 lastequality = None
948 948 if pre_ins and pre_del:
949 949 # No changes made which could affect previous entry, keep going.
950 950 post_ins = post_del = True
951 951 equalities = []
952 952 else:
953 953 if len(equalities):
954 954 equalities.pop() # Throw away the previous equality.
955 955 if len(equalities):
956 956 pointer = equalities[-1]
957 957 else:
958 958 pointer = -1
959 959 post_ins = post_del = False
960 960 changes = True
961 961 pointer += 1
962 962
963 963 if changes:
964 964 self.diff_cleanupMerge(diffs)
965 965
966 966 def diff_cleanupMerge(self, diffs):
967 967 """Reorder and merge like edit sections. Merge equalities.
968 968 Any edit section can move as long as it doesn't cross an equality.
969 969
970 970 Args:
971 971 diffs: Array of diff tuples.
972 972 """
973 973 diffs.append((self.DIFF_EQUAL, "")) # Add a dummy entry at the end.
974 974 pointer = 0
975 975 count_delete = 0
976 976 count_insert = 0
977 977 text_delete = ""
978 978 text_insert = ""
979 979 while pointer < len(diffs):
980 980 if diffs[pointer][0] == self.DIFF_INSERT:
981 981 count_insert += 1
982 982 text_insert += diffs[pointer][1]
983 983 pointer += 1
984 984 elif diffs[pointer][0] == self.DIFF_DELETE:
985 985 count_delete += 1
986 986 text_delete += diffs[pointer][1]
987 987 pointer += 1
988 988 elif diffs[pointer][0] == self.DIFF_EQUAL:
989 989 # Upon reaching an equality, check for prior redundancies.
990 990 if count_delete + count_insert > 1:
991 991 if count_delete != 0 and count_insert != 0:
992 992 # Factor out any common prefixies.
993 993 commonlength = self.diff_commonPrefix(text_insert, text_delete)
994 994 if commonlength != 0:
995 995 x = pointer - count_delete - count_insert - 1
996 996 if x >= 0 and diffs[x][0] == self.DIFF_EQUAL:
997 997 diffs[x] = (
998 998 diffs[x][0],
999 999 diffs[x][1] + text_insert[:commonlength],
1000 1000 )
1001 1001 else:
1002 1002 diffs.insert(
1003 1003 0, (self.DIFF_EQUAL, text_insert[:commonlength])
1004 1004 )
1005 1005 pointer += 1
1006 1006 text_insert = text_insert[commonlength:]
1007 1007 text_delete = text_delete[commonlength:]
1008 1008 # Factor out any common suffixies.
1009 1009 commonlength = self.diff_commonSuffix(text_insert, text_delete)
1010 1010 if commonlength != 0:
1011 1011 diffs[pointer] = (
1012 1012 diffs[pointer][0],
1013 1013 text_insert[-commonlength:] + diffs[pointer][1],
1014 1014 )
1015 1015 text_insert = text_insert[:-commonlength]
1016 1016 text_delete = text_delete[:-commonlength]
1017 1017 # Delete the offending records and add the merged ones.
1018 1018 if count_delete == 0:
1019 1019 diffs[pointer - count_insert : pointer] = [
1020 1020 (self.DIFF_INSERT, text_insert)
1021 1021 ]
1022 1022 elif count_insert == 0:
1023 1023 diffs[pointer - count_delete : pointer] = [
1024 1024 (self.DIFF_DELETE, text_delete)
1025 1025 ]
1026 1026 else:
1027 1027 diffs[pointer - count_delete - count_insert : pointer] = [
1028 1028 (self.DIFF_DELETE, text_delete),
1029 1029 (self.DIFF_INSERT, text_insert),
1030 1030 ]
1031 1031 pointer = pointer - count_delete - count_insert + 1
1032 1032 if count_delete != 0:
1033 1033 pointer += 1
1034 1034 if count_insert != 0:
1035 1035 pointer += 1
1036 1036 elif pointer != 0 and diffs[pointer - 1][0] == self.DIFF_EQUAL:
1037 1037 # Merge this equality with the previous one.
1038 1038 diffs[pointer - 1] = (
1039 1039 diffs[pointer - 1][0],
1040 1040 diffs[pointer - 1][1] + diffs[pointer][1],
1041 1041 )
1042 1042 del diffs[pointer]
1043 1043 else:
1044 1044 pointer += 1
1045 1045
1046 1046 count_insert = 0
1047 1047 count_delete = 0
1048 1048 text_delete = ""
1049 1049 text_insert = ""
1050 1050
1051 1051 if diffs[-1][1] == "":
1052 1052 diffs.pop() # Remove the dummy entry at the end.
1053 1053
1054 1054 # Second pass: look for single edits surrounded on both sides by equalities
1055 1055 # which can be shifted sideways to eliminate an equality.
1056 1056 # e.g: A<ins>BA</ins>C -> <ins>AB</ins>AC
1057 1057 changes = False
1058 1058 pointer = 1
1059 1059 # Intentionally ignore the first and last element (don't need checking).
1060 1060 while pointer < len(diffs) - 1:
1061 1061 if (
1062 1062 diffs[pointer - 1][0] == self.DIFF_EQUAL
1063 1063 and diffs[pointer + 1][0] == self.DIFF_EQUAL
1064 1064 ):
1065 1065 # This is a single edit surrounded by equalities.
1066 1066 if diffs[pointer][1].endswith(diffs[pointer - 1][1]):
1067 1067 # Shift the edit over the previous equality.
1068 1068 diffs[pointer] = (
1069 1069 diffs[pointer][0],
1070 1070 diffs[pointer - 1][1]
1071 1071 + diffs[pointer][1][: -len(diffs[pointer - 1][1])],
1072 1072 )
1073 1073 diffs[pointer + 1] = (
1074 1074 diffs[pointer + 1][0],
1075 1075 diffs[pointer - 1][1] + diffs[pointer + 1][1],
1076 1076 )
1077 1077 del diffs[pointer - 1]
1078 1078 changes = True
1079 1079 elif diffs[pointer][1].startswith(diffs[pointer + 1][1]):
1080 1080 # Shift the edit over the next equality.
1081 1081 diffs[pointer - 1] = (
1082 1082 diffs[pointer - 1][0],
1083 1083 diffs[pointer - 1][1] + diffs[pointer + 1][1],
1084 1084 )
1085 1085 diffs[pointer] = (
1086 1086 diffs[pointer][0],
1087 1087 diffs[pointer][1][len(diffs[pointer + 1][1]) :]
1088 1088 + diffs[pointer + 1][1],
1089 1089 )
1090 1090 del diffs[pointer + 1]
1091 1091 changes = True
1092 1092 pointer += 1
1093 1093
1094 1094 # If shifts were made, the diff needs reordering and another shift sweep.
1095 1095 if changes:
1096 1096 self.diff_cleanupMerge(diffs)
1097 1097
1098 1098 def diff_xIndex(self, diffs, loc):
1099 1099 """loc is a location in text1, compute and return the equivalent location
1100 1100 in text2. e.g. "The cat" vs "The big cat", 1->1, 5->8
1101 1101
1102 1102 Args:
1103 1103 diffs: Array of diff tuples.
1104 1104 loc: Location within text1.
1105 1105
1106 1106 Returns:
1107 1107 Location within text2.
1108 1108 """
1109 1109 chars1 = 0
1110 1110 chars2 = 0
1111 1111 last_chars1 = 0
1112 1112 last_chars2 = 0
1113 1113 for x in range(len(diffs)):
1114 1114 (op, text) = diffs[x]
1115 1115 if op != self.DIFF_INSERT: # Equality or deletion.
1116 1116 chars1 += len(text)
1117 1117 if op != self.DIFF_DELETE: # Equality or insertion.
1118 1118 chars2 += len(text)
1119 1119 if chars1 > loc: # Overshot the location.
1120 1120 break
1121 1121 last_chars1 = chars1
1122 1122 last_chars2 = chars2
1123 1123
1124 1124 if len(diffs) != x and diffs[x][0] == self.DIFF_DELETE:
1125 1125 # The location was deleted.
1126 1126 return last_chars2
1127 1127 # Add the remaining len(character).
1128 1128 return last_chars2 + (loc - last_chars1)
1129 1129
1130 1130 def diff_prettyHtml(self, diffs):
1131 1131 """Convert a diff array into a pretty HTML report.
1132 1132
1133 1133 Args:
1134 1134 diffs: Array of diff tuples.
1135 1135
1136 1136 Returns:
1137 1137 HTML representation.
1138 1138 """
1139 1139 html = []
1140 1140 for op, data in diffs:
1141 1141 text = (
1142 1142 data.replace("&", "&amp;")
1143 1143 .replace("<", "&lt;")
1144 1144 .replace(">", "&gt;")
1145 1145 .replace("\n", "&para;<br>")
1146 1146 )
1147 1147 if op == self.DIFF_INSERT:
1148 1148 html.append('<ins style="background:#e6ffe6;">%s</ins>' % text)
1149 1149 elif op == self.DIFF_DELETE:
1150 1150 html.append('<del style="background:#ffe6e6;">%s</del>' % text)
1151 1151 elif op == self.DIFF_EQUAL:
1152 1152 html.append("<span>%s</span>" % text)
1153 1153 return "".join(html)
1154 1154
1155 1155 def diff_text1(self, diffs):
1156 1156 """Compute and return the source text (all equalities and deletions).
1157 1157
1158 1158 Args:
1159 1159 diffs: Array of diff tuples.
1160 1160
1161 1161 Returns:
1162 1162 Source text.
1163 1163 """
1164 1164 text = []
1165 1165 for op, data in diffs:
1166 1166 if op != self.DIFF_INSERT:
1167 1167 text.append(data)
1168 1168 return "".join(text)
1169 1169
1170 1170 def diff_text2(self, diffs):
1171 1171 """Compute and return the destination text (all equalities and insertions).
1172 1172
1173 1173 Args:
1174 1174 diffs: Array of diff tuples.
1175 1175
1176 1176 Returns:
1177 1177 Destination text.
1178 1178 """
1179 1179 text = []
1180 1180 for op, data in diffs:
1181 1181 if op != self.DIFF_DELETE:
1182 1182 text.append(data)
1183 1183 return "".join(text)
1184 1184
1185 1185 def diff_levenshtein(self, diffs):
1186 1186 """Compute the Levenshtein distance; the number of inserted, deleted or
1187 1187 substituted characters.
1188 1188
1189 1189 Args:
1190 1190 diffs: Array of diff tuples.
1191 1191
1192 1192 Returns:
1193 1193 Number of changes.
1194 1194 """
1195 1195 levenshtein = 0
1196 1196 insertions = 0
1197 1197 deletions = 0
1198 1198 for op, data in diffs:
1199 1199 if op == self.DIFF_INSERT:
1200 1200 insertions += len(data)
1201 1201 elif op == self.DIFF_DELETE:
1202 1202 deletions += len(data)
1203 1203 elif op == self.DIFF_EQUAL:
1204 1204 # A deletion and an insertion is one substitution.
1205 1205 levenshtein += max(insertions, deletions)
1206 1206 insertions = 0
1207 1207 deletions = 0
1208 1208 levenshtein += max(insertions, deletions)
1209 1209 return levenshtein
1210 1210
1211 1211 def diff_toDelta(self, diffs):
1212 1212 """Crush the diff into an encoded string which describes the operations
1213 1213 required to transform text1 into text2.
1214 1214 E.g. =3\t-2\t+ing -> Keep 3 chars, delete 2 chars, insert 'ing'.
1215 1215 Operations are tab-separated. Inserted text is escaped using %xx notation.
1216 1216
1217 1217 Args:
1218 1218 diffs: Array of diff tuples.
1219 1219
1220 1220 Returns:
1221 1221 Delta text.
1222 1222 """
1223 1223 text = []
1224 1224 for op, data in diffs:
1225 1225 if op == self.DIFF_INSERT:
1226 1226 # High ascii will raise UnicodeDecodeError. Use Unicode instead.
1227 1227 data = data.encode("utf-8")
1228 1228 text.append("+" + urllib.parse.quote(data, "!~*'();/?:@&=+$,# "))
1229 1229 elif op == self.DIFF_DELETE:
1230 1230 text.append("-%d" % len(data))
1231 1231 elif op == self.DIFF_EQUAL:
1232 1232 text.append("=%d" % len(data))
1233 1233 return "\t".join(text)
1234 1234
1235 1235 def diff_fromDelta(self, text1, delta):
1236 1236 """Given the original text1, and an encoded string which describes the
1237 1237 operations required to transform text1 into text2, compute the full diff.
1238 1238
1239 1239 Args:
1240 1240 text1: Source string for the diff.
1241 1241 delta: Delta text.
1242 1242
1243 1243 Returns:
1244 1244 Array of diff tuples.
1245 1245
1246 1246 Raises:
1247 1247 ValueError: If invalid input.
1248 1248 """
1249 1249 if type(delta) == str:
1250 1250 # Deltas should be composed of a subset of ascii chars, Unicode not
1251 1251 # required. If this encode raises UnicodeEncodeError, delta is invalid.
1252 1252 delta = delta.encode("ascii")
1253 1253 diffs = []
1254 1254 pointer = 0 # Cursor in text1
1255 1255 tokens = delta.split("\t")
1256 1256 for token in tokens:
1257 1257 if token == "":
1258 1258 # Blank tokens are ok (from a trailing \t).
1259 1259 continue
1260 1260 # Each token begins with a one character parameter which specifies the
1261 1261 # operation of this token (delete, insert, equality).
1262 1262 param = token[1:]
1263 1263 if token[0] == "+":
1264 1264 param = urllib.parse.unquote(param)
1265 1265 diffs.append((self.DIFF_INSERT, param))
1266 1266 elif token[0] == "-" or token[0] == "=":
1267 1267 try:
1268 1268 n = int(param)
1269 1269 except ValueError:
1270 1270 raise ValueError("Invalid number in diff_fromDelta: " + param)
1271 1271 if n < 0:
1272 1272 raise ValueError("Negative number in diff_fromDelta: " + param)
1273 1273 text = text1[pointer : pointer + n]
1274 1274 pointer += n
1275 1275 if token[0] == "=":
1276 1276 diffs.append((self.DIFF_EQUAL, text))
1277 1277 else:
1278 1278 diffs.append((self.DIFF_DELETE, text))
1279 1279 else:
1280 1280 # Anything else is an error.
1281 1281 raise ValueError(
1282 1282 "Invalid diff operation in diff_fromDelta: " + token[0]
1283 1283 )
1284 1284 if pointer != len(text1):
1285 1285 raise ValueError(
1286 1286 "Delta length (%d) does not equal source text length (%d)."
1287 1287 % (pointer, len(text1))
1288 1288 )
1289 1289 return diffs
1290 1290
1291 1291 # MATCH FUNCTIONS
1292 1292
1293 1293 def match_main(self, text, pattern, loc):
1294 1294 """Locate the best instance of 'pattern' in 'text' near 'loc'.
1295 1295
1296 1296 Args:
1297 1297 text: The text to search.
1298 1298 pattern: The pattern to search for.
1299 1299 loc: The location to search around.
1300 1300
1301 1301 Returns:
1302 1302 Best match index or -1.
1303 1303 """
1304 1304 # Check for null inputs.
1305 1305 if text is None or pattern is None:
1306 1306 raise ValueError("Null inputs. (match_main)")
1307 1307
1308 1308 loc = max(0, min(loc, len(text)))
1309 1309 if text == pattern:
1310 1310 # Shortcut (potentially not guaranteed by the algorithm)
1311 1311 return 0
1312 1312 elif not text:
1313 1313 # Nothing to match.
1314 1314 return -1
1315 1315 elif text[loc : loc + len(pattern)] == pattern:
1316 1316 # Perfect match at the perfect spot! (Includes case of null pattern)
1317 1317 return loc
1318 1318 else:
1319 1319 # Do a fuzzy compare.
1320 1320 match = self.match_bitap(text, pattern, loc)
1321 1321 return match
1322 1322
1323 1323 def match_bitap(self, text, pattern, loc):
1324 1324 """Locate the best instance of 'pattern' in 'text' near 'loc' using the
1325 1325 Bitap algorithm.
1326 1326
1327 1327 Args:
1328 1328 text: The text to search.
1329 1329 pattern: The pattern to search for.
1330 1330 loc: The location to search around.
1331 1331
1332 1332 Returns:
1333 1333 Best match index or -1.
1334 1334 """
1335 1335 # Python doesn't have a maxint limit, so ignore this check.
1336 1336 # if self.Match_MaxBits != 0 and len(pattern) > self.Match_MaxBits:
1337 1337 # raise ValueError("Pattern too long for this application.")
1338 1338
1339 1339 # Initialise the alphabet.
1340 1340 s = self.match_alphabet(pattern)
1341 1341
1342 1342 def match_bitapScore(e, x):
1343 1343 """Compute and return the score for a match with e errors and x location.
1344 1344 Accesses loc and pattern through being a closure.
1345 1345
1346 1346 Args:
1347 1347 e: Number of errors in match.
1348 1348 x: Location of match.
1349 1349
1350 1350 Returns:
1351 1351 Overall score for match (0.0 = good, 1.0 = bad).
1352 1352 """
1353 1353 accuracy = float(e) / len(pattern)
1354 1354 proximity = abs(loc - x)
1355 1355 if not self.Match_Distance:
1356 1356 # Dodge divide by zero error.
1357 1357 return proximity and 1.0 or accuracy
1358 1358 return accuracy + (proximity / float(self.Match_Distance))
1359 1359
1360 1360 # Highest score beyond which we give up.
1361 1361 score_threshold = self.Match_Threshold
1362 1362 # Is there a nearby exact match? (speedup)
1363 1363 best_loc = text.find(pattern, loc)
1364 1364 if best_loc != -1:
1365 1365 score_threshold = min(match_bitapScore(0, best_loc), score_threshold)
1366 1366 # What about in the other direction? (speedup)
1367 1367 best_loc = text.rfind(pattern, loc + len(pattern))
1368 1368 if best_loc != -1:
1369 1369 score_threshold = min(match_bitapScore(0, best_loc), score_threshold)
1370 1370
1371 1371 # Initialise the bit arrays.
1372 1372 matchmask = 1 << (len(pattern) - 1)
1373 1373 best_loc = -1
1374 1374
1375 1375 bin_max = len(pattern) + len(text)
1376 1376 # Empty initialization added to appease pychecker.
1377 1377 last_rd = None
1378 1378 for d in range(len(pattern)):
1379 1379 # Scan for the best match each iteration allows for one more error.
1380 1380 # Run a binary search to determine how far from 'loc' we can stray at
1381 1381 # this error level.
1382 1382 bin_min = 0
1383 1383 bin_mid = bin_max
1384 1384 while bin_min < bin_mid:
1385 1385 if match_bitapScore(d, loc + bin_mid) <= score_threshold:
1386 1386 bin_min = bin_mid
1387 1387 else:
1388 1388 bin_max = bin_mid
1389 1389 bin_mid = (bin_max - bin_min) // 2 + bin_min
1390 1390
1391 1391 # Use the result from this iteration as the maximum for the next.
1392 1392 bin_max = bin_mid
1393 1393 start = max(1, loc - bin_mid + 1)
1394 1394 finish = min(loc + bin_mid, len(text)) + len(pattern)
1395 1395
1396 1396 rd = [0] * (finish + 2)
1397 1397 rd[finish + 1] = (1 << d) - 1
1398 1398 for j in range(finish, start - 1, -1):
1399 1399 if len(text) <= j - 1:
1400 1400 # Out of range.
1401 1401 charMatch = 0
1402 1402 else:
1403 1403 charMatch = s.get(text[j - 1], 0)
1404 1404 if d == 0: # First pass: exact match.
1405 1405 rd[j] = ((rd[j + 1] << 1) | 1) & charMatch
1406 1406 else: # Subsequent passes: fuzzy match.
1407 1407 rd[j] = (
1408 1408 (((rd[j + 1] << 1) | 1) & charMatch)
1409 1409 | (((last_rd[j + 1] | last_rd[j]) << 1) | 1)
1410 1410 | last_rd[j + 1]
1411 1411 )
1412 1412 if rd[j] & matchmask:
1413 1413 score = match_bitapScore(d, j - 1)
1414 1414 # This match will almost certainly be better than any existing match.
1415 1415 # But check anyway.
1416 1416 if score <= score_threshold:
1417 1417 # Told you so.
1418 1418 score_threshold = score
1419 1419 best_loc = j - 1
1420 1420 if best_loc > loc:
1421 1421 # When passing loc, don't exceed our current distance from loc.
1422 1422 start = max(1, 2 * loc - best_loc)
1423 1423 else:
1424 1424 # Already passed loc, downhill from here on in.
1425 1425 break
1426 1426 # No hope for a (better) match at greater error levels.
1427 1427 if match_bitapScore(d + 1, loc) > score_threshold:
1428 1428 break
1429 1429 last_rd = rd
1430 1430 return best_loc
1431 1431
1432 1432 def match_alphabet(self, pattern):
1433 1433 """Initialise the alphabet for the Bitap algorithm.
1434 1434
1435 1435 Args:
1436 1436 pattern: The text to encode.
1437 1437
1438 1438 Returns:
1439 1439 Hash of character locations.
1440 1440 """
1441 1441 s = {}
1442 1442 for char in pattern:
1443 1443 s[char] = 0
1444 1444 for i in range(len(pattern)):
1445 1445 s[pattern[i]] |= 1 << (len(pattern) - i - 1)
1446 1446 return s
1447 1447
1448 1448 # PATCH FUNCTIONS
1449 1449
1450 1450 def patch_addContext(self, patch, text):
1451 1451 """Increase the context until it is unique,
1452 1452 but don't let the pattern expand beyond Match_MaxBits.
1453 1453
1454 1454 Args:
1455 1455 patch: The patch to grow.
1456 1456 text: Source text.
1457 1457 """
1458 1458 if len(text) == 0:
1459 1459 return
1460 1460 pattern = text[patch.start2 : patch.start2 + patch.length1]
1461 1461 padding = 0
1462 1462
1463 1463 # Look for the first and last matches of pattern in text. If two different
1464 1464 # matches are found, increase the pattern length.
1465 1465 while text.find(pattern) != text.rfind(pattern) and (
1466 1466 self.Match_MaxBits == 0
1467 1467 or len(pattern) < self.Match_MaxBits - self.Patch_Margin - self.Patch_Margin
1468 1468 ):
1469 1469 padding += self.Patch_Margin
1470 1470 pattern = text[
1471 1471 max(0, patch.start2 - padding) : patch.start2 + patch.length1 + padding
1472 1472 ]
1473 1473 # Add one chunk for good luck.
1474 1474 padding += self.Patch_Margin
1475 1475
1476 1476 # Add the prefix.
1477 1477 prefix = text[max(0, patch.start2 - padding) : patch.start2]
1478 1478 if prefix:
1479 1479 patch.diffs[:0] = [(self.DIFF_EQUAL, prefix)]
1480 1480 # Add the suffix.
1481 1481 suffix = text[
1482 1482 patch.start2 + patch.length1 : patch.start2 + patch.length1 + padding
1483 1483 ]
1484 1484 if suffix:
1485 1485 patch.diffs.append((self.DIFF_EQUAL, suffix))
1486 1486
1487 1487 # Roll back the start points.
1488 1488 patch.start1 -= len(prefix)
1489 1489 patch.start2 -= len(prefix)
1490 1490 # Extend lengths.
1491 1491 patch.length1 += len(prefix) + len(suffix)
1492 1492 patch.length2 += len(prefix) + len(suffix)
1493 1493
1494 1494 def patch_make(self, a, b=None, c=None):
1495 1495 """Compute a list of patches to turn text1 into text2.
1496 1496 Use diffs if provided, otherwise compute it ourselves.
1497 1497 There are four ways to call this function, depending on what data is
1498 1498 available to the caller:
1499 1499 Method 1:
1500 1500 a = text1, b = text2
1501 1501 Method 2:
1502 1502 a = diffs
1503 1503 Method 3 (optimal):
1504 1504 a = text1, b = diffs
1505 1505 Method 4 (deprecated, use method 3):
1506 1506 a = text1, b = text2, c = diffs
1507 1507
1508 1508 Args:
1509 1509 a: text1 (methods 1,3,4) or Array of diff tuples for text1 to
1510 1510 text2 (method 2).
1511 1511 b: text2 (methods 1,4) or Array of diff tuples for text1 to
1512 1512 text2 (method 3) or undefined (method 2).
1513 1513 c: Array of diff tuples for text1 to text2 (method 4) or
1514 1514 undefined (methods 1,2,3).
1515 1515
1516 1516 Returns:
1517 1517 Array of Patch objects.
1518 1518 """
1519 1519 text1 = None
1520 1520 diffs = None
1521 1521 # Note that texts may arrive as 'str' or 'unicode'.
1522 1522 if isinstance(a, str) and isinstance(b, str) and c is None:
1523 1523 # Method 1: text1, text2
1524 1524 # Compute diffs from text1 and text2.
1525 1525 text1 = a
1526 1526 diffs = self.diff_main(text1, b, True)
1527 1527 if len(diffs) > 2:
1528 1528 self.diff_cleanupSemantic(diffs)
1529 1529 self.diff_cleanupEfficiency(diffs)
1530 1530 elif isinstance(a, list) and b is None and c is None:
1531 1531 # Method 2: diffs
1532 1532 # Compute text1 from diffs.
1533 1533 diffs = a
1534 1534 text1 = self.diff_text1(diffs)
1535 1535 elif isinstance(a, str) and isinstance(b, list) and c is None:
1536 1536 # Method 3: text1, diffs
1537 1537 text1 = a
1538 1538 diffs = b
1539 1539 elif isinstance(a, str) and isinstance(b, str) and isinstance(c, list):
1540 1540 # Method 4: text1, text2, diffs
1541 1541 # text2 is not used.
1542 1542 text1 = a
1543 1543 diffs = c
1544 1544 else:
1545 1545 raise ValueError("Unknown call format to patch_make.")
1546 1546
1547 1547 if not diffs:
1548 1548 return [] # Get rid of the None case.
1549 1549 patches = []
1550 1550 patch = patch_obj()
1551 1551 char_count1 = 0 # Number of characters into the text1 string.
1552 1552 char_count2 = 0 # Number of characters into the text2 string.
1553 1553 prepatch_text = text1 # Recreate the patches to determine context info.
1554 1554 postpatch_text = text1
1555 1555 for x in range(len(diffs)):
1556 1556 (diff_type, diff_text) = diffs[x]
1557 1557 if len(patch.diffs) == 0 and diff_type != self.DIFF_EQUAL:
1558 1558 # A new patch starts here.
1559 1559 patch.start1 = char_count1
1560 1560 patch.start2 = char_count2
1561 1561 if diff_type == self.DIFF_INSERT:
1562 1562 # Insertion
1563 1563 patch.diffs.append(diffs[x])
1564 1564 patch.length2 += len(diff_text)
1565 1565 postpatch_text = (
1566 1566 postpatch_text[:char_count2]
1567 1567 + diff_text
1568 1568 + postpatch_text[char_count2:]
1569 1569 )
1570 1570 elif diff_type == self.DIFF_DELETE:
1571 1571 # Deletion.
1572 1572 patch.length1 += len(diff_text)
1573 1573 patch.diffs.append(diffs[x])
1574 1574 postpatch_text = (
1575 1575 postpatch_text[:char_count2]
1576 1576 + postpatch_text[char_count2 + len(diff_text) :]
1577 1577 )
1578 1578 elif (
1579 1579 diff_type == self.DIFF_EQUAL
1580 1580 and len(diff_text) <= 2 * self.Patch_Margin
1581 1581 and len(patch.diffs) != 0
1582 1582 and len(diffs) != x + 1
1583 1583 ):
1584 1584 # Small equality inside a patch.
1585 1585 patch.diffs.append(diffs[x])
1586 1586 patch.length1 += len(diff_text)
1587 1587 patch.length2 += len(diff_text)
1588 1588
1589 1589 if diff_type == self.DIFF_EQUAL and len(diff_text) >= 2 * self.Patch_Margin:
1590 1590 # Time for a new patch.
1591 1591 if len(patch.diffs) != 0:
1592 1592 self.patch_addContext(patch, prepatch_text)
1593 1593 patches.append(patch)
1594 1594 patch = patch_obj()
1595 1595 # Unlike Unidiff, our patch lists have a rolling context.
1596 1596 # http://code.google.com/p/google-diff-match-patch/wiki/Unidiff
1597 1597 # Update prepatch text & pos to reflect the application of the
1598 1598 # just completed patch.
1599 1599 prepatch_text = postpatch_text
1600 1600 char_count1 = char_count2
1601 1601
1602 1602 # Update the current character count.
1603 1603 if diff_type != self.DIFF_INSERT:
1604 1604 char_count1 += len(diff_text)
1605 1605 if diff_type != self.DIFF_DELETE:
1606 1606 char_count2 += len(diff_text)
1607 1607
1608 1608 # Pick up the leftover patch if not empty.
1609 1609 if len(patch.diffs) != 0:
1610 1610 self.patch_addContext(patch, prepatch_text)
1611 1611 patches.append(patch)
1612 1612 return patches
1613 1613
1614 1614 def patch_deepCopy(self, patches):
1615 1615 """Given an array of patches, return another array that is identical.
1616 1616
1617 1617 Args:
1618 1618 patches: Array of Patch objects.
1619 1619
1620 1620 Returns:
1621 1621 Array of Patch objects.
1622 1622 """
1623 1623 patchesCopy = []
1624 1624 for patch in patches:
1625 1625 patchCopy = patch_obj()
1626 1626 # No need to deep copy the tuples since they are immutable.
1627 1627 patchCopy.diffs = patch.diffs[:]
1628 1628 patchCopy.start1 = patch.start1
1629 1629 patchCopy.start2 = patch.start2
1630 1630 patchCopy.length1 = patch.length1
1631 1631 patchCopy.length2 = patch.length2
1632 1632 patchesCopy.append(patchCopy)
1633 1633 return patchesCopy
1634 1634
1635 1635 def patch_apply(self, patches, text):
1636 1636 """Merge a set of patches onto the text. Return a patched text, as well
1637 1637 as a list of true/false values indicating which patches were applied.
1638 1638
1639 1639 Args:
1640 1640 patches: Array of Patch objects.
1641 1641 text: Old text.
1642 1642
1643 1643 Returns:
1644 1644 Two element Array, containing the new text and an array of boolean values.
1645 1645 """
1646 1646 if not patches:
1647 1647 return (text, [])
1648 1648
1649 1649 # Deep copy the patches so that no changes are made to originals.
1650 1650 patches = self.patch_deepCopy(patches)
1651 1651
1652 1652 nullPadding = self.patch_addPadding(patches)
1653 1653 text = nullPadding + text + nullPadding
1654 1654 self.patch_splitMax(patches)
1655 1655
1656 1656 # delta keeps track of the offset between the expected and actual location
1657 1657 # of the previous patch. If there are patches expected at positions 10 and
1658 1658 # 20, but the first patch was found at 12, delta is 2 and the second patch
1659 1659 # has an effective expected position of 22.
1660 1660 delta = 0
1661 1661 results = []
1662 1662 for patch in patches:
1663 1663 expected_loc = patch.start2 + delta
1664 1664 text1 = self.diff_text1(patch.diffs)
1665 1665 end_loc = -1
1666 1666 if len(text1) > self.Match_MaxBits:
1667 1667 # patch_splitMax will only provide an oversized pattern in the case of
1668 1668 # a monster delete.
1669 1669 start_loc = self.match_main(
1670 1670 text, text1[: self.Match_MaxBits], expected_loc
1671 1671 )
1672 1672 if start_loc != -1:
1673 1673 end_loc = self.match_main(
1674 1674 text,
1675 1675 text1[-self.Match_MaxBits :],
1676 1676 expected_loc + len(text1) - self.Match_MaxBits,
1677 1677 )
1678 1678 if end_loc == -1 or start_loc >= end_loc:
1679 1679 # Can't find valid trailing context. Drop this patch.
1680 1680 start_loc = -1
1681 1681 else:
1682 1682 start_loc = self.match_main(text, text1, expected_loc)
1683 1683 if start_loc == -1:
1684 1684 # No match found. :(
1685 1685 results.append(False)
1686 1686 # Subtract the delta for this failed patch from subsequent patches.
1687 1687 delta -= patch.length2 - patch.length1
1688 1688 else:
1689 1689 # Found a match. :)
1690 1690 results.append(True)
1691 1691 delta = start_loc - expected_loc
1692 1692 if end_loc == -1:
1693 1693 text2 = text[start_loc : start_loc + len(text1)]
1694 1694 else:
1695 1695 text2 = text[start_loc : end_loc + self.Match_MaxBits]
1696 1696 if text1 == text2:
1697 1697 # Perfect match, just shove the replacement text in.
1698 1698 text = (
1699 1699 text[:start_loc]
1700 1700 + self.diff_text2(patch.diffs)
1701 1701 + text[start_loc + len(text1) :]
1702 1702 )
1703 1703 else:
1704 1704 # Imperfect match.
1705 1705 # Run a diff to get a framework of equivalent indices.
1706 1706 diffs = self.diff_main(text1, text2, False)
1707 1707 if (
1708 1708 len(text1) > self.Match_MaxBits
1709 1709 and self.diff_levenshtein(diffs) / float(len(text1))
1710 1710 > self.Patch_DeleteThreshold
1711 1711 ):
1712 1712 # The end points match, but the content is unacceptably bad.
1713 1713 results[-1] = False
1714 1714 else:
1715 1715 self.diff_cleanupSemanticLossless(diffs)
1716 1716 index1 = 0
1717 1717 for op, data in patch.diffs:
1718 1718 if op != self.DIFF_EQUAL:
1719 1719 index2 = self.diff_xIndex(diffs, index1)
1720 1720 if op == self.DIFF_INSERT: # Insertion
1721 1721 text = (
1722 1722 text[: start_loc + index2]
1723 1723 + data
1724 1724 + text[start_loc + index2 :]
1725 1725 )
1726 1726 elif op == self.DIFF_DELETE: # Deletion
1727 1727 text = (
1728 1728 text[: start_loc + index2]
1729 1729 + text[
1730 1730 start_loc
1731 1731 + self.diff_xIndex(diffs, index1 + len(data)) :
1732 1732 ]
1733 1733 )
1734 1734 if op != self.DIFF_DELETE:
1735 1735 index1 += len(data)
1736 1736 # Strip the padding off.
1737 1737 text = text[len(nullPadding) : -len(nullPadding)]
1738 1738 return (text, results)
1739 1739
1740 1740 def patch_addPadding(self, patches):
1741 1741 """Add some padding on text start and end so that edges can match
1742 1742 something. Intended to be called only from within patch_apply.
1743 1743
1744 1744 Args:
1745 1745 patches: Array of Patch objects.
1746 1746
1747 1747 Returns:
1748 1748 The padding string added to each side.
1749 1749 """
1750 1750 paddingLength = self.Patch_Margin
1751 1751 nullPadding = ""
1752 1752 for x in range(1, paddingLength + 1):
1753 1753 nullPadding += chr(x)
1754 1754
1755 1755 # Bump all the patches forward.
1756 1756 for patch in patches:
1757 1757 patch.start1 += paddingLength
1758 1758 patch.start2 += paddingLength
1759 1759
1760 1760 # Add some padding on start of first diff.
1761 1761 patch = patches[0]
1762 1762 diffs = patch.diffs
1763 1763 if not diffs or diffs[0][0] != self.DIFF_EQUAL:
1764 1764 # Add nullPadding equality.
1765 1765 diffs.insert(0, (self.DIFF_EQUAL, nullPadding))
1766 1766 patch.start1 -= paddingLength # Should be 0.
1767 1767 patch.start2 -= paddingLength # Should be 0.
1768 1768 patch.length1 += paddingLength
1769 1769 patch.length2 += paddingLength
1770 1770 elif paddingLength > len(diffs[0][1]):
1771 1771 # Grow first equality.
1772 1772 extraLength = paddingLength - len(diffs[0][1])
1773 1773 newText = nullPadding[len(diffs[0][1]) :] + diffs[0][1]
1774 1774 diffs[0] = (diffs[0][0], newText)
1775 1775 patch.start1 -= extraLength
1776 1776 patch.start2 -= extraLength
1777 1777 patch.length1 += extraLength
1778 1778 patch.length2 += extraLength
1779 1779
1780 1780 # Add some padding on end of last diff.
1781 1781 patch = patches[-1]
1782 1782 diffs = patch.diffs
1783 1783 if not diffs or diffs[-1][0] != self.DIFF_EQUAL:
1784 1784 # Add nullPadding equality.
1785 1785 diffs.append((self.DIFF_EQUAL, nullPadding))
1786 1786 patch.length1 += paddingLength
1787 1787 patch.length2 += paddingLength
1788 1788 elif paddingLength > len(diffs[-1][1]):
1789 1789 # Grow last equality.
1790 1790 extraLength = paddingLength - len(diffs[-1][1])
1791 1791 newText = diffs[-1][1] + nullPadding[:extraLength]
1792 1792 diffs[-1] = (diffs[-1][0], newText)
1793 1793 patch.length1 += extraLength
1794 1794 patch.length2 += extraLength
1795 1795
1796 1796 return nullPadding
1797 1797
1798 1798 def patch_splitMax(self, patches):
1799 1799 """Look through the patches and break up any which are longer than the
1800 1800 maximum limit of the match algorithm.
1801 1801 Intended to be called only from within patch_apply.
1802 1802
1803 1803 Args:
1804 1804 patches: Array of Patch objects.
1805 1805 """
1806 1806 patch_size = self.Match_MaxBits
1807 1807 if patch_size == 0:
1808 1808 # Python has the option of not splitting strings due to its ability
1809 1809 # to handle integers of arbitrary precision.
1810 1810 return
1811 1811 for x in range(len(patches)):
1812 1812 if patches[x].length1 <= patch_size:
1813 1813 continue
1814 1814 bigpatch = patches[x]
1815 1815 # Remove the big old patch.
1816 1816 del patches[x]
1817 1817 x -= 1
1818 1818 start1 = bigpatch.start1
1819 1819 start2 = bigpatch.start2
1820 1820 precontext = ""
1821 1821 while len(bigpatch.diffs) != 0:
1822 1822 # Create one of several smaller patches.
1823 1823 patch = patch_obj()
1824 1824 empty = True
1825 1825 patch.start1 = start1 - len(precontext)
1826 1826 patch.start2 = start2 - len(precontext)
1827 1827 if precontext:
1828 1828 patch.length1 = patch.length2 = len(precontext)
1829 1829 patch.diffs.append((self.DIFF_EQUAL, precontext))
1830 1830
1831 1831 while (
1832 1832 len(bigpatch.diffs) != 0
1833 1833 and patch.length1 < patch_size - self.Patch_Margin
1834 1834 ):
1835 1835 (diff_type, diff_text) = bigpatch.diffs[0]
1836 1836 if diff_type == self.DIFF_INSERT:
1837 1837 # Insertions are harmless.
1838 1838 patch.length2 += len(diff_text)
1839 1839 start2 += len(diff_text)
1840 1840 patch.diffs.append(bigpatch.diffs.pop(0))
1841 1841 empty = False
1842 1842 elif (
1843 1843 diff_type == self.DIFF_DELETE
1844 1844 and len(patch.diffs) == 1
1845 1845 and patch.diffs[0][0] == self.DIFF_EQUAL
1846 1846 and len(diff_text) > 2 * patch_size
1847 1847 ):
1848 1848 # This is a large deletion. Let it pass in one chunk.
1849 1849 patch.length1 += len(diff_text)
1850 1850 start1 += len(diff_text)
1851 1851 empty = False
1852 1852 patch.diffs.append((diff_type, diff_text))
1853 1853 del bigpatch.diffs[0]
1854 1854 else:
1855 1855 # Deletion or equality. Only take as much as we can stomach.
1856 1856 diff_text = diff_text[
1857 1857 : patch_size - patch.length1 - self.Patch_Margin
1858 1858 ]
1859 1859 patch.length1 += len(diff_text)
1860 1860 start1 += len(diff_text)
1861 1861 if diff_type == self.DIFF_EQUAL:
1862 1862 patch.length2 += len(diff_text)
1863 1863 start2 += len(diff_text)
1864 1864 else:
1865 1865 empty = False
1866 1866
1867 1867 patch.diffs.append((diff_type, diff_text))
1868 1868 if diff_text == bigpatch.diffs[0][1]:
1869 1869 del bigpatch.diffs[0]
1870 1870 else:
1871 1871 bigpatch.diffs[0] = (
1872 1872 bigpatch.diffs[0][0],
1873 1873 bigpatch.diffs[0][1][len(diff_text) :],
1874 1874 )
1875 1875
1876 1876 # Compute the head context for the next patch.
1877 1877 precontext = self.diff_text2(patch.diffs)
1878 1878 precontext = precontext[-self.Patch_Margin :]
1879 1879 # Append the end context for this patch.
1880 1880 postcontext = self.diff_text1(bigpatch.diffs)[: self.Patch_Margin]
1881 1881 if postcontext:
1882 1882 patch.length1 += len(postcontext)
1883 1883 patch.length2 += len(postcontext)
1884 1884 if len(patch.diffs) != 0 and patch.diffs[-1][0] == self.DIFF_EQUAL:
1885 1885 patch.diffs[-1] = (
1886 1886 self.DIFF_EQUAL,
1887 1887 patch.diffs[-1][1] + postcontext,
1888 1888 )
1889 1889 else:
1890 1890 patch.diffs.append((self.DIFF_EQUAL, postcontext))
1891 1891
1892 1892 if not empty:
1893 1893 x += 1
1894 1894 patches.insert(x, patch)
1895 1895
1896 1896 def patch_toText(self, patches):
1897 1897 """Take a list of patches and return a textual representation.
1898 1898
1899 1899 Args:
1900 1900 patches: Array of Patch objects.
1901 1901
1902 1902 Returns:
1903 1903 Text representation of patches.
1904 1904 """
1905 1905 text = []
1906 1906 for patch in patches:
1907 1907 text.append(str(patch))
1908 1908 return "".join(text)
1909 1909
1910 1910 def patch_fromText(self, textline):
1911 1911 """Parse a textual representation of patches and return a list of patch
1912 1912 objects.
1913 1913
1914 1914 Args:
1915 1915 textline: Text representation of patches.
1916 1916
1917 1917 Returns:
1918 1918 Array of Patch objects.
1919 1919
1920 1920 Raises:
1921 1921 ValueError: If invalid input.
1922 1922 """
1923 if type(textline) == unicode:
1923 if type(textline) == str:
1924 1924 # Patches should be composed of a subset of ascii chars, Unicode not
1925 1925 # required. If this encode raises UnicodeEncodeError, patch is invalid.
1926 1926 textline = textline.encode("ascii")
1927 1927 patches = []
1928 1928 if not textline:
1929 1929 return patches
1930 1930 text = textline.split("\n")
1931 1931 while len(text) != 0:
1932 1932 m = re.match("^@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@$", text[0])
1933 1933 if not m:
1934 1934 raise ValueError("Invalid patch string: " + text[0])
1935 1935 patch = patch_obj()
1936 1936 patches.append(patch)
1937 1937 patch.start1 = int(m.group(1))
1938 1938 if m.group(2) == "":
1939 1939 patch.start1 -= 1
1940 1940 patch.length1 = 1
1941 1941 elif m.group(2) == "0":
1942 1942 patch.length1 = 0
1943 1943 else:
1944 1944 patch.start1 -= 1
1945 1945 patch.length1 = int(m.group(2))
1946 1946
1947 1947 patch.start2 = int(m.group(3))
1948 1948 if m.group(4) == "":
1949 1949 patch.start2 -= 1
1950 1950 patch.length2 = 1
1951 1951 elif m.group(4) == "0":
1952 1952 patch.length2 = 0
1953 1953 else:
1954 1954 patch.start2 -= 1
1955 1955 patch.length2 = int(m.group(4))
1956 1956
1957 1957 del text[0]
1958 1958
1959 1959 while len(text) != 0:
1960 1960 if text[0]:
1961 1961 sign = text[0][0]
1962 1962 else:
1963 1963 sign = ""
1964 1964 line = urllib.parse.unquote(text[0][1:])
1965 1965 line = line.decode("utf-8")
1966 1966 if sign == "+":
1967 1967 # Insertion.
1968 1968 patch.diffs.append((self.DIFF_INSERT, line))
1969 1969 elif sign == "-":
1970 1970 # Deletion.
1971 1971 patch.diffs.append((self.DIFF_DELETE, line))
1972 1972 elif sign == " ":
1973 1973 # Minor equality.
1974 1974 patch.diffs.append((self.DIFF_EQUAL, line))
1975 1975 elif sign == "@":
1976 1976 # Start of next patch.
1977 1977 break
1978 1978 elif sign == "":
1979 1979 # Blank line? Whatever.
1980 1980 pass
1981 1981 else:
1982 1982 # WTF?
1983 1983 raise ValueError("Invalid patch mode: '%s'\n%s" % (sign, line))
1984 1984 del text[0]
1985 1985 return patches
1986 1986
1987 1987
1988 1988 class patch_obj:
1989 1989 """Class representing one patch operation."""
1990 1990
1991 1991 def __init__(self):
1992 1992 """Initializes with an empty list of diffs."""
1993 1993 self.diffs = []
1994 1994 self.start1 = None
1995 1995 self.start2 = None
1996 1996 self.length1 = 0
1997 1997 self.length2 = 0
1998 1998
1999 1999 def __str__(self):
2000 2000 """Emmulate GNU diff's format.
2001 2001 Header: @@ -382,8 +481,9 @@
2002 2002 Indicies are printed as 1-based, not 0-based.
2003 2003
2004 2004 Returns:
2005 2005 The GNU diff string.
2006 2006 """
2007 2007 if self.length1 == 0:
2008 2008 coords1 = str(self.start1) + ",0"
2009 2009 elif self.length1 == 1:
2010 2010 coords1 = str(self.start1 + 1)
2011 2011 else:
2012 2012 coords1 = str(self.start1 + 1) + "," + str(self.length1)
2013 2013 if self.length2 == 0:
2014 2014 coords2 = str(self.start2) + ",0"
2015 2015 elif self.length2 == 1:
2016 2016 coords2 = str(self.start2 + 1)
2017 2017 else:
2018 2018 coords2 = str(self.start2 + 1) + "," + str(self.length2)
2019 2019 text = ["@@ -", coords1, " +", coords2, " @@\n"]
2020 2020 # Escape the body of the patch with %xx notation.
2021 2021 for op, data in self.diffs:
2022 2022 if op == diff_match_patch.DIFF_INSERT:
2023 2023 text.append("+")
2024 2024 elif op == diff_match_patch.DIFF_DELETE:
2025 2025 text.append("-")
2026 2026 elif op == diff_match_patch.DIFF_EQUAL:
2027 2027 text.append(" ")
2028 2028 # High ascii will raise UnicodeDecodeError. Use Unicode instead.
2029 2029 data = data.encode("utf-8")
2030 2030 text.append(urllib.parse.quote(data, "!~*'();/?:@&=+$,# ") + "\n")
2031 2031 return "".join(text)
@@ -1,1272 +1,1271 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21
22 22 """
23 23 Set of diffing helpers, previously part of vcs
24 24 """
25 25
26 26 import os
27 27 import re
28 28 import bz2
29 29 import gzip
30 30 import time
31 31
32 32 import collections
33 33 import difflib
34 34 import logging
35 35 import pickle
36 36 from itertools import tee
37 37
38 38 from rhodecode.lib.vcs.exceptions import VCSError
39 39 from rhodecode.lib.vcs.nodes import FileNode, SubModuleNode
40 40 from rhodecode.lib.utils2 import safe_unicode, safe_str
41 41
42 42 log = logging.getLogger(__name__)
43 43
44 44 # define max context, a file with more than this numbers of lines is unusable
45 45 # in browser anyway
46 46 MAX_CONTEXT = 20 * 1024
47 47 DEFAULT_CONTEXT = 3
48 48
49 49
50 50 def get_diff_context(request):
51 51 return MAX_CONTEXT if request.GET.get('fullcontext', '') == '1' else DEFAULT_CONTEXT
52 52
53 53
54 54 def get_diff_whitespace_flag(request):
55 55 return request.GET.get('ignorews', '') == '1'
56 56
57 57
58 58 class OPS(object):
59 59 ADD = 'A'
60 60 MOD = 'M'
61 61 DEL = 'D'
62 62
63 63
64 64 def get_gitdiff(filenode_old, filenode_new, ignore_whitespace=True, context=3):
65 65 """
66 66 Returns git style diff between given ``filenode_old`` and ``filenode_new``.
67 67
68 68 :param ignore_whitespace: ignore whitespaces in diff
69 69 """
70 70 # make sure we pass in default context
71 71 context = context or 3
72 72 # protect against IntOverflow when passing HUGE context
73 73 if context > MAX_CONTEXT:
74 74 context = MAX_CONTEXT
75 75
76 submodules = filter(lambda o: isinstance(o, SubModuleNode),
77 [filenode_new, filenode_old])
76 submodules = [o for o in [filenode_new, filenode_old] if isinstance(o, SubModuleNode)]
78 77 if submodules:
79 78 return ''
80 79
81 80 for filenode in (filenode_old, filenode_new):
82 81 if not isinstance(filenode, FileNode):
83 82 raise VCSError(
84 83 "Given object should be FileNode object, not %s"
85 84 % filenode.__class__)
86 85
87 86 repo = filenode_new.commit.repository
88 87 old_commit = filenode_old.commit or repo.EMPTY_COMMIT
89 88 new_commit = filenode_new.commit
90 89
91 90 vcs_gitdiff = repo.get_diff(
92 91 old_commit, new_commit, filenode_new.path,
93 92 ignore_whitespace, context, path1=filenode_old.path)
94 93 return vcs_gitdiff
95 94
96 95 NEW_FILENODE = 1
97 96 DEL_FILENODE = 2
98 97 MOD_FILENODE = 3
99 98 RENAMED_FILENODE = 4
100 99 COPIED_FILENODE = 5
101 100 CHMOD_FILENODE = 6
102 101 BIN_FILENODE = 7
103 102
104 103
105 104 class LimitedDiffContainer(object):
106 105
107 106 def __init__(self, diff_limit, cur_diff_size, diff):
108 107 self.diff = diff
109 108 self.diff_limit = diff_limit
110 109 self.cur_diff_size = cur_diff_size
111 110
112 111 def __getitem__(self, key):
113 112 return self.diff.__getitem__(key)
114 113
115 114 def __iter__(self):
116 115 for l in self.diff:
117 116 yield l
118 117
119 118
120 119 class Action(object):
121 120 """
122 121 Contains constants for the action value of the lines in a parsed diff.
123 122 """
124 123
125 124 ADD = 'add'
126 125 DELETE = 'del'
127 126 UNMODIFIED = 'unmod'
128 127
129 128 CONTEXT = 'context'
130 129 OLD_NO_NL = 'old-no-nl'
131 130 NEW_NO_NL = 'new-no-nl'
132 131
133 132
134 133 class DiffProcessor(object):
135 134 """
136 135 Give it a unified or git diff and it returns a list of the files that were
137 136 mentioned in the diff together with a dict of meta information that
138 137 can be used to render it in a HTML template.
139 138
140 139 .. note:: Unicode handling
141 140
142 141 The original diffs are a byte sequence and can contain filenames
143 142 in mixed encodings. This class generally returns `unicode` objects
144 143 since the result is intended for presentation to the user.
145 144
146 145 """
147 146 _chunk_re = re.compile(r'^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@(.*)')
148 147 _newline_marker = re.compile(r'^\\ No newline at end of file')
149 148
150 149 # used for inline highlighter word split
151 150 _token_re = re.compile(r'()(&gt;|&lt;|&amp;|\W+?)')
152 151
153 152 # collapse ranges of commits over given number
154 153 _collapse_commits_over = 5
155 154
156 155 def __init__(self, diff, format='gitdiff', diff_limit=None,
157 156 file_limit=None, show_full_diff=True):
158 157 """
159 158 :param diff: A `Diff` object representing a diff from a vcs backend
160 159 :param format: format of diff passed, `udiff` or `gitdiff`
161 160 :param diff_limit: define the size of diff that is considered "big"
162 161 based on that parameter cut off will be triggered, set to None
163 162 to show full diff
164 163 """
165 164 self._diff = diff
166 165 self._format = format
167 166 self.adds = 0
168 167 self.removes = 0
169 168 # calculate diff size
170 169 self.diff_limit = diff_limit
171 170 self.file_limit = file_limit
172 171 self.show_full_diff = show_full_diff
173 172 self.cur_diff_size = 0
174 173 self.parsed = False
175 174 self.parsed_diff = []
176 175
177 176 log.debug('Initialized DiffProcessor with %s mode', format)
178 177 if format == 'gitdiff':
179 178 self.differ = self._highlight_line_difflib
180 179 self._parser = self._parse_gitdiff
181 180 else:
182 181 self.differ = self._highlight_line_udiff
183 182 self._parser = self._new_parse_gitdiff
184 183
185 184 def _copy_iterator(self):
186 185 """
187 186 make a fresh copy of generator, we should not iterate thru
188 187 an original as it's needed for repeating operations on
189 188 this instance of DiffProcessor
190 189 """
191 190 self.__udiff, iterator_copy = tee(self.__udiff)
192 191 return iterator_copy
193 192
194 193 def _escaper(self, string):
195 194 """
196 195 Escaper for diff escapes special chars and checks the diff limit
197 196
198 197 :param string:
199 198 """
200 199 self.cur_diff_size += len(string)
201 200
202 201 if not self.show_full_diff and (self.cur_diff_size > self.diff_limit):
203 202 raise DiffLimitExceeded('Diff Limit Exceeded')
204 203
205 204 return string \
206 205 .replace('&', '&amp;')\
207 206 .replace('<', '&lt;')\
208 207 .replace('>', '&gt;')
209 208
210 209 def _line_counter(self, l):
211 210 """
212 211 Checks each line and bumps total adds/removes for this diff
213 212
214 213 :param l:
215 214 """
216 215 if l.startswith('+') and not l.startswith('+++'):
217 216 self.adds += 1
218 217 elif l.startswith('-') and not l.startswith('---'):
219 218 self.removes += 1
220 219 return safe_unicode(l)
221 220
222 221 def _highlight_line_difflib(self, line, next_):
223 222 """
224 223 Highlight inline changes in both lines.
225 224 """
226 225
227 226 if line['action'] == Action.DELETE:
228 227 old, new = line, next_
229 228 else:
230 229 old, new = next_, line
231 230
232 231 oldwords = self._token_re.split(old['line'])
233 232 newwords = self._token_re.split(new['line'])
234 233 sequence = difflib.SequenceMatcher(None, oldwords, newwords)
235 234
236 235 oldfragments, newfragments = [], []
237 236 for tag, i1, i2, j1, j2 in sequence.get_opcodes():
238 237 oldfrag = ''.join(oldwords[i1:i2])
239 238 newfrag = ''.join(newwords[j1:j2])
240 239 if tag != 'equal':
241 240 if oldfrag:
242 241 oldfrag = '<del>%s</del>' % oldfrag
243 242 if newfrag:
244 243 newfrag = '<ins>%s</ins>' % newfrag
245 244 oldfragments.append(oldfrag)
246 245 newfragments.append(newfrag)
247 246
248 247 old['line'] = "".join(oldfragments)
249 248 new['line'] = "".join(newfragments)
250 249
251 250 def _highlight_line_udiff(self, line, next_):
252 251 """
253 252 Highlight inline changes in both lines.
254 253 """
255 254 start = 0
256 255 limit = min(len(line['line']), len(next_['line']))
257 256 while start < limit and line['line'][start] == next_['line'][start]:
258 257 start += 1
259 258 end = -1
260 259 limit -= start
261 260 while -end <= limit and line['line'][end] == next_['line'][end]:
262 261 end -= 1
263 262 end += 1
264 263 if start or end:
265 264 def do(l):
266 265 last = end + len(l['line'])
267 266 if l['action'] == Action.ADD:
268 267 tag = 'ins'
269 268 else:
270 269 tag = 'del'
271 270 l['line'] = '%s<%s>%s</%s>%s' % (
272 271 l['line'][:start],
273 272 tag,
274 273 l['line'][start:last],
275 274 tag,
276 275 l['line'][last:]
277 276 )
278 277 do(line)
279 278 do(next_)
280 279
281 280 def _clean_line(self, line, command):
282 281 if command in ['+', '-', ' ']:
283 282 # only modify the line if it's actually a diff thing
284 283 line = line[1:]
285 284 return line
286 285
287 286 def _parse_gitdiff(self, inline_diff=True):
288 287 _files = []
289 288 diff_container = lambda arg: arg
290 289
291 290 for chunk in self._diff.chunks():
292 291 head = chunk.header
293 292
294 293 diff = map(self._escaper, self.diff_splitter(chunk.diff))
295 294 raw_diff = chunk.raw
296 295 limited_diff = False
297 296 exceeds_limit = False
298 297
299 298 op = None
300 299 stats = {
301 300 'added': 0,
302 301 'deleted': 0,
303 302 'binary': False,
304 303 'ops': {},
305 304 }
306 305
307 306 if head['deleted_file_mode']:
308 307 op = OPS.DEL
309 308 stats['binary'] = True
310 309 stats['ops'][DEL_FILENODE] = 'deleted file'
311 310
312 311 elif head['new_file_mode']:
313 312 op = OPS.ADD
314 313 stats['binary'] = True
315 314 stats['ops'][NEW_FILENODE] = 'new file %s' % head['new_file_mode']
316 315 else: # modify operation, can be copy, rename or chmod
317 316
318 317 # CHMOD
319 318 if head['new_mode'] and head['old_mode']:
320 319 op = OPS.MOD
321 320 stats['binary'] = True
322 321 stats['ops'][CHMOD_FILENODE] = (
323 322 'modified file chmod %s => %s' % (
324 323 head['old_mode'], head['new_mode']))
325 324 # RENAME
326 325 if head['rename_from'] != head['rename_to']:
327 326 op = OPS.MOD
328 327 stats['binary'] = True
329 328 stats['ops'][RENAMED_FILENODE] = (
330 329 'file renamed from %s to %s' % (
331 330 head['rename_from'], head['rename_to']))
332 331 # COPY
333 332 if head.get('copy_from') and head.get('copy_to'):
334 333 op = OPS.MOD
335 334 stats['binary'] = True
336 335 stats['ops'][COPIED_FILENODE] = (
337 336 'file copied from %s to %s' % (
338 337 head['copy_from'], head['copy_to']))
339 338
340 339 # If our new parsed headers didn't match anything fallback to
341 340 # old style detection
342 341 if op is None:
343 342 if not head['a_file'] and head['b_file']:
344 343 op = OPS.ADD
345 344 stats['binary'] = True
346 345 stats['ops'][NEW_FILENODE] = 'new file'
347 346
348 347 elif head['a_file'] and not head['b_file']:
349 348 op = OPS.DEL
350 349 stats['binary'] = True
351 350 stats['ops'][DEL_FILENODE] = 'deleted file'
352 351
353 352 # it's not ADD not DELETE
354 353 if op is None:
355 354 op = OPS.MOD
356 355 stats['binary'] = True
357 356 stats['ops'][MOD_FILENODE] = 'modified file'
358 357
359 358 # a real non-binary diff
360 359 if head['a_file'] or head['b_file']:
361 360 try:
362 361 raw_diff, chunks, _stats = self._parse_lines(diff)
363 362 stats['binary'] = False
364 363 stats['added'] = _stats[0]
365 364 stats['deleted'] = _stats[1]
366 365 # explicit mark that it's a modified file
367 366 if op == OPS.MOD:
368 367 stats['ops'][MOD_FILENODE] = 'modified file'
369 368 exceeds_limit = len(raw_diff) > self.file_limit
370 369
371 370 # changed from _escaper function so we validate size of
372 371 # each file instead of the whole diff
373 372 # diff will hide big files but still show small ones
374 373 # from my tests, big files are fairly safe to be parsed
375 374 # but the browser is the bottleneck
376 375 if not self.show_full_diff and exceeds_limit:
377 376 raise DiffLimitExceeded('File Limit Exceeded')
378 377
379 378 except DiffLimitExceeded:
380 379 diff_container = lambda _diff: \
381 380 LimitedDiffContainer(
382 381 self.diff_limit, self.cur_diff_size, _diff)
383 382
384 383 exceeds_limit = len(raw_diff) > self.file_limit
385 384 limited_diff = True
386 385 chunks = []
387 386
388 387 else: # GIT format binary patch, or possibly empty diff
389 388 if head['bin_patch']:
390 389 # we have operation already extracted, but we mark simply
391 390 # it's a diff we wont show for binary files
392 391 stats['ops'][BIN_FILENODE] = 'binary diff hidden'
393 392 chunks = []
394 393
395 394 if chunks and not self.show_full_diff and op == OPS.DEL:
396 395 # if not full diff mode show deleted file contents
397 396 # TODO: anderson: if the view is not too big, there is no way
398 397 # to see the content of the file
399 398 chunks = []
400 399
401 400 chunks.insert(0, [{
402 401 'old_lineno': '',
403 402 'new_lineno': '',
404 403 'action': Action.CONTEXT,
405 404 'line': msg,
406 405 } for _op, msg in stats['ops'].items()
407 406 if _op not in [MOD_FILENODE]])
408 407
409 408 _files.append({
410 409 'filename': safe_unicode(head['b_path']),
411 410 'old_revision': head['a_blob_id'],
412 411 'new_revision': head['b_blob_id'],
413 412 'chunks': chunks,
414 413 'raw_diff': safe_unicode(raw_diff),
415 414 'operation': op,
416 415 'stats': stats,
417 416 'exceeds_limit': exceeds_limit,
418 417 'is_limited_diff': limited_diff,
419 418 })
420 419
421 420 sorter = lambda info: {OPS.ADD: 0, OPS.MOD: 1,
422 421 OPS.DEL: 2}.get(info['operation'])
423 422
424 423 if not inline_diff:
425 424 return diff_container(sorted(_files, key=sorter))
426 425
427 426 # highlight inline changes
428 427 for diff_data in _files:
429 428 for chunk in diff_data['chunks']:
430 429 lineiter = iter(chunk)
431 430 try:
432 431 while 1:
433 432 line = next(lineiter)
434 433 if line['action'] not in (
435 434 Action.UNMODIFIED, Action.CONTEXT):
436 435 nextline = next(lineiter)
437 436 if nextline['action'] in ['unmod', 'context'] or \
438 437 nextline['action'] == line['action']:
439 438 continue
440 439 self.differ(line, nextline)
441 440 except StopIteration:
442 441 pass
443 442
444 443 return diff_container(sorted(_files, key=sorter))
445 444
446 445 def _check_large_diff(self):
447 446 if self.diff_limit:
448 447 log.debug('Checking if diff exceeds current diff_limit of %s', self.diff_limit)
449 448 if not self.show_full_diff and (self.cur_diff_size > self.diff_limit):
450 449 raise DiffLimitExceeded('Diff Limit `%s` Exceeded', self.diff_limit)
451 450
452 451 # FIXME: NEWDIFFS: dan: this replaces _parse_gitdiff
453 452 def _new_parse_gitdiff(self, inline_diff=True):
454 453 _files = []
455 454
456 455 # this can be overriden later to a LimitedDiffContainer type
457 456 diff_container = lambda arg: arg
458 457
459 458 for chunk in self._diff.chunks():
460 459 head = chunk.header
461 460 log.debug('parsing diff %r', head)
462 461
463 462 raw_diff = chunk.raw
464 463 limited_diff = False
465 464 exceeds_limit = False
466 465
467 466 op = None
468 467 stats = {
469 468 'added': 0,
470 469 'deleted': 0,
471 470 'binary': False,
472 471 'old_mode': None,
473 472 'new_mode': None,
474 473 'ops': {},
475 474 }
476 475 if head['old_mode']:
477 476 stats['old_mode'] = head['old_mode']
478 477 if head['new_mode']:
479 478 stats['new_mode'] = head['new_mode']
480 479 if head['b_mode']:
481 480 stats['new_mode'] = head['b_mode']
482 481
483 482 # delete file
484 483 if head['deleted_file_mode']:
485 484 op = OPS.DEL
486 485 stats['binary'] = True
487 486 stats['ops'][DEL_FILENODE] = 'deleted file'
488 487
489 488 # new file
490 489 elif head['new_file_mode']:
491 490 op = OPS.ADD
492 491 stats['binary'] = True
493 492 stats['old_mode'] = None
494 493 stats['new_mode'] = head['new_file_mode']
495 494 stats['ops'][NEW_FILENODE] = 'new file %s' % head['new_file_mode']
496 495
497 496 # modify operation, can be copy, rename or chmod
498 497 else:
499 498 # CHMOD
500 499 if head['new_mode'] and head['old_mode']:
501 500 op = OPS.MOD
502 501 stats['binary'] = True
503 502 stats['ops'][CHMOD_FILENODE] = (
504 503 'modified file chmod %s => %s' % (
505 504 head['old_mode'], head['new_mode']))
506 505
507 506 # RENAME
508 507 if head['rename_from'] != head['rename_to']:
509 508 op = OPS.MOD
510 509 stats['binary'] = True
511 510 stats['renamed'] = (head['rename_from'], head['rename_to'])
512 511 stats['ops'][RENAMED_FILENODE] = (
513 512 'file renamed from %s to %s' % (
514 513 head['rename_from'], head['rename_to']))
515 514 # COPY
516 515 if head.get('copy_from') and head.get('copy_to'):
517 516 op = OPS.MOD
518 517 stats['binary'] = True
519 518 stats['copied'] = (head['copy_from'], head['copy_to'])
520 519 stats['ops'][COPIED_FILENODE] = (
521 520 'file copied from %s to %s' % (
522 521 head['copy_from'], head['copy_to']))
523 522
524 523 # If our new parsed headers didn't match anything fallback to
525 524 # old style detection
526 525 if op is None:
527 526 if not head['a_file'] and head['b_file']:
528 527 op = OPS.ADD
529 528 stats['binary'] = True
530 529 stats['new_file'] = True
531 530 stats['ops'][NEW_FILENODE] = 'new file'
532 531
533 532 elif head['a_file'] and not head['b_file']:
534 533 op = OPS.DEL
535 534 stats['binary'] = True
536 535 stats['ops'][DEL_FILENODE] = 'deleted file'
537 536
538 537 # it's not ADD not DELETE
539 538 if op is None:
540 539 op = OPS.MOD
541 540 stats['binary'] = True
542 541 stats['ops'][MOD_FILENODE] = 'modified file'
543 542
544 543 # a real non-binary diff
545 544 if head['a_file'] or head['b_file']:
546 545 # simulate splitlines, so we keep the line end part
547 546 diff = self.diff_splitter(chunk.diff)
548 547
549 548 # append each file to the diff size
550 549 raw_chunk_size = len(raw_diff)
551 550
552 551 exceeds_limit = raw_chunk_size > self.file_limit
553 552 self.cur_diff_size += raw_chunk_size
554 553
555 554 try:
556 555 # Check each file instead of the whole diff.
557 556 # Diff will hide big files but still show small ones.
558 557 # From the tests big files are fairly safe to be parsed
559 558 # but the browser is the bottleneck.
560 559 if not self.show_full_diff and exceeds_limit:
561 560 log.debug('File `%s` exceeds current file_limit of %s',
562 561 safe_unicode(head['b_path']), self.file_limit)
563 562 raise DiffLimitExceeded(
564 563 'File Limit %s Exceeded', self.file_limit)
565 564
566 565 self._check_large_diff()
567 566
568 567 raw_diff, chunks, _stats = self._new_parse_lines(diff)
569 568 stats['binary'] = False
570 569 stats['added'] = _stats[0]
571 570 stats['deleted'] = _stats[1]
572 571 # explicit mark that it's a modified file
573 572 if op == OPS.MOD:
574 573 stats['ops'][MOD_FILENODE] = 'modified file'
575 574
576 575 except DiffLimitExceeded:
577 576 diff_container = lambda _diff: \
578 577 LimitedDiffContainer(
579 578 self.diff_limit, self.cur_diff_size, _diff)
580 579
581 580 limited_diff = True
582 581 chunks = []
583 582
584 583 else: # GIT format binary patch, or possibly empty diff
585 584 if head['bin_patch']:
586 585 # we have operation already extracted, but we mark simply
587 586 # it's a diff we wont show for binary files
588 587 stats['ops'][BIN_FILENODE] = 'binary diff hidden'
589 588 chunks = []
590 589
591 590 # Hide content of deleted node by setting empty chunks
592 591 if chunks and not self.show_full_diff and op == OPS.DEL:
593 592 # if not full diff mode show deleted file contents
594 593 # TODO: anderson: if the view is not too big, there is no way
595 594 # to see the content of the file
596 595 chunks = []
597 596
598 597 chunks.insert(
599 598 0, [{'old_lineno': '',
600 599 'new_lineno': '',
601 600 'action': Action.CONTEXT,
602 601 'line': msg,
603 602 } for _op, msg in stats['ops'].items()
604 603 if _op not in [MOD_FILENODE]])
605 604
606 605 original_filename = safe_unicode(head['a_path'])
607 606 _files.append({
608 607 'original_filename': original_filename,
609 608 'filename': safe_unicode(head['b_path']),
610 609 'old_revision': head['a_blob_id'],
611 610 'new_revision': head['b_blob_id'],
612 611 'chunks': chunks,
613 612 'raw_diff': safe_unicode(raw_diff),
614 613 'operation': op,
615 614 'stats': stats,
616 615 'exceeds_limit': exceeds_limit,
617 616 'is_limited_diff': limited_diff,
618 617 })
619 618
620 619 sorter = lambda info: {OPS.ADD: 0, OPS.MOD: 1,
621 620 OPS.DEL: 2}.get(info['operation'])
622 621
623 622 return diff_container(sorted(_files, key=sorter))
624 623
625 624 # FIXME: NEWDIFFS: dan: this gets replaced by _new_parse_lines
626 625 def _parse_lines(self, diff_iter):
627 626 """
628 627 Parse the diff an return data for the template.
629 628 """
630 629
631 630 stats = [0, 0]
632 631 chunks = []
633 632 raw_diff = []
634 633
635 634 try:
636 635 line = next(diff_iter)
637 636
638 637 while line:
639 638 raw_diff.append(line)
640 639 lines = []
641 640 chunks.append(lines)
642 641
643 642 match = self._chunk_re.match(line)
644 643
645 644 if not match:
646 645 break
647 646
648 647 gr = match.groups()
649 648 (old_line, old_end,
650 649 new_line, new_end) = [int(x or 1) for x in gr[:-1]]
651 650 old_line -= 1
652 651 new_line -= 1
653 652
654 653 context = len(gr) == 5
655 654 old_end += old_line
656 655 new_end += new_line
657 656
658 657 if context:
659 658 # skip context only if it's first line
660 659 if int(gr[0]) > 1:
661 660 lines.append({
662 661 'old_lineno': '...',
663 662 'new_lineno': '...',
664 663 'action': Action.CONTEXT,
665 664 'line': line,
666 665 })
667 666
668 667 line = next(diff_iter)
669 668
670 669 while old_line < old_end or new_line < new_end:
671 670 command = ' '
672 671 if line:
673 672 command = line[0]
674 673
675 674 affects_old = affects_new = False
676 675
677 676 # ignore those if we don't expect them
678 677 if command in '#@':
679 678 continue
680 679 elif command == '+':
681 680 affects_new = True
682 681 action = Action.ADD
683 682 stats[0] += 1
684 683 elif command == '-':
685 684 affects_old = True
686 685 action = Action.DELETE
687 686 stats[1] += 1
688 687 else:
689 688 affects_old = affects_new = True
690 689 action = Action.UNMODIFIED
691 690
692 691 if not self._newline_marker.match(line):
693 692 old_line += affects_old
694 693 new_line += affects_new
695 694 lines.append({
696 695 'old_lineno': affects_old and old_line or '',
697 696 'new_lineno': affects_new and new_line or '',
698 697 'action': action,
699 698 'line': self._clean_line(line, command)
700 699 })
701 700 raw_diff.append(line)
702 701
703 702 line = next(diff_iter)
704 703
705 704 if self._newline_marker.match(line):
706 705 # we need to append to lines, since this is not
707 706 # counted in the line specs of diff
708 707 lines.append({
709 708 'old_lineno': '...',
710 709 'new_lineno': '...',
711 710 'action': Action.CONTEXT,
712 711 'line': self._clean_line(line, command)
713 712 })
714 713
715 714 except StopIteration:
716 715 pass
717 716 return ''.join(raw_diff), chunks, stats
718 717
719 718 # FIXME: NEWDIFFS: dan: this replaces _parse_lines
720 719 def _new_parse_lines(self, diff_iter):
721 720 """
722 721 Parse the diff an return data for the template.
723 722 """
724 723
725 724 stats = [0, 0]
726 725 chunks = []
727 726 raw_diff = []
728 727
729 728 try:
730 729 line = next(diff_iter)
731 730
732 731 while line:
733 732 raw_diff.append(line)
734 733 # match header e.g @@ -0,0 +1 @@\n'
735 734 match = self._chunk_re.match(line)
736 735
737 736 if not match:
738 737 break
739 738
740 739 gr = match.groups()
741 740 (old_line, old_end,
742 741 new_line, new_end) = [int(x or 1) for x in gr[:-1]]
743 742
744 743 lines = []
745 744 hunk = {
746 745 'section_header': gr[-1],
747 746 'source_start': old_line,
748 747 'source_length': old_end,
749 748 'target_start': new_line,
750 749 'target_length': new_end,
751 750 'lines': lines,
752 751 }
753 752 chunks.append(hunk)
754 753
755 754 old_line -= 1
756 755 new_line -= 1
757 756
758 757 context = len(gr) == 5
759 758 old_end += old_line
760 759 new_end += new_line
761 760
762 761 line = next(diff_iter)
763 762
764 763 while old_line < old_end or new_line < new_end:
765 764 command = ' '
766 765 if line:
767 766 command = line[0]
768 767
769 768 affects_old = affects_new = False
770 769
771 770 # ignore those if we don't expect them
772 771 if command in '#@':
773 772 continue
774 773 elif command == '+':
775 774 affects_new = True
776 775 action = Action.ADD
777 776 stats[0] += 1
778 777 elif command == '-':
779 778 affects_old = True
780 779 action = Action.DELETE
781 780 stats[1] += 1
782 781 else:
783 782 affects_old = affects_new = True
784 783 action = Action.UNMODIFIED
785 784
786 785 if not self._newline_marker.match(line):
787 786 old_line += affects_old
788 787 new_line += affects_new
789 788 lines.append({
790 789 'old_lineno': affects_old and old_line or '',
791 790 'new_lineno': affects_new and new_line or '',
792 791 'action': action,
793 792 'line': self._clean_line(line, command)
794 793 })
795 794 raw_diff.append(line)
796 795
797 796 line = next(diff_iter)
798 797
799 798 if self._newline_marker.match(line):
800 799 # we need to append to lines, since this is not
801 800 # counted in the line specs of diff
802 801 if affects_old:
803 802 action = Action.OLD_NO_NL
804 803 elif affects_new:
805 804 action = Action.NEW_NO_NL
806 805 else:
807 806 raise Exception('invalid context for no newline')
808 807
809 808 lines.append({
810 809 'old_lineno': None,
811 810 'new_lineno': None,
812 811 'action': action,
813 812 'line': self._clean_line(line, command)
814 813 })
815 814
816 815 except StopIteration:
817 816 pass
818 817
819 818 return ''.join(raw_diff), chunks, stats
820 819
821 820 def _safe_id(self, idstring):
822 821 """Make a string safe for including in an id attribute.
823 822
824 823 The HTML spec says that id attributes 'must begin with
825 824 a letter ([A-Za-z]) and may be followed by any number
826 825 of letters, digits ([0-9]), hyphens ("-"), underscores
827 826 ("_"), colons (":"), and periods (".")'. These regexps
828 827 are slightly over-zealous, in that they remove colons
829 828 and periods unnecessarily.
830 829
831 830 Whitespace is transformed into underscores, and then
832 831 anything which is not a hyphen or a character that
833 832 matches \w (alphanumerics and underscore) is removed.
834 833
835 834 """
836 835 # Transform all whitespace to underscore
837 836 idstring = re.sub(r'\s', "_", '%s' % idstring)
838 837 # Remove everything that is not a hyphen or a member of \w
839 838 idstring = re.sub(r'(?!-)\W', "", idstring).lower()
840 839 return idstring
841 840
842 841 @classmethod
843 842 def diff_splitter(cls, string):
844 843 """
845 844 Diff split that emulates .splitlines() but works only on \n
846 845 """
847 846 if not string:
848 847 return
849 848 elif string == '\n':
850 yield u'\n'
849 yield '\n'
851 850 else:
852 851
853 852 has_newline = string.endswith('\n')
854 853 elements = string.split('\n')
855 854 if has_newline:
856 855 # skip last element as it's empty string from newlines
857 856 elements = elements[:-1]
858 857
859 858 len_elements = len(elements)
860 859
861 860 for cnt, line in enumerate(elements, start=1):
862 861 last_line = cnt == len_elements
863 862 if last_line and not has_newline:
864 863 yield safe_unicode(line)
865 864 else:
866 865 yield safe_unicode(line) + '\n'
867 866
868 867 def prepare(self, inline_diff=True):
869 868 """
870 869 Prepare the passed udiff for HTML rendering.
871 870
872 871 :return: A list of dicts with diff information.
873 872 """
874 873 parsed = self._parser(inline_diff=inline_diff)
875 874 self.parsed = True
876 875 self.parsed_diff = parsed
877 876 return parsed
878 877
879 878 def as_raw(self, diff_lines=None):
880 879 """
881 880 Returns raw diff as a byte string
882 881 """
883 882 return self._diff.raw
884 883
885 884 def as_html(self, table_class='code-difftable', line_class='line',
886 885 old_lineno_class='lineno old', new_lineno_class='lineno new',
887 886 code_class='code', enable_comments=False, parsed_lines=None):
888 887 """
889 888 Return given diff as html table with customized css classes
890 889 """
891 890 # TODO(marcink): not sure how to pass in translator
892 891 # here in an efficient way, leave the _ for proper gettext extraction
893 892 _ = lambda s: s
894 893
895 894 def _link_to_if(condition, label, url):
896 895 """
897 896 Generates a link if condition is meet or just the label if not.
898 897 """
899 898
900 899 if condition:
901 900 return '''<a href="%(url)s" class="tooltip"
902 901 title="%(title)s">%(label)s</a>''' % {
903 902 'title': _('Click to select line'),
904 903 'url': url,
905 904 'label': label
906 905 }
907 906 else:
908 907 return label
909 908 if not self.parsed:
910 909 self.prepare()
911 910
912 911 diff_lines = self.parsed_diff
913 912 if parsed_lines:
914 913 diff_lines = parsed_lines
915 914
916 915 _html_empty = True
917 916 _html = []
918 917 _html.append('''<table class="%(table_class)s">\n''' % {
919 918 'table_class': table_class
920 919 })
921 920
922 921 for diff in diff_lines:
923 922 for line in diff['chunks']:
924 923 _html_empty = False
925 924 for change in line:
926 925 _html.append('''<tr class="%(lc)s %(action)s">\n''' % {
927 926 'lc': line_class,
928 927 'action': change['action']
929 928 })
930 929 anchor_old_id = ''
931 930 anchor_new_id = ''
932 931 anchor_old = "%(filename)s_o%(oldline_no)s" % {
933 932 'filename': self._safe_id(diff['filename']),
934 933 'oldline_no': change['old_lineno']
935 934 }
936 935 anchor_new = "%(filename)s_n%(oldline_no)s" % {
937 936 'filename': self._safe_id(diff['filename']),
938 937 'oldline_no': change['new_lineno']
939 938 }
940 939 cond_old = (change['old_lineno'] != '...' and
941 940 change['old_lineno'])
942 941 cond_new = (change['new_lineno'] != '...' and
943 942 change['new_lineno'])
944 943 if cond_old:
945 944 anchor_old_id = 'id="%s"' % anchor_old
946 945 if cond_new:
947 946 anchor_new_id = 'id="%s"' % anchor_new
948 947
949 948 if change['action'] != Action.CONTEXT:
950 949 anchor_link = True
951 950 else:
952 951 anchor_link = False
953 952
954 953 ###########################################################
955 954 # COMMENT ICONS
956 955 ###########################################################
957 956 _html.append('''\t<td class="add-comment-line"><span class="add-comment-content">''')
958 957
959 958 if enable_comments and change['action'] != Action.CONTEXT:
960 959 _html.append('''<a href="#"><span class="icon-comment-add"></span></a>''')
961 960
962 961 _html.append('''</span></td><td class="comment-toggle tooltip" title="Toggle Comment Thread"><i class="icon-comment"></i></td>\n''')
963 962
964 963 ###########################################################
965 964 # OLD LINE NUMBER
966 965 ###########################################################
967 966 _html.append('''\t<td %(a_id)s class="%(olc)s">''' % {
968 967 'a_id': anchor_old_id,
969 968 'olc': old_lineno_class
970 969 })
971 970
972 971 _html.append('''%(link)s''' % {
973 972 'link': _link_to_if(anchor_link, change['old_lineno'],
974 973 '#%s' % anchor_old)
975 974 })
976 975 _html.append('''</td>\n''')
977 976 ###########################################################
978 977 # NEW LINE NUMBER
979 978 ###########################################################
980 979
981 980 _html.append('''\t<td %(a_id)s class="%(nlc)s">''' % {
982 981 'a_id': anchor_new_id,
983 982 'nlc': new_lineno_class
984 983 })
985 984
986 985 _html.append('''%(link)s''' % {
987 986 'link': _link_to_if(anchor_link, change['new_lineno'],
988 987 '#%s' % anchor_new)
989 988 })
990 989 _html.append('''</td>\n''')
991 990 ###########################################################
992 991 # CODE
993 992 ###########################################################
994 993 code_classes = [code_class]
995 994 if (not enable_comments or
996 995 change['action'] == Action.CONTEXT):
997 996 code_classes.append('no-comment')
998 997 _html.append('\t<td class="%s">' % ' '.join(code_classes))
999 998 _html.append('''\n\t\t<pre>%(code)s</pre>\n''' % {
1000 999 'code': change['line']
1001 1000 })
1002 1001
1003 1002 _html.append('''\t</td>''')
1004 1003 _html.append('''\n</tr>\n''')
1005 1004 _html.append('''</table>''')
1006 1005 if _html_empty:
1007 1006 return None
1008 1007 return ''.join(_html)
1009 1008
1010 1009 def stat(self):
1011 1010 """
1012 1011 Returns tuple of added, and removed lines for this instance
1013 1012 """
1014 1013 return self.adds, self.removes
1015 1014
1016 1015 def get_context_of_line(
1017 1016 self, path, diff_line=None, context_before=3, context_after=3):
1018 1017 """
1019 1018 Returns the context lines for the specified diff line.
1020 1019
1021 1020 :type diff_line: :class:`DiffLineNumber`
1022 1021 """
1023 1022 assert self.parsed, "DiffProcessor is not initialized."
1024 1023
1025 1024 if None not in diff_line:
1026 1025 raise ValueError(
1027 1026 "Cannot specify both line numbers: {}".format(diff_line))
1028 1027
1029 1028 file_diff = self._get_file_diff(path)
1030 1029 chunk, idx = self._find_chunk_line_index(file_diff, diff_line)
1031 1030
1032 1031 first_line_to_include = max(idx - context_before, 0)
1033 1032 first_line_after_context = idx + context_after + 1
1034 1033 context_lines = chunk[first_line_to_include:first_line_after_context]
1035 1034
1036 1035 line_contents = [
1037 1036 _context_line(line) for line in context_lines
1038 1037 if _is_diff_content(line)]
1039 1038 # TODO: johbo: Interim fixup, the diff chunks drop the final newline.
1040 1039 # Once they are fixed, we can drop this line here.
1041 1040 if line_contents:
1042 1041 line_contents[-1] = (
1043 1042 line_contents[-1][0], line_contents[-1][1].rstrip('\n') + '\n')
1044 1043 return line_contents
1045 1044
1046 1045 def find_context(self, path, context, offset=0):
1047 1046 """
1048 1047 Finds the given `context` inside of the diff.
1049 1048
1050 1049 Use the parameter `offset` to specify which offset the target line has
1051 1050 inside of the given `context`. This way the correct diff line will be
1052 1051 returned.
1053 1052
1054 1053 :param offset: Shall be used to specify the offset of the main line
1055 1054 within the given `context`.
1056 1055 """
1057 1056 if offset < 0 or offset >= len(context):
1058 1057 raise ValueError(
1059 1058 "Only positive values up to the length of the context "
1060 1059 "minus one are allowed.")
1061 1060
1062 1061 matches = []
1063 1062 file_diff = self._get_file_diff(path)
1064 1063
1065 1064 for chunk in file_diff['chunks']:
1066 1065 context_iter = iter(context)
1067 1066 for line_idx, line in enumerate(chunk):
1068 1067 try:
1069 1068 if _context_line(line) == next(context_iter):
1070 1069 continue
1071 1070 except StopIteration:
1072 1071 matches.append((line_idx, chunk))
1073 1072 context_iter = iter(context)
1074 1073
1075 1074 # Increment position and triger StopIteration
1076 1075 # if we had a match at the end
1077 1076 line_idx += 1
1078 1077 try:
1079 1078 next(context_iter)
1080 1079 except StopIteration:
1081 1080 matches.append((line_idx, chunk))
1082 1081
1083 1082 effective_offset = len(context) - offset
1084 1083 found_at_diff_lines = [
1085 1084 _line_to_diff_line_number(chunk[idx - effective_offset])
1086 1085 for idx, chunk in matches]
1087 1086
1088 1087 return found_at_diff_lines
1089 1088
1090 1089 def _get_file_diff(self, path):
1091 1090 for file_diff in self.parsed_diff:
1092 1091 if file_diff['filename'] == path:
1093 1092 break
1094 1093 else:
1095 1094 raise FileNotInDiffException("File {} not in diff".format(path))
1096 1095 return file_diff
1097 1096
1098 1097 def _find_chunk_line_index(self, file_diff, diff_line):
1099 1098 for chunk in file_diff['chunks']:
1100 1099 for idx, line in enumerate(chunk):
1101 1100 if line['old_lineno'] == diff_line.old:
1102 1101 return chunk, idx
1103 1102 if line['new_lineno'] == diff_line.new:
1104 1103 return chunk, idx
1105 1104 raise LineNotInDiffException(
1106 1105 "The line {} is not part of the diff.".format(diff_line))
1107 1106
1108 1107
1109 1108 def _is_diff_content(line):
1110 1109 return line['action'] in (
1111 1110 Action.UNMODIFIED, Action.ADD, Action.DELETE)
1112 1111
1113 1112
1114 1113 def _context_line(line):
1115 1114 return (line['action'], line['line'])
1116 1115
1117 1116
1118 1117 DiffLineNumber = collections.namedtuple('DiffLineNumber', ['old', 'new'])
1119 1118
1120 1119
1121 1120 def _line_to_diff_line_number(line):
1122 1121 new_line_no = line['new_lineno'] or None
1123 1122 old_line_no = line['old_lineno'] or None
1124 1123 return DiffLineNumber(old=old_line_no, new=new_line_no)
1125 1124
1126 1125
1127 1126 class FileNotInDiffException(Exception):
1128 1127 """
1129 1128 Raised when the context for a missing file is requested.
1130 1129
1131 1130 If you request the context for a line in a file which is not part of the
1132 1131 given diff, then this exception is raised.
1133 1132 """
1134 1133
1135 1134
1136 1135 class LineNotInDiffException(Exception):
1137 1136 """
1138 1137 Raised when the context for a missing line is requested.
1139 1138
1140 1139 If you request the context for a line in a file and this line is not
1141 1140 part of the given diff, then this exception is raised.
1142 1141 """
1143 1142
1144 1143
1145 1144 class DiffLimitExceeded(Exception):
1146 1145 pass
1147 1146
1148 1147
1149 1148 # NOTE(marcink): if diffs.mako change, probably this
1150 1149 # needs a bump to next version
1151 1150 CURRENT_DIFF_VERSION = 'v5'
1152 1151
1153 1152
1154 1153 def _cleanup_cache_file(cached_diff_file):
1155 1154 # cleanup file to not store it "damaged"
1156 1155 try:
1157 1156 os.remove(cached_diff_file)
1158 1157 except Exception:
1159 1158 log.exception('Failed to cleanup path %s', cached_diff_file)
1160 1159
1161 1160
1162 1161 def _get_compression_mode(cached_diff_file):
1163 1162 mode = 'bz2'
1164 1163 if 'mode:plain' in cached_diff_file:
1165 1164 mode = 'plain'
1166 1165 elif 'mode:gzip' in cached_diff_file:
1167 1166 mode = 'gzip'
1168 1167 return mode
1169 1168
1170 1169
1171 1170 def cache_diff(cached_diff_file, diff, commits):
1172 1171 compression_mode = _get_compression_mode(cached_diff_file)
1173 1172
1174 1173 struct = {
1175 1174 'version': CURRENT_DIFF_VERSION,
1176 1175 'diff': diff,
1177 1176 'commits': commits
1178 1177 }
1179 1178
1180 1179 start = time.time()
1181 1180 try:
1182 1181 if compression_mode == 'plain':
1183 1182 with open(cached_diff_file, 'wb') as f:
1184 1183 pickle.dump(struct, f)
1185 1184 elif compression_mode == 'gzip':
1186 1185 with gzip.GzipFile(cached_diff_file, 'wb') as f:
1187 1186 pickle.dump(struct, f)
1188 1187 else:
1189 1188 with bz2.BZ2File(cached_diff_file, 'wb') as f:
1190 1189 pickle.dump(struct, f)
1191 1190 except Exception:
1192 1191 log.warn('Failed to save cache', exc_info=True)
1193 1192 _cleanup_cache_file(cached_diff_file)
1194 1193
1195 1194 log.debug('Saved diff cache under %s in %.4fs', cached_diff_file, time.time() - start)
1196 1195
1197 1196
1198 1197 def load_cached_diff(cached_diff_file):
1199 1198 compression_mode = _get_compression_mode(cached_diff_file)
1200 1199
1201 1200 default_struct = {
1202 1201 'version': CURRENT_DIFF_VERSION,
1203 1202 'diff': None,
1204 1203 'commits': None
1205 1204 }
1206 1205
1207 1206 has_cache = os.path.isfile(cached_diff_file)
1208 1207 if not has_cache:
1209 1208 log.debug('Reading diff cache file failed %s', cached_diff_file)
1210 1209 return default_struct
1211 1210
1212 1211 data = None
1213 1212
1214 1213 start = time.time()
1215 1214 try:
1216 1215 if compression_mode == 'plain':
1217 1216 with open(cached_diff_file, 'rb') as f:
1218 1217 data = pickle.load(f)
1219 1218 elif compression_mode == 'gzip':
1220 1219 with gzip.GzipFile(cached_diff_file, 'rb') as f:
1221 1220 data = pickle.load(f)
1222 1221 else:
1223 1222 with bz2.BZ2File(cached_diff_file, 'rb') as f:
1224 1223 data = pickle.load(f)
1225 1224 except Exception:
1226 1225 log.warn('Failed to read diff cache file', exc_info=True)
1227 1226
1228 1227 if not data:
1229 1228 data = default_struct
1230 1229
1231 1230 if not isinstance(data, dict):
1232 1231 # old version of data ?
1233 1232 data = default_struct
1234 1233
1235 1234 # check version
1236 1235 if data.get('version') != CURRENT_DIFF_VERSION:
1237 1236 # purge cache
1238 1237 _cleanup_cache_file(cached_diff_file)
1239 1238 return default_struct
1240 1239
1241 1240 log.debug('Loaded diff cache from %s in %.4fs', cached_diff_file, time.time() - start)
1242 1241
1243 1242 return data
1244 1243
1245 1244
1246 1245 def generate_diff_cache_key(*args):
1247 1246 """
1248 1247 Helper to generate a cache key using arguments
1249 1248 """
1250 1249 def arg_mapper(input_param):
1251 1250 input_param = safe_str(input_param)
1252 1251 # we cannot allow '/' in arguments since it would allow
1253 1252 # subdirectory usage
1254 1253 input_param.replace('/', '_')
1255 1254 return input_param or None # prevent empty string arguments
1256 1255
1257 1256 return '_'.join([
1258 1257 '{}' for i in range(len(args))]).format(*map(arg_mapper, args))
1259 1258
1260 1259
1261 1260 def diff_cache_exist(cache_storage, *args):
1262 1261 """
1263 1262 Based on all generated arguments check and return a cache path
1264 1263 """
1265 1264 args = list(args) + ['mode:gzip']
1266 1265 cache_key = generate_diff_cache_key(*args)
1267 1266 cache_file_path = os.path.join(cache_storage, cache_key)
1268 1267 # prevent path traversal attacks using some param that have e.g '../../'
1269 1268 if not os.path.abspath(cache_file_path).startswith(cache_storage):
1270 1269 raise ValueError('Final path must be within {}'.format(cache_storage))
1271 1270
1272 1271 return cache_file_path
@@ -1,444 +1,444 b''
1 1 # Copyright (c) Django Software Foundation and individual contributors.
2 2 # All rights reserved.
3 3 #
4 4 # Redistribution and use in source and binary forms, with or without modification,
5 5 # are permitted provided that the following conditions are met:
6 6 #
7 7 # 1. Redistributions of source code must retain the above copyright notice,
8 8 # this list of conditions and the following disclaimer.
9 9 #
10 10 # 2. Redistributions in binary form must reproduce the above copyright
11 11 # notice, this list of conditions and the following disclaimer in the
12 12 # documentation and/or other materials provided with the distribution.
13 13 #
14 14 # 3. Neither the name of Django nor the names of its contributors may be used
15 15 # to endorse or promote products derived from this software without
16 16 # specific prior written permission.
17 17 #
18 18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
19 19 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
20 20 # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 21 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
22 22 # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
23 23 # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
24 24 # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
25 25 # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26 26 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27 27 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 28
29 29 """
30 30 For definitions of the different versions of RSS, see:
31 31 http://web.archive.org/web/20110718035220/http://diveintomark.org/archives/2004/02/04/incompatible-rss
32 32 """
33 33
34 34
35 35 import datetime
36 from io import StringIO
36 import io
37 37
38 38 import pytz
39 39 from six.moves.urllib import parse as urlparse
40 40
41 41 from rhodecode.lib.feedgenerator import datetime_safe
42 42 from rhodecode.lib.feedgenerator.utils import SimplerXMLGenerator, iri_to_uri, force_text
43 43
44 44
45 45 #### The following code comes from ``django.utils.feedgenerator`` ####
46 46
47 47
48 48 def rfc2822_date(date):
49 49 # We can't use strftime() because it produces locale-dependent results, so
50 50 # we have to map english month and day names manually
51 51 months = ('Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec',)
52 52 days = ('Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun')
53 53 # Support datetime objects older than 1900
54 54 date = datetime_safe.new_datetime(date)
55 55 # We do this ourselves to be timezone aware, email.Utils is not tz aware.
56 56 dow = days[date.weekday()]
57 57 month = months[date.month - 1]
58 58 time_str = date.strftime('%s, %%d %s %%Y %%H:%%M:%%S ' % (dow, month))
59 59
60 60 offset = date.utcoffset()
61 61 # Historically, this function assumes that naive datetimes are in UTC.
62 62 if offset is None:
63 63 return time_str + '-0000'
64 64 else:
65 65 timezone = (offset.days * 24 * 60) + (offset.seconds // 60)
66 66 hour, minute = divmod(timezone, 60)
67 67 return time_str + '%+03d%02d' % (hour, minute)
68 68
69 69
70 70 def rfc3339_date(date):
71 71 # Support datetime objects older than 1900
72 72 date = datetime_safe.new_datetime(date)
73 73 time_str = date.strftime('%Y-%m-%dT%H:%M:%S')
74 74
75 75 offset = date.utcoffset()
76 76 # Historically, this function assumes that naive datetimes are in UTC.
77 77 if offset is None:
78 78 return time_str + 'Z'
79 79 else:
80 80 timezone = (offset.days * 24 * 60) + (offset.seconds // 60)
81 81 hour, minute = divmod(timezone, 60)
82 82 return time_str + '%+03d:%02d' % (hour, minute)
83 83
84 84
85 85 def get_tag_uri(url, date):
86 86 """
87 87 Creates a TagURI.
88 88
89 89 See http://web.archive.org/web/20110514113830/http://diveintomark.org/archives/2004/05/28/howto-atom-id
90 90 """
91 91 bits = urlparse(url)
92 92 d = ''
93 93 if date is not None:
94 94 d = ',%s' % datetime_safe.new_datetime(date).strftime('%Y-%m-%d')
95 95 return 'tag:%s%s:%s/%s' % (bits.hostname, d, bits.path, bits.fragment)
96 96
97 97
98 98 class SyndicationFeed(object):
99 99 """Base class for all syndication feeds. Subclasses should provide write()"""
100 100
101 101 def __init__(self, title, link, description, language=None, author_email=None,
102 102 author_name=None, author_link=None, subtitle=None, categories=None,
103 103 feed_url=None, feed_copyright=None, feed_guid=None, ttl=None, **kwargs):
104 104 def to_unicode(s):
105 105 return force_text(s, strings_only=True)
106 106 if categories:
107 107 categories = [force_text(c) for c in categories]
108 108 if ttl is not None:
109 109 # Force ints to unicode
110 110 ttl = force_text(ttl)
111 111 self.feed = {
112 112 'title': to_unicode(title),
113 113 'link': iri_to_uri(link),
114 114 'description': to_unicode(description),
115 115 'language': to_unicode(language),
116 116 'author_email': to_unicode(author_email),
117 117 'author_name': to_unicode(author_name),
118 118 'author_link': iri_to_uri(author_link),
119 119 'subtitle': to_unicode(subtitle),
120 120 'categories': categories or (),
121 121 'feed_url': iri_to_uri(feed_url),
122 122 'feed_copyright': to_unicode(feed_copyright),
123 123 'id': feed_guid or link,
124 124 'ttl': ttl,
125 125 }
126 126 self.feed.update(kwargs)
127 127 self.items = []
128 128
129 129 def add_item(self, title, link, description, author_email=None,
130 130 author_name=None, author_link=None, pubdate=None, comments=None,
131 131 unique_id=None, unique_id_is_permalink=None, enclosure=None,
132 132 categories=(), item_copyright=None, ttl=None, updateddate=None,
133 133 enclosures=None, **kwargs):
134 134 """
135 135 Adds an item to the feed. All args are expected to be Python Unicode
136 136 objects except pubdate and updateddate, which are datetime.datetime
137 137 objects, and enclosures, which is an iterable of instances of the
138 138 Enclosure class.
139 139 """
140 140 def to_unicode(s):
141 141 return force_text(s, strings_only=True)
142 142 if categories:
143 143 categories = [to_unicode(c) for c in categories]
144 144 if ttl is not None:
145 145 # Force ints to unicode
146 146 ttl = force_text(ttl)
147 147 if enclosure is None:
148 148 enclosures = [] if enclosures is None else enclosures
149 149
150 150 item = {
151 151 'title': to_unicode(title),
152 152 'link': iri_to_uri(link),
153 153 'description': to_unicode(description),
154 154 'author_email': to_unicode(author_email),
155 155 'author_name': to_unicode(author_name),
156 156 'author_link': iri_to_uri(author_link),
157 157 'pubdate': pubdate,
158 158 'updateddate': updateddate,
159 159 'comments': to_unicode(comments),
160 160 'unique_id': to_unicode(unique_id),
161 161 'unique_id_is_permalink': unique_id_is_permalink,
162 162 'enclosures': enclosures,
163 163 'categories': categories or (),
164 164 'item_copyright': to_unicode(item_copyright),
165 165 'ttl': ttl,
166 166 }
167 167 item.update(kwargs)
168 168 self.items.append(item)
169 169
170 170 def num_items(self):
171 171 return len(self.items)
172 172
173 173 def root_attributes(self):
174 174 """
175 175 Return extra attributes to place on the root (i.e. feed/channel) element.
176 176 Called from write().
177 177 """
178 178 return {}
179 179
180 180 def add_root_elements(self, handler):
181 181 """
182 182 Add elements in the root (i.e. feed/channel) element. Called
183 183 from write().
184 184 """
185 185 pass
186 186
187 187 def item_attributes(self, item):
188 188 """
189 189 Return extra attributes to place on each item (i.e. item/entry) element.
190 190 """
191 191 return {}
192 192
193 193 def add_item_elements(self, handler, item):
194 194 """
195 195 Add elements on each item (i.e. item/entry) element.
196 196 """
197 197 pass
198 198
199 199 def write(self, outfile, encoding):
200 200 """
201 201 Outputs the feed in the given encoding to outfile, which is a file-like
202 202 object. Subclasses should override this.
203 203 """
204 204 raise NotImplementedError('subclasses of SyndicationFeed must provide a write() method')
205 205
206 206 def writeString(self, encoding):
207 207 """
208 208 Returns the feed in the given encoding as a string.
209 209 """
210 s = StringIO()
210 s = io.StringIO()
211 211 self.write(s, encoding)
212 212 return s.getvalue()
213 213
214 214 def latest_post_date(self):
215 215 """
216 216 Returns the latest item's pubdate or updateddate. If no items
217 217 have either of these attributes this returns the current UTC date/time.
218 218 """
219 219 latest_date = None
220 220 date_keys = ('updateddate', 'pubdate')
221 221
222 222 for item in self.items:
223 223 for date_key in date_keys:
224 224 item_date = item.get(date_key)
225 225 if item_date:
226 226 if latest_date is None or item_date > latest_date:
227 227 latest_date = item_date
228 228
229 229 # datetime.now(tz=utc) is slower, as documented in django.utils.timezone.now
230 230 return latest_date or datetime.datetime.utcnow().replace(tzinfo=pytz.utc)
231 231
232 232
233 233 class Enclosure(object):
234 234 """Represents an RSS enclosure"""
235 235 def __init__(self, url, length, mime_type):
236 236 """All args are expected to be Python Unicode objects"""
237 237 self.length, self.mime_type = length, mime_type
238 238 self.url = iri_to_uri(url)
239 239
240 240
241 241 class RssFeed(SyndicationFeed):
242 242 content_type = 'application/rss+xml; charset=utf-8'
243 243
244 244 def write(self, outfile, encoding):
245 245 handler = SimplerXMLGenerator(outfile, encoding)
246 246 handler.startDocument()
247 247 handler.startElement("rss", self.rss_attributes())
248 248 handler.startElement("channel", self.root_attributes())
249 249 self.add_root_elements(handler)
250 250 self.write_items(handler)
251 251 self.endChannelElement(handler)
252 252 handler.endElement("rss")
253 253
254 254 def rss_attributes(self):
255 255 return {"version": self._version,
256 256 "xmlns:atom": "http://www.w3.org/2005/Atom"}
257 257
258 258 def write_items(self, handler):
259 259 for item in self.items:
260 260 handler.startElement('item', self.item_attributes(item))
261 261 self.add_item_elements(handler, item)
262 262 handler.endElement("item")
263 263
264 264 def add_root_elements(self, handler):
265 265 handler.addQuickElement("title", self.feed['title'])
266 266 handler.addQuickElement("link", self.feed['link'])
267 267 handler.addQuickElement("description", self.feed['description'])
268 268 if self.feed['feed_url'] is not None:
269 269 handler.addQuickElement("atom:link", None, {"rel": "self", "href": self.feed['feed_url']})
270 270 if self.feed['language'] is not None:
271 271 handler.addQuickElement("language", self.feed['language'])
272 272 for cat in self.feed['categories']:
273 273 handler.addQuickElement("category", cat)
274 274 if self.feed['feed_copyright'] is not None:
275 275 handler.addQuickElement("copyright", self.feed['feed_copyright'])
276 276 handler.addQuickElement("lastBuildDate", rfc2822_date(self.latest_post_date()))
277 277 if self.feed['ttl'] is not None:
278 278 handler.addQuickElement("ttl", self.feed['ttl'])
279 279
280 280 def endChannelElement(self, handler):
281 281 handler.endElement("channel")
282 282
283 283
284 284 class RssUserland091Feed(RssFeed):
285 285 _version = "0.91"
286 286
287 287 def add_item_elements(self, handler, item):
288 288 handler.addQuickElement("title", item['title'])
289 289 handler.addQuickElement("link", item['link'])
290 290 if item['description'] is not None:
291 291 handler.addQuickElement("description", item['description'])
292 292
293 293
294 294 class Rss201rev2Feed(RssFeed):
295 295 # Spec: http://blogs.law.harvard.edu/tech/rss
296 296 _version = "2.0"
297 297
298 298 def add_item_elements(self, handler, item):
299 299 handler.addQuickElement("title", item['title'])
300 300 handler.addQuickElement("link", item['link'])
301 301 if item['description'] is not None:
302 302 handler.addQuickElement("description", item['description'])
303 303
304 304 # Author information.
305 305 if item["author_name"] and item["author_email"]:
306 306 handler.addQuickElement("author", "%s (%s)" % (item['author_email'], item['author_name']))
307 307 elif item["author_email"]:
308 308 handler.addQuickElement("author", item["author_email"])
309 309 elif item["author_name"]:
310 310 handler.addQuickElement(
311 311 "dc:creator", item["author_name"], {"xmlns:dc": "http://purl.org/dc/elements/1.1/"}
312 312 )
313 313
314 314 if item['pubdate'] is not None:
315 315 handler.addQuickElement("pubDate", rfc2822_date(item['pubdate']))
316 316 if item['comments'] is not None:
317 317 handler.addQuickElement("comments", item['comments'])
318 318 if item['unique_id'] is not None:
319 319 guid_attrs = {}
320 320 if isinstance(item.get('unique_id_is_permalink'), bool):
321 321 guid_attrs['isPermaLink'] = str(item['unique_id_is_permalink']).lower()
322 322 handler.addQuickElement("guid", item['unique_id'], guid_attrs)
323 323 if item['ttl'] is not None:
324 324 handler.addQuickElement("ttl", item['ttl'])
325 325
326 326 # Enclosure.
327 327 if item['enclosures']:
328 328 enclosures = list(item['enclosures'])
329 329 if len(enclosures) > 1:
330 330 raise ValueError(
331 331 "RSS feed items may only have one enclosure, see "
332 332 "http://www.rssboard.org/rss-profile#element-channel-item-enclosure"
333 333 )
334 334 enclosure = enclosures[0]
335 335 handler.addQuickElement('enclosure', '', {
336 336 'url': enclosure.url,
337 337 'length': enclosure.length,
338 338 'type': enclosure.mime_type,
339 339 })
340 340
341 341 # Categories.
342 342 for cat in item['categories']:
343 343 handler.addQuickElement("category", cat)
344 344
345 345
346 346 class Atom1Feed(SyndicationFeed):
347 347 # Spec: https://tools.ietf.org/html/rfc4287
348 348 content_type = 'application/atom+xml; charset=utf-8'
349 349 ns = "http://www.w3.org/2005/Atom"
350 350
351 351 def write(self, outfile, encoding):
352 352 handler = SimplerXMLGenerator(outfile, encoding)
353 353 handler.startDocument()
354 354 handler.startElement('feed', self.root_attributes())
355 355 self.add_root_elements(handler)
356 356 self.write_items(handler)
357 357 handler.endElement("feed")
358 358
359 359 def root_attributes(self):
360 360 if self.feed['language'] is not None:
361 361 return {"xmlns": self.ns, "xml:lang": self.feed['language']}
362 362 else:
363 363 return {"xmlns": self.ns}
364 364
365 365 def add_root_elements(self, handler):
366 366 handler.addQuickElement("title", self.feed['title'])
367 367 handler.addQuickElement("link", "", {"rel": "alternate", "href": self.feed['link']})
368 368 if self.feed['feed_url'] is not None:
369 369 handler.addQuickElement("link", "", {"rel": "self", "href": self.feed['feed_url']})
370 370 handler.addQuickElement("id", self.feed['id'])
371 371 handler.addQuickElement("updated", rfc3339_date(self.latest_post_date()))
372 372 if self.feed['author_name'] is not None:
373 373 handler.startElement("author", {})
374 374 handler.addQuickElement("name", self.feed['author_name'])
375 375 if self.feed['author_email'] is not None:
376 376 handler.addQuickElement("email", self.feed['author_email'])
377 377 if self.feed['author_link'] is not None:
378 378 handler.addQuickElement("uri", self.feed['author_link'])
379 379 handler.endElement("author")
380 380 if self.feed['subtitle'] is not None:
381 381 handler.addQuickElement("subtitle", self.feed['subtitle'])
382 382 for cat in self.feed['categories']:
383 383 handler.addQuickElement("category", "", {"term": cat})
384 384 if self.feed['feed_copyright'] is not None:
385 385 handler.addQuickElement("rights", self.feed['feed_copyright'])
386 386
387 387 def write_items(self, handler):
388 388 for item in self.items:
389 389 handler.startElement("entry", self.item_attributes(item))
390 390 self.add_item_elements(handler, item)
391 391 handler.endElement("entry")
392 392
393 393 def add_item_elements(self, handler, item):
394 394 handler.addQuickElement("title", item['title'])
395 395 handler.addQuickElement("link", "", {"href": item['link'], "rel": "alternate"})
396 396
397 397 if item['pubdate'] is not None:
398 398 handler.addQuickElement('published', rfc3339_date(item['pubdate']))
399 399
400 400 if item['updateddate'] is not None:
401 401 handler.addQuickElement('updated', rfc3339_date(item['updateddate']))
402 402
403 403 # Author information.
404 404 if item['author_name'] is not None:
405 405 handler.startElement("author", {})
406 406 handler.addQuickElement("name", item['author_name'])
407 407 if item['author_email'] is not None:
408 408 handler.addQuickElement("email", item['author_email'])
409 409 if item['author_link'] is not None:
410 410 handler.addQuickElement("uri", item['author_link'])
411 411 handler.endElement("author")
412 412
413 413 # Unique ID.
414 414 if item['unique_id'] is not None:
415 415 unique_id = item['unique_id']
416 416 else:
417 417 unique_id = get_tag_uri(item['link'], item['pubdate'])
418 418 handler.addQuickElement("id", unique_id)
419 419
420 420 # Summary.
421 421 if item['description'] is not None:
422 422 handler.addQuickElement("summary", item['description'], {"type": "html"})
423 423
424 424 # Enclosures.
425 425 for enclosure in item['enclosures']:
426 426 handler.addQuickElement('link', '', {
427 427 'rel': 'enclosure',
428 428 'href': enclosure.url,
429 429 'length': enclosure.length,
430 430 'type': enclosure.mime_type,
431 431 })
432 432
433 433 # Categories.
434 434 for cat in item['categories']:
435 435 handler.addQuickElement("category", "", {"term": cat})
436 436
437 437 # Rights.
438 438 if item['item_copyright'] is not None:
439 439 handler.addQuickElement("rights", item['item_copyright'])
440 440
441 441
442 442 # This isolates the decision of what the system default is, so calling code can
443 443 # do "feedgenerator.DefaultFeed" instead of "feedgenerator.Rss201rev2Feed".
444 444 DefaultFeed = Rss201rev2Feed No newline at end of file
@@ -1,538 +1,538 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2013-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21
22 22 """
23 23 Set of hooks run by RhodeCode Enterprise
24 24 """
25 25
26 26 import os
27 27 import logging
28 28
29 29 import rhodecode
30 30 from rhodecode import events
31 31 from rhodecode.lib import helpers as h
32 32 from rhodecode.lib import audit_logger
33 33 from rhodecode.lib.utils2 import safe_str, user_agent_normalizer
34 34 from rhodecode.lib.exceptions import (
35 35 HTTPLockedRC, HTTPBranchProtected, UserCreationError)
36 36 from rhodecode.model.db import Repository, User
37 37 from rhodecode.lib.statsd_client import StatsdClient
38 38
39 39 log = logging.getLogger(__name__)
40 40
41 41
42 42 class HookResponse(object):
43 43 def __init__(self, status, output):
44 44 self.status = status
45 45 self.output = output
46 46
47 47 def __add__(self, other):
48 48 other_status = getattr(other, 'status', 0)
49 49 new_status = max(self.status, other_status)
50 50 other_output = getattr(other, 'output', '')
51 51 new_output = self.output + other_output
52 52
53 53 return HookResponse(new_status, new_output)
54 54
55 55 def __bool__(self):
56 56 return self.status == 0
57 57
58 58
59 59 def is_shadow_repo(extras):
60 60 """
61 61 Returns ``True`` if this is an action executed against a shadow repository.
62 62 """
63 63 return extras['is_shadow_repo']
64 64
65 65
66 66 def _get_scm_size(alias, root_path):
67 67
68 68 if not alias.startswith('.'):
69 69 alias += '.'
70 70
71 71 size_scm, size_root = 0, 0
72 72 for path, unused_dirs, files in os.walk(safe_str(root_path)):
73 73 if path.find(alias) != -1:
74 74 for f in files:
75 75 try:
76 76 size_scm += os.path.getsize(os.path.join(path, f))
77 77 except OSError:
78 78 pass
79 79 else:
80 80 for f in files:
81 81 try:
82 82 size_root += os.path.getsize(os.path.join(path, f))
83 83 except OSError:
84 84 pass
85 85
86 86 size_scm_f = h.format_byte_size_binary(size_scm)
87 87 size_root_f = h.format_byte_size_binary(size_root)
88 88 size_total_f = h.format_byte_size_binary(size_root + size_scm)
89 89
90 90 return size_scm_f, size_root_f, size_total_f
91 91
92 92
93 93 # actual hooks called by Mercurial internally, and GIT by our Python Hooks
94 94 def repo_size(extras):
95 95 """Present size of repository after push."""
96 96 repo = Repository.get_by_repo_name(extras.repository)
97 vcs_part = safe_str(u'.%s' % repo.repo_type)
97 vcs_part = safe_str('.%s' % repo.repo_type)
98 98 size_vcs, size_root, size_total = _get_scm_size(vcs_part,
99 99 repo.repo_full_path)
100 100 msg = ('Repository `%s` size summary %s:%s repo:%s total:%s\n'
101 101 % (repo.repo_name, vcs_part, size_vcs, size_root, size_total))
102 102 return HookResponse(0, msg)
103 103
104 104
105 105 def pre_push(extras):
106 106 """
107 107 Hook executed before pushing code.
108 108
109 109 It bans pushing when the repository is locked.
110 110 """
111 111
112 112 user = User.get_by_username(extras.username)
113 113 output = ''
114 114 if extras.locked_by[0] and user.user_id != int(extras.locked_by[0]):
115 115 locked_by = User.get(extras.locked_by[0]).username
116 116 reason = extras.locked_by[2]
117 117 # this exception is interpreted in git/hg middlewares and based
118 118 # on that proper return code is server to client
119 119 _http_ret = HTTPLockedRC(
120 120 _locked_by_explanation(extras.repository, locked_by, reason))
121 121 if str(_http_ret.code).startswith('2'):
122 122 # 2xx Codes don't raise exceptions
123 123 output = _http_ret.title
124 124 else:
125 125 raise _http_ret
126 126
127 127 hook_response = ''
128 128 if not is_shadow_repo(extras):
129 129 if extras.commit_ids and extras.check_branch_perms:
130 130
131 131 auth_user = user.AuthUser()
132 132 repo = Repository.get_by_repo_name(extras.repository)
133 133 affected_branches = []
134 134 if repo.repo_type == 'hg':
135 135 for entry in extras.commit_ids:
136 136 if entry['type'] == 'branch':
137 137 is_forced = bool(entry['multiple_heads'])
138 138 affected_branches.append([entry['name'], is_forced])
139 139 elif repo.repo_type == 'git':
140 140 for entry in extras.commit_ids:
141 141 if entry['type'] == 'heads':
142 142 is_forced = bool(entry['pruned_sha'])
143 143 affected_branches.append([entry['name'], is_forced])
144 144
145 145 for branch_name, is_forced in affected_branches:
146 146
147 147 rule, branch_perm = auth_user.get_rule_and_branch_permission(
148 148 extras.repository, branch_name)
149 149 if not branch_perm:
150 150 # no branch permission found for this branch, just keep checking
151 151 continue
152 152
153 153 if branch_perm == 'branch.push_force':
154 154 continue
155 155 elif branch_perm == 'branch.push' and is_forced is False:
156 156 continue
157 157 elif branch_perm == 'branch.push' and is_forced is True:
158 158 halt_message = 'Branch `{}` changes rejected by rule {}. ' \
159 159 'FORCE PUSH FORBIDDEN.'.format(branch_name, rule)
160 160 else:
161 161 halt_message = 'Branch `{}` changes rejected by rule {}.'.format(
162 162 branch_name, rule)
163 163
164 164 if halt_message:
165 165 _http_ret = HTTPBranchProtected(halt_message)
166 166 raise _http_ret
167 167
168 168 # Propagate to external components. This is done after checking the
169 169 # lock, for consistent behavior.
170 170 hook_response = pre_push_extension(
171 171 repo_store_path=Repository.base_path(), **extras)
172 172 events.trigger(events.RepoPrePushEvent(
173 173 repo_name=extras.repository, extras=extras))
174 174
175 175 return HookResponse(0, output) + hook_response
176 176
177 177
178 178 def pre_pull(extras):
179 179 """
180 180 Hook executed before pulling the code.
181 181
182 182 It bans pulling when the repository is locked.
183 183 """
184 184
185 185 output = ''
186 186 if extras.locked_by[0]:
187 187 locked_by = User.get(extras.locked_by[0]).username
188 188 reason = extras.locked_by[2]
189 189 # this exception is interpreted in git/hg middlewares and based
190 190 # on that proper return code is server to client
191 191 _http_ret = HTTPLockedRC(
192 192 _locked_by_explanation(extras.repository, locked_by, reason))
193 193 if str(_http_ret.code).startswith('2'):
194 194 # 2xx Codes don't raise exceptions
195 195 output = _http_ret.title
196 196 else:
197 197 raise _http_ret
198 198
199 199 # Propagate to external components. This is done after checking the
200 200 # lock, for consistent behavior.
201 201 hook_response = ''
202 202 if not is_shadow_repo(extras):
203 203 extras.hook_type = extras.hook_type or 'pre_pull'
204 204 hook_response = pre_pull_extension(
205 205 repo_store_path=Repository.base_path(), **extras)
206 206 events.trigger(events.RepoPrePullEvent(
207 207 repo_name=extras.repository, extras=extras))
208 208
209 209 return HookResponse(0, output) + hook_response
210 210
211 211
212 212 def post_pull(extras):
213 213 """Hook executed after client pulls the code."""
214 214
215 215 audit_user = audit_logger.UserWrap(
216 216 username=extras.username,
217 217 ip_addr=extras.ip)
218 218 repo = audit_logger.RepoWrap(repo_name=extras.repository)
219 219 audit_logger.store(
220 220 'user.pull', action_data={'user_agent': extras.user_agent},
221 221 user=audit_user, repo=repo, commit=True)
222 222
223 223 statsd = StatsdClient.statsd
224 224 if statsd:
225 225 statsd.incr('rhodecode_pull_total', tags=[
226 226 'user-agent:{}'.format(user_agent_normalizer(extras.user_agent)),
227 227 ])
228 228 output = ''
229 229 # make lock is a tri state False, True, None. We only make lock on True
230 230 if extras.make_lock is True and not is_shadow_repo(extras):
231 231 user = User.get_by_username(extras.username)
232 232 Repository.lock(Repository.get_by_repo_name(extras.repository),
233 233 user.user_id,
234 234 lock_reason=Repository.LOCK_PULL)
235 235 msg = 'Made lock on repo `%s`' % (extras.repository,)
236 236 output += msg
237 237
238 238 if extras.locked_by[0]:
239 239 locked_by = User.get(extras.locked_by[0]).username
240 240 reason = extras.locked_by[2]
241 241 _http_ret = HTTPLockedRC(
242 242 _locked_by_explanation(extras.repository, locked_by, reason))
243 243 if str(_http_ret.code).startswith('2'):
244 244 # 2xx Codes don't raise exceptions
245 245 output += _http_ret.title
246 246
247 247 # Propagate to external components.
248 248 hook_response = ''
249 249 if not is_shadow_repo(extras):
250 250 extras.hook_type = extras.hook_type or 'post_pull'
251 251 hook_response = post_pull_extension(
252 252 repo_store_path=Repository.base_path(), **extras)
253 253 events.trigger(events.RepoPullEvent(
254 254 repo_name=extras.repository, extras=extras))
255 255
256 256 return HookResponse(0, output) + hook_response
257 257
258 258
259 259 def post_push(extras):
260 260 """Hook executed after user pushes to the repository."""
261 261 commit_ids = extras.commit_ids
262 262
263 263 # log the push call
264 264 audit_user = audit_logger.UserWrap(
265 265 username=extras.username, ip_addr=extras.ip)
266 266 repo = audit_logger.RepoWrap(repo_name=extras.repository)
267 267 audit_logger.store(
268 268 'user.push', action_data={
269 269 'user_agent': extras.user_agent,
270 270 'commit_ids': commit_ids[:400]},
271 271 user=audit_user, repo=repo, commit=True)
272 272
273 273 statsd = StatsdClient.statsd
274 274 if statsd:
275 275 statsd.incr('rhodecode_push_total', tags=[
276 276 'user-agent:{}'.format(user_agent_normalizer(extras.user_agent)),
277 277 ])
278 278
279 279 # Propagate to external components.
280 280 output = ''
281 281 # make lock is a tri state False, True, None. We only release lock on False
282 282 if extras.make_lock is False and not is_shadow_repo(extras):
283 283 Repository.unlock(Repository.get_by_repo_name(extras.repository))
284 284 msg = 'Released lock on repo `{}`\n'.format(safe_str(extras.repository))
285 285 output += msg
286 286
287 287 if extras.locked_by[0]:
288 288 locked_by = User.get(extras.locked_by[0]).username
289 289 reason = extras.locked_by[2]
290 290 _http_ret = HTTPLockedRC(
291 291 _locked_by_explanation(extras.repository, locked_by, reason))
292 292 # TODO: johbo: if not?
293 293 if str(_http_ret.code).startswith('2'):
294 294 # 2xx Codes don't raise exceptions
295 295 output += _http_ret.title
296 296
297 297 if extras.new_refs:
298 298 tmpl = '{}/{}/pull-request/new?{{ref_type}}={{ref_name}}'.format(
299 299 safe_str(extras.server_url), safe_str(extras.repository))
300 300
301 301 for branch_name in extras.new_refs['branches']:
302 302 output += 'RhodeCode: open pull request link: {}\n'.format(
303 303 tmpl.format(ref_type='branch', ref_name=safe_str(branch_name)))
304 304
305 305 for book_name in extras.new_refs['bookmarks']:
306 306 output += 'RhodeCode: open pull request link: {}\n'.format(
307 307 tmpl.format(ref_type='bookmark', ref_name=safe_str(book_name)))
308 308
309 309 hook_response = ''
310 310 if not is_shadow_repo(extras):
311 311 hook_response = post_push_extension(
312 312 repo_store_path=Repository.base_path(),
313 313 **extras)
314 314 events.trigger(events.RepoPushEvent(
315 315 repo_name=extras.repository, pushed_commit_ids=commit_ids, extras=extras))
316 316
317 317 output += 'RhodeCode: push completed\n'
318 318 return HookResponse(0, output) + hook_response
319 319
320 320
321 321 def _locked_by_explanation(repo_name, user_name, reason):
322 322 message = (
323 323 'Repository `%s` locked by user `%s`. Reason:`%s`'
324 324 % (repo_name, user_name, reason))
325 325 return message
326 326
327 327
328 328 def check_allowed_create_user(user_dict, created_by, **kwargs):
329 329 # pre create hooks
330 330 if pre_create_user.is_active():
331 331 hook_result = pre_create_user(created_by=created_by, **user_dict)
332 332 allowed = hook_result.status == 0
333 333 if not allowed:
334 334 reason = hook_result.output
335 335 raise UserCreationError(reason)
336 336
337 337
338 338 class ExtensionCallback(object):
339 339 """
340 340 Forwards a given call to rcextensions, sanitizes keyword arguments.
341 341
342 342 Does check if there is an extension active for that hook. If it is
343 343 there, it will forward all `kwargs_keys` keyword arguments to the
344 344 extension callback.
345 345 """
346 346
347 347 def __init__(self, hook_name, kwargs_keys):
348 348 self._hook_name = hook_name
349 349 self._kwargs_keys = set(kwargs_keys)
350 350
351 351 def __call__(self, *args, **kwargs):
352 352 log.debug('Calling extension callback for `%s`', self._hook_name)
353 353 callback = self._get_callback()
354 354 if not callback:
355 355 log.debug('extension callback `%s` not found, skipping...', self._hook_name)
356 356 return
357 357
358 358 kwargs_to_pass = {}
359 359 for key in self._kwargs_keys:
360 360 try:
361 361 kwargs_to_pass[key] = kwargs[key]
362 362 except KeyError:
363 363 log.error('Failed to fetch %s key from given kwargs. '
364 364 'Expected keys: %s', key, self._kwargs_keys)
365 365 raise
366 366
367 367 # backward compat for removed api_key for old hooks. This was it works
368 368 # with older rcextensions that require api_key present
369 369 if self._hook_name in ['CREATE_USER_HOOK', 'DELETE_USER_HOOK']:
370 370 kwargs_to_pass['api_key'] = '_DEPRECATED_'
371 371 return callback(**kwargs_to_pass)
372 372
373 373 def is_active(self):
374 374 return hasattr(rhodecode.EXTENSIONS, self._hook_name)
375 375
376 376 def _get_callback(self):
377 377 return getattr(rhodecode.EXTENSIONS, self._hook_name, None)
378 378
379 379
380 380 pre_pull_extension = ExtensionCallback(
381 381 hook_name='PRE_PULL_HOOK',
382 382 kwargs_keys=(
383 383 'server_url', 'config', 'scm', 'username', 'ip', 'action',
384 384 'repository', 'hook_type', 'user_agent', 'repo_store_path',))
385 385
386 386
387 387 post_pull_extension = ExtensionCallback(
388 388 hook_name='PULL_HOOK',
389 389 kwargs_keys=(
390 390 'server_url', 'config', 'scm', 'username', 'ip', 'action',
391 391 'repository', 'hook_type', 'user_agent', 'repo_store_path',))
392 392
393 393
394 394 pre_push_extension = ExtensionCallback(
395 395 hook_name='PRE_PUSH_HOOK',
396 396 kwargs_keys=(
397 397 'server_url', 'config', 'scm', 'username', 'ip', 'action',
398 398 'repository', 'repo_store_path', 'commit_ids', 'hook_type', 'user_agent',))
399 399
400 400
401 401 post_push_extension = ExtensionCallback(
402 402 hook_name='PUSH_HOOK',
403 403 kwargs_keys=(
404 404 'server_url', 'config', 'scm', 'username', 'ip', 'action',
405 405 'repository', 'repo_store_path', 'commit_ids', 'hook_type', 'user_agent',))
406 406
407 407
408 408 pre_create_user = ExtensionCallback(
409 409 hook_name='PRE_CREATE_USER_HOOK',
410 410 kwargs_keys=(
411 411 'username', 'password', 'email', 'firstname', 'lastname', 'active',
412 412 'admin', 'created_by'))
413 413
414 414
415 415 create_pull_request = ExtensionCallback(
416 416 hook_name='CREATE_PULL_REQUEST',
417 417 kwargs_keys=(
418 418 'server_url', 'config', 'scm', 'username', 'ip', 'action',
419 419 'repository', 'pull_request_id', 'url', 'title', 'description',
420 420 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
421 421 'mergeable', 'source', 'target', 'author', 'reviewers'))
422 422
423 423
424 424 merge_pull_request = ExtensionCallback(
425 425 hook_name='MERGE_PULL_REQUEST',
426 426 kwargs_keys=(
427 427 'server_url', 'config', 'scm', 'username', 'ip', 'action',
428 428 'repository', 'pull_request_id', 'url', 'title', 'description',
429 429 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
430 430 'mergeable', 'source', 'target', 'author', 'reviewers'))
431 431
432 432
433 433 close_pull_request = ExtensionCallback(
434 434 hook_name='CLOSE_PULL_REQUEST',
435 435 kwargs_keys=(
436 436 'server_url', 'config', 'scm', 'username', 'ip', 'action',
437 437 'repository', 'pull_request_id', 'url', 'title', 'description',
438 438 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
439 439 'mergeable', 'source', 'target', 'author', 'reviewers'))
440 440
441 441
442 442 review_pull_request = ExtensionCallback(
443 443 hook_name='REVIEW_PULL_REQUEST',
444 444 kwargs_keys=(
445 445 'server_url', 'config', 'scm', 'username', 'ip', 'action',
446 446 'repository', 'pull_request_id', 'url', 'title', 'description',
447 447 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
448 448 'mergeable', 'source', 'target', 'author', 'reviewers'))
449 449
450 450
451 451 comment_pull_request = ExtensionCallback(
452 452 hook_name='COMMENT_PULL_REQUEST',
453 453 kwargs_keys=(
454 454 'server_url', 'config', 'scm', 'username', 'ip', 'action',
455 455 'repository', 'pull_request_id', 'url', 'title', 'description',
456 456 'status', 'comment', 'created_on', 'updated_on', 'commit_ids', 'review_status',
457 457 'mergeable', 'source', 'target', 'author', 'reviewers'))
458 458
459 459
460 460 comment_edit_pull_request = ExtensionCallback(
461 461 hook_name='COMMENT_EDIT_PULL_REQUEST',
462 462 kwargs_keys=(
463 463 'server_url', 'config', 'scm', 'username', 'ip', 'action',
464 464 'repository', 'pull_request_id', 'url', 'title', 'description',
465 465 'status', 'comment', 'created_on', 'updated_on', 'commit_ids', 'review_status',
466 466 'mergeable', 'source', 'target', 'author', 'reviewers'))
467 467
468 468
469 469 update_pull_request = ExtensionCallback(
470 470 hook_name='UPDATE_PULL_REQUEST',
471 471 kwargs_keys=(
472 472 'server_url', 'config', 'scm', 'username', 'ip', 'action',
473 473 'repository', 'pull_request_id', 'url', 'title', 'description',
474 474 'status', 'created_on', 'updated_on', 'commit_ids', 'review_status',
475 475 'mergeable', 'source', 'target', 'author', 'reviewers'))
476 476
477 477
478 478 create_user = ExtensionCallback(
479 479 hook_name='CREATE_USER_HOOK',
480 480 kwargs_keys=(
481 481 'username', 'full_name_or_username', 'full_contact', 'user_id',
482 482 'name', 'firstname', 'short_contact', 'admin', 'lastname',
483 483 'ip_addresses', 'extern_type', 'extern_name',
484 484 'email', 'api_keys', 'last_login',
485 485 'full_name', 'active', 'password', 'emails',
486 486 'inherit_default_permissions', 'created_by', 'created_on'))
487 487
488 488
489 489 delete_user = ExtensionCallback(
490 490 hook_name='DELETE_USER_HOOK',
491 491 kwargs_keys=(
492 492 'username', 'full_name_or_username', 'full_contact', 'user_id',
493 493 'name', 'firstname', 'short_contact', 'admin', 'lastname',
494 494 'ip_addresses',
495 495 'email', 'last_login',
496 496 'full_name', 'active', 'password', 'emails',
497 497 'inherit_default_permissions', 'deleted_by'))
498 498
499 499
500 500 create_repository = ExtensionCallback(
501 501 hook_name='CREATE_REPO_HOOK',
502 502 kwargs_keys=(
503 503 'repo_name', 'repo_type', 'description', 'private', 'created_on',
504 504 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
505 505 'clone_uri', 'fork_id', 'group_id', 'created_by'))
506 506
507 507
508 508 delete_repository = ExtensionCallback(
509 509 hook_name='DELETE_REPO_HOOK',
510 510 kwargs_keys=(
511 511 'repo_name', 'repo_type', 'description', 'private', 'created_on',
512 512 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
513 513 'clone_uri', 'fork_id', 'group_id', 'deleted_by', 'deleted_on'))
514 514
515 515
516 516 comment_commit_repository = ExtensionCallback(
517 517 hook_name='COMMENT_COMMIT_REPO_HOOK',
518 518 kwargs_keys=(
519 519 'repo_name', 'repo_type', 'description', 'private', 'created_on',
520 520 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
521 521 'clone_uri', 'fork_id', 'group_id',
522 522 'repository', 'created_by', 'comment', 'commit'))
523 523
524 524 comment_edit_commit_repository = ExtensionCallback(
525 525 hook_name='COMMENT_EDIT_COMMIT_REPO_HOOK',
526 526 kwargs_keys=(
527 527 'repo_name', 'repo_type', 'description', 'private', 'created_on',
528 528 'enable_downloads', 'repo_id', 'user_id', 'enable_statistics',
529 529 'clone_uri', 'fork_id', 'group_id',
530 530 'repository', 'created_by', 'comment', 'commit'))
531 531
532 532
533 533 create_repository_group = ExtensionCallback(
534 534 hook_name='CREATE_REPO_GROUP_HOOK',
535 535 kwargs_keys=(
536 536 'group_name', 'group_parent_id', 'group_description',
537 537 'group_id', 'user_id', 'created_by', 'created_on',
538 538 'enable_locking'))
@@ -1,187 +1,187 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import sys
22 22 import logging
23 23
24 24
25 BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(30, 38)
25 BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = list(range(30, 38))
26 26
27 27 # Sequences
28 28 RESET_SEQ = "\033[0m"
29 29 COLOR_SEQ = "\033[0;%dm"
30 30 BOLD_SEQ = "\033[1m"
31 31
32 32 COLORS = {
33 33 'CRITICAL': MAGENTA,
34 34 'ERROR': RED,
35 35 'WARNING': CYAN,
36 36 'INFO': GREEN,
37 37 'DEBUG': BLUE,
38 38 'SQL': YELLOW
39 39 }
40 40
41 41
42 42 def _inject_req_id(record, with_prefix=True):
43 43 from pyramid.threadlocal import get_current_request
44 44 dummy = '00000000-0000-0000-0000-000000000000'
45 45 req_id = None
46 46
47 47 req = get_current_request()
48 48 if req:
49 49 req_id = getattr(req, 'req_id', None)
50 50 if with_prefix:
51 51 req_id = 'req_id:%-36s' % (req_id or dummy)
52 52 else:
53 53 req_id = (req_id or dummy)
54 54 record.req_id = req_id
55 55
56 56
57 57 def _add_log_to_debug_bucket(formatted_record):
58 58 from pyramid.threadlocal import get_current_request
59 59 req = get_current_request()
60 60 if req:
61 61 req.req_id_bucket.append(formatted_record)
62 62
63 63
64 64 def one_space_trim(s):
65 65 if s.find(" ") == -1:
66 66 return s
67 67 else:
68 68 s = s.replace(' ', ' ')
69 69 return one_space_trim(s)
70 70
71 71
72 72 def format_sql(sql):
73 73 sql = sql.replace('\n', '')
74 74 sql = one_space_trim(sql)
75 75 sql = sql\
76 76 .replace(',', ',\n\t')\
77 77 .replace('SELECT', '\n\tSELECT \n\t')\
78 78 .replace('UPDATE', '\n\tUPDATE \n\t')\
79 79 .replace('DELETE', '\n\tDELETE \n\t')\
80 80 .replace('FROM', '\n\tFROM')\
81 81 .replace('ORDER BY', '\n\tORDER BY')\
82 82 .replace('LIMIT', '\n\tLIMIT')\
83 83 .replace('WHERE', '\n\tWHERE')\
84 84 .replace('AND', '\n\tAND')\
85 85 .replace('LEFT', '\n\tLEFT')\
86 86 .replace('INNER', '\n\tINNER')\
87 87 .replace('INSERT', '\n\tINSERT')\
88 88 .replace('DELETE', '\n\tDELETE')
89 89 return sql
90 90
91 91
92 92 class ExceptionAwareFormatter(logging.Formatter):
93 93 """
94 94 Extended logging formatter which prints out remote tracebacks.
95 95 """
96 96
97 97 def formatException(self, ei):
98 98 ex_type, ex_value, ex_tb = ei
99 99
100 100 local_tb = logging.Formatter.formatException(self, ei)
101 101 if hasattr(ex_value, '_vcs_server_traceback'):
102 102
103 103 def formatRemoteTraceback(remote_tb_lines):
104 104 result = ["\n +--- This exception occured remotely on VCSServer - Remote traceback:\n\n"]
105 105 result.append(remote_tb_lines)
106 106 result.append("\n +--- End of remote traceback\n")
107 107 return result
108 108
109 109 try:
110 110 if ex_type is not None and ex_value is None and ex_tb is None:
111 111 # possible old (3.x) call syntax where caller is only
112 112 # providing exception object
113 113 if type(ex_type) is not type:
114 114 raise TypeError(
115 115 "invalid argument: ex_type should be an exception "
116 116 "type, or just supply no arguments at all")
117 117 if ex_type is None and ex_tb is None:
118 118 ex_type, ex_value, ex_tb = sys.exc_info()
119 119
120 120 remote_tb = getattr(ex_value, "_vcs_server_traceback", None)
121 121
122 122 if remote_tb:
123 123 remote_tb = formatRemoteTraceback(remote_tb)
124 124 return local_tb + ''.join(remote_tb)
125 125 finally:
126 126 # clean up cycle to traceback, to allow proper GC
127 127 del ex_type, ex_value, ex_tb
128 128
129 129 return local_tb
130 130
131 131
132 132 class RequestTrackingFormatter(ExceptionAwareFormatter):
133 133 def format(self, record):
134 134 _inject_req_id(record)
135 135 def_record = logging.Formatter.format(self, record)
136 136 _add_log_to_debug_bucket(def_record)
137 137 return def_record
138 138
139 139
140 140 class ColorFormatter(ExceptionAwareFormatter):
141 141
142 142 def format(self, record):
143 143 """
144 144 Changes record's levelname to use with COLORS enum
145 145 """
146 146 def_record = super(ColorFormatter, self).format(record)
147 147
148 148 levelname = record.levelname
149 149 start = COLOR_SEQ % (COLORS[levelname])
150 150 end = RESET_SEQ
151 151
152 152 colored_record = ''.join([start, def_record, end])
153 153 return colored_record
154 154
155 155
156 156 class ColorRequestTrackingFormatter(RequestTrackingFormatter):
157 157
158 158 def format(self, record):
159 159 """
160 160 Changes record's levelname to use with COLORS enum
161 161 """
162 162 def_record = super(ColorRequestTrackingFormatter, self).format(record)
163 163
164 164 levelname = record.levelname
165 165 start = COLOR_SEQ % (COLORS[levelname])
166 166 end = RESET_SEQ
167 167
168 168 colored_record = ''.join([start, def_record, end])
169 169 return colored_record
170 170
171 171
172 172 class ColorFormatterSql(logging.Formatter):
173 173
174 174 def format(self, record):
175 175 """
176 176 Changes record's levelname to use with COLORS enum
177 177 """
178 178
179 179 start = COLOR_SEQ % (COLORS['SQL'])
180 180 def_record = format_sql(logging.Formatter.format(self, record))
181 181 end = RESET_SEQ
182 182
183 183 colored_record = ''.join([start, def_record, end])
184 184 return colored_record
185 185
186 186 # marcink: needs to stay with this name for backward .ini compatability
187 187 Pyro4AwareFormatter = ExceptionAwareFormatter
@@ -1,580 +1,580 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21
22 22 """
23 23 Renderer for markup languages with ability to parse using rst or markdown
24 24 """
25 25
26 26 import re
27 27 import os
28 28 import lxml
29 29 import logging
30 30 import urllib.parse
31 31 import bleach
32 32
33 33 from mako.lookup import TemplateLookup
34 34 from mako.template import Template as MakoTemplate
35 35
36 36 from docutils.core import publish_parts
37 37 from docutils.parsers.rst import directives
38 38 from docutils import writers
39 39 from docutils.writers import html4css1
40 40 import markdown
41 41
42 42 from rhodecode.lib.markdown_ext import GithubFlavoredMarkdownExtension
43 43 from rhodecode.lib.utils2 import (safe_unicode, md5_safe, MENTIONS_REGEX)
44 44
45 45 log = logging.getLogger(__name__)
46 46
47 47 # default renderer used to generate automated comments
48 48 DEFAULT_COMMENTS_RENDERER = 'rst'
49 49
50 50 try:
51 51 from lxml.html import fromstring
52 52 from lxml.html import tostring
53 53 except ImportError:
54 54 log.exception('Failed to import lxml')
55 55 fromstring = None
56 56 tostring = None
57 57
58 58
59 59 class CustomHTMLTranslator(writers.html4css1.HTMLTranslator):
60 60 """
61 61 Custom HTML Translator used for sandboxing potential
62 62 JS injections in ref links
63 63 """
64 64 def visit_literal_block(self, node):
65 65 self.body.append(self.starttag(node, 'pre', CLASS='codehilite literal-block'))
66 66
67 67 def visit_reference(self, node):
68 68 if 'refuri' in node.attributes:
69 69 refuri = node['refuri']
70 70 if ':' in refuri:
71 71 prefix, link = refuri.lstrip().split(':', 1)
72 72 prefix = prefix or ''
73 73
74 74 if prefix.lower() == 'javascript':
75 75 # we don't allow javascript type of refs...
76 76 node['refuri'] = 'javascript:alert("SandBoxedJavascript")'
77 77
78 78 # old style class requires this...
79 79 return html4css1.HTMLTranslator.visit_reference(self, node)
80 80
81 81
82 82 class RhodeCodeWriter(writers.html4css1.Writer):
83 83 def __init__(self):
84 84 writers.Writer.__init__(self)
85 85 self.translator_class = CustomHTMLTranslator
86 86
87 87
88 88 def relative_links(html_source, server_paths):
89 89 if not html_source:
90 90 return html_source
91 91
92 92 if not fromstring and tostring:
93 93 return html_source
94 94
95 95 try:
96 96 doc = lxml.html.fromstring(html_source)
97 97 except Exception:
98 98 return html_source
99 99
100 100 for el in doc.cssselect('img, video'):
101 101 src = el.attrib.get('src')
102 102 if src:
103 103 el.attrib['src'] = relative_path(src, server_paths['raw'])
104 104
105 105 for el in doc.cssselect('a:not(.gfm)'):
106 106 src = el.attrib.get('href')
107 107 if src:
108 108 raw_mode = el.attrib['href'].endswith('?raw=1')
109 109 if raw_mode:
110 110 el.attrib['href'] = relative_path(src, server_paths['raw'])
111 111 else:
112 112 el.attrib['href'] = relative_path(src, server_paths['standard'])
113 113
114 114 return lxml.html.tostring(doc)
115 115
116 116
117 117 def relative_path(path, request_path, is_repo_file=None):
118 118 """
119 119 relative link support, path is a rel path, and request_path is current
120 120 server path (not absolute)
121 121
122 122 e.g.
123 123
124 124 path = '../logo.png'
125 125 request_path= '/repo/files/path/file.md'
126 126 produces: '/repo/files/logo.png'
127 127 """
128 128 # TODO(marcink): unicode/str support ?
129 129 # maybe=> safe_unicode(urllib.quote(safe_str(final_path), '/:'))
130 130
131 131 def dummy_check(p):
132 132 return True # assume default is a valid file path
133 133
134 134 is_repo_file = is_repo_file or dummy_check
135 135 if not path:
136 136 return request_path
137 137
138 138 path = safe_unicode(path)
139 139 request_path = safe_unicode(request_path)
140 140
141 if path.startswith((u'data:', u'javascript:', u'#', u':')):
141 if path.startswith(('data:', 'javascript:', '#', ':')):
142 142 # skip data, anchor, invalid links
143 143 return path
144 144
145 145 is_absolute = bool(urllib.parse.urlparse(path).netloc)
146 146 if is_absolute:
147 147 return path
148 148
149 149 if not request_path:
150 150 return path
151 151
152 if path.startswith(u'/'):
152 if path.startswith('/'):
153 153 path = path[1:]
154 154
155 if path.startswith(u'./'):
155 if path.startswith('./'):
156 156 path = path[2:]
157 157
158 158 parts = request_path.split('/')
159 159 # compute how deep we need to traverse the request_path
160 160 depth = 0
161 161
162 162 if is_repo_file(request_path):
163 163 # if request path is a VALID file, we use a relative path with
164 164 # one level up
165 165 depth += 1
166 166
167 while path.startswith(u'../'):
167 while path.startswith('../'):
168 168 depth += 1
169 169 path = path[3:]
170 170
171 171 if depth > 0:
172 172 parts = parts[:-depth]
173 173
174 174 parts.append(path)
175 final_path = u'/'.join(parts).lstrip(u'/')
175 final_path = '/'.join(parts).lstrip('/')
176 176
177 return u'/' + final_path
177 return '/' + final_path
178 178
179 179
180 180 _cached_markdown_renderer = None
181 181
182 182
183 183 def get_markdown_renderer(extensions, output_format):
184 184 global _cached_markdown_renderer
185 185
186 186 if _cached_markdown_renderer is None:
187 187 _cached_markdown_renderer = markdown.Markdown(
188 188 extensions=extensions,
189 189 enable_attributes=False, output_format=output_format)
190 190 return _cached_markdown_renderer
191 191
192 192
193 193 _cached_markdown_renderer_flavored = None
194 194
195 195
196 196 def get_markdown_renderer_flavored(extensions, output_format):
197 197 global _cached_markdown_renderer_flavored
198 198
199 199 if _cached_markdown_renderer_flavored is None:
200 200 _cached_markdown_renderer_flavored = markdown.Markdown(
201 201 extensions=extensions + [GithubFlavoredMarkdownExtension()],
202 202 enable_attributes=False, output_format=output_format)
203 203 return _cached_markdown_renderer_flavored
204 204
205 205
206 206 class MarkupRenderer(object):
207 207 RESTRUCTUREDTEXT_DISALLOWED_DIRECTIVES = ['include', 'meta', 'raw']
208 208
209 209 MARKDOWN_PAT = re.compile(r'\.(md|mkdn?|mdown|markdown)$', re.IGNORECASE)
210 210 RST_PAT = re.compile(r'\.re?st$', re.IGNORECASE)
211 211 JUPYTER_PAT = re.compile(r'\.(ipynb)$', re.IGNORECASE)
212 212 PLAIN_PAT = re.compile(r'^readme$', re.IGNORECASE)
213 213
214 214 URL_PAT = re.compile(r'(http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]'
215 215 r'|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+)')
216 216
217 217 MENTION_PAT = re.compile(MENTIONS_REGEX)
218 218
219 219 extensions = ['markdown.extensions.codehilite', 'markdown.extensions.extra',
220 220 'markdown.extensions.def_list', 'markdown.extensions.sane_lists']
221 221
222 222 output_format = 'html4'
223 223
224 224 # extension together with weights. Lower is first means we control how
225 225 # extensions are attached to readme names with those.
226 226 PLAIN_EXTS = [
227 227 # prefer no extension
228 228 ('', 0), # special case that renders READMES names without extension
229 229 ('.text', 2), ('.TEXT', 2),
230 230 ('.txt', 3), ('.TXT', 3)
231 231 ]
232 232
233 233 RST_EXTS = [
234 234 ('.rst', 1), ('.rest', 1),
235 235 ('.RST', 2), ('.REST', 2)
236 236 ]
237 237
238 238 MARKDOWN_EXTS = [
239 239 ('.md', 1), ('.MD', 1),
240 240 ('.mkdn', 2), ('.MKDN', 2),
241 241 ('.mdown', 3), ('.MDOWN', 3),
242 242 ('.markdown', 4), ('.MARKDOWN', 4)
243 243 ]
244 244
245 245 def _detect_renderer(self, source, filename=None):
246 246 """
247 247 runs detection of what renderer should be used for generating html
248 248 from a markup language
249 249
250 250 filename can be also explicitly a renderer name
251 251
252 252 :param source:
253 253 :param filename:
254 254 """
255 255
256 256 if MarkupRenderer.MARKDOWN_PAT.findall(filename):
257 257 detected_renderer = 'markdown'
258 258 elif MarkupRenderer.RST_PAT.findall(filename):
259 259 detected_renderer = 'rst'
260 260 elif MarkupRenderer.JUPYTER_PAT.findall(filename):
261 261 detected_renderer = 'jupyter'
262 262 elif MarkupRenderer.PLAIN_PAT.findall(filename):
263 263 detected_renderer = 'plain'
264 264 else:
265 265 detected_renderer = 'plain'
266 266
267 267 return getattr(MarkupRenderer, detected_renderer)
268 268
269 269 @classmethod
270 270 def bleach_clean(cls, text):
271 271 from .bleach_whitelist import markdown_attrs, markdown_tags
272 272 allowed_tags = markdown_tags
273 273 allowed_attrs = markdown_attrs
274 274
275 275 try:
276 276 return bleach.clean(text, tags=allowed_tags, attributes=allowed_attrs)
277 277 except Exception:
278 278 return 'UNPARSEABLE TEXT'
279 279
280 280 @classmethod
281 281 def renderer_from_filename(cls, filename, exclude):
282 282 """
283 283 Detect renderer markdown/rst from filename and optionally use exclude
284 284 list to remove some options. This is mostly used in helpers.
285 285 Returns None when no renderer can be detected.
286 286 """
287 287 def _filter(elements):
288 288 if isinstance(exclude, (list, tuple)):
289 289 return [x for x in elements if x not in exclude]
290 290 return elements
291 291
292 292 if filename.endswith(
293 293 tuple(_filter([x[0] for x in cls.MARKDOWN_EXTS if x[0]]))):
294 294 return 'markdown'
295 295 if filename.endswith(tuple(_filter([x[0] for x in cls.RST_EXTS if x[0]]))):
296 296 return 'rst'
297 297
298 298 return None
299 299
300 300 def render(self, source, filename=None):
301 301 """
302 302 Renders a given filename using detected renderer
303 303 it detects renderers based on file extension or mimetype.
304 304 At last it will just do a simple html replacing new lines with <br/>
305 305
306 306 :param file_name:
307 307 :param source:
308 308 """
309 309
310 310 renderer = self._detect_renderer(source, filename)
311 311 readme_data = renderer(source)
312 312 return readme_data
313 313
314 314 @classmethod
315 315 def _flavored_markdown(cls, text):
316 316 """
317 317 Github style flavored markdown
318 318
319 319 :param text:
320 320 """
321 321
322 322 # Extract pre blocks.
323 323 extractions = {}
324 324
325 325 def pre_extraction_callback(matchobj):
326 326 digest = md5_safe(matchobj.group(0))
327 327 extractions[digest] = matchobj.group(0)
328 328 return "{gfm-extraction-%s}" % digest
329 329 pattern = re.compile(r'<pre>.*?</pre>', re.MULTILINE | re.DOTALL)
330 330 text = re.sub(pattern, pre_extraction_callback, text)
331 331
332 332 # Prevent foo_bar_baz from ending up with an italic word in the middle.
333 333 def italic_callback(matchobj):
334 334 s = matchobj.group(0)
335 335 if list(s).count('_') >= 2:
336 336 return s.replace('_', r'\_')
337 337 return s
338 338 text = re.sub(r'^(?! {4}|\t)\w+_\w+_\w[\w_]*', italic_callback, text)
339 339
340 340 # Insert pre block extractions.
341 341 def pre_insert_callback(matchobj):
342 342 return '\n\n' + extractions[matchobj.group(1)]
343 343 text = re.sub(r'\{gfm-extraction-([0-9a-f]{32})\}',
344 344 pre_insert_callback, text)
345 345
346 346 return text
347 347
348 348 @classmethod
349 349 def urlify_text(cls, text):
350 350 def url_func(match_obj):
351 351 url_full = match_obj.groups()[0]
352 352 return '<a href="%(url)s">%(url)s</a>' % ({'url': url_full})
353 353
354 354 return cls.URL_PAT.sub(url_func, text)
355 355
356 356 @classmethod
357 357 def convert_mentions(cls, text, mode):
358 358 mention_pat = cls.MENTION_PAT
359 359
360 360 def wrapp(match_obj):
361 361 uname = match_obj.groups()[0]
362 362 hovercard_url = "pyroutes.url('hovercard_username', {'username': '%s'});" % uname
363 363
364 364 if mode == 'markdown':
365 365 tmpl = '<strong class="tooltip-hovercard" data-hovercard-alt="{uname}" data-hovercard-url="{hovercard_url}">@{uname}</strong>'
366 366 elif mode == 'rst':
367 367 tmpl = ' **@{uname}** '
368 368 else:
369 369 raise ValueError('mode must be rst or markdown')
370 370
371 371 return tmpl.format(**{'uname': uname,
372 372 'hovercard_url': hovercard_url})
373 373
374 374 return mention_pat.sub(wrapp, text).strip()
375 375
376 376 @classmethod
377 377 def plain(cls, source, universal_newline=True, leading_newline=True):
378 378 source = safe_unicode(source)
379 379 if universal_newline:
380 380 newline = '\n'
381 381 source = newline.join(source.splitlines())
382 382
383 383 rendered_source = cls.urlify_text(source)
384 384 source = ''
385 385 if leading_newline:
386 386 source += '<br />'
387 387 source += rendered_source.replace("\n", '<br />')
388 388
389 389 rendered = cls.bleach_clean(source)
390 390 return rendered
391 391
392 392 @classmethod
393 393 def markdown(cls, source, safe=True, flavored=True, mentions=False,
394 394 clean_html=True):
395 395 """
396 396 returns markdown rendered code cleaned by the bleach library
397 397 """
398 398
399 399 if flavored:
400 400 markdown_renderer = get_markdown_renderer_flavored(
401 401 cls.extensions, cls.output_format)
402 402 else:
403 403 markdown_renderer = get_markdown_renderer(
404 404 cls.extensions, cls.output_format)
405 405
406 406 if mentions:
407 407 mention_hl = cls.convert_mentions(source, mode='markdown')
408 408 # we extracted mentions render with this using Mentions false
409 409 return cls.markdown(mention_hl, safe=safe, flavored=flavored,
410 410 mentions=False)
411 411
412 412 source = safe_unicode(source)
413 413
414 414 try:
415 415 if flavored:
416 416 source = cls._flavored_markdown(source)
417 417 rendered = markdown_renderer.convert(source)
418 418 except Exception:
419 419 log.exception('Error when rendering Markdown')
420 420 if safe:
421 421 log.debug('Fallback to render in plain mode')
422 422 rendered = cls.plain(source)
423 423 else:
424 424 raise
425 425
426 426 if clean_html:
427 427 rendered = cls.bleach_clean(rendered)
428 428 return rendered
429 429
430 430 @classmethod
431 431 def rst(cls, source, safe=True, mentions=False, clean_html=False):
432 432 if mentions:
433 433 mention_hl = cls.convert_mentions(source, mode='rst')
434 434 # we extracted mentions render with this using Mentions false
435 435 return cls.rst(mention_hl, safe=safe, mentions=False)
436 436
437 437 source = safe_unicode(source)
438 438 try:
439 439 docutils_settings = dict(
440 440 [(alias, None) for alias in
441 441 cls.RESTRUCTUREDTEXT_DISALLOWED_DIRECTIVES])
442 442
443 443 docutils_settings.update({
444 444 'input_encoding': 'unicode',
445 445 'report_level': 4,
446 446 'syntax_highlight': 'short',
447 447 })
448 448
449 449 for k, v in docutils_settings.items():
450 450 directives.register_directive(k, v)
451 451
452 452 parts = publish_parts(source=source,
453 453 writer=RhodeCodeWriter(),
454 454 settings_overrides=docutils_settings)
455 455 rendered = parts["fragment"]
456 456 if clean_html:
457 457 rendered = cls.bleach_clean(rendered)
458 458 return parts['html_title'] + rendered
459 459 except Exception:
460 460 log.exception('Error when rendering RST')
461 461 if safe:
462 462 log.debug('Fallback to render in plain mode')
463 463 return cls.plain(source)
464 464 else:
465 465 raise
466 466
467 467 @classmethod
468 468 def jupyter(cls, source, safe=True):
469 469 from rhodecode.lib import helpers
470 470
471 471 from traitlets.config import Config
472 472 import nbformat
473 473 from nbconvert import HTMLExporter
474 474 from nbconvert.preprocessors import Preprocessor
475 475
476 476 class CustomHTMLExporter(HTMLExporter):
477 477 def _template_file_default(self):
478 478 return 'basic'
479 479
480 480 class Sandbox(Preprocessor):
481 481
482 482 def preprocess(self, nb, resources):
483 483 sandbox_text = 'SandBoxed(IPython.core.display.Javascript object)'
484 484 for cell in nb['cells']:
485 485 if not safe:
486 486 continue
487 487
488 488 if 'outputs' in cell:
489 489 for cell_output in cell['outputs']:
490 490 if 'data' in cell_output:
491 491 if 'application/javascript' in cell_output['data']:
492 492 cell_output['data']['text/plain'] = sandbox_text
493 493 cell_output['data'].pop('application/javascript', None)
494 494
495 495 if 'source' in cell and cell['cell_type'] == 'markdown':
496 496 # sanitize similar like in markdown
497 497 cell['source'] = cls.bleach_clean(cell['source'])
498 498
499 499 return nb, resources
500 500
501 501 def _sanitize_resources(input_resources):
502 502 """
503 503 Skip/sanitize some of the CSS generated and included in jupyter
504 504 so it doesn't messes up UI so much
505 505 """
506 506
507 507 # TODO(marcink): probably we should replace this with whole custom
508 508 # CSS set that doesn't screw up, but jupyter generated html has some
509 509 # special markers, so it requires Custom HTML exporter template with
510 510 # _default_template_path_default, to achieve that
511 511
512 512 # strip the reset CSS
513 513 input_resources[0] = input_resources[0][input_resources[0].find('/*! Source'):]
514 514 return input_resources
515 515
516 516 def as_html(notebook):
517 517 conf = Config()
518 518 conf.CustomHTMLExporter.preprocessors = [Sandbox]
519 519 html_exporter = CustomHTMLExporter(config=conf)
520 520
521 521 (body, resources) = html_exporter.from_notebook_node(notebook)
522 522 header = '<!-- ## IPYTHON NOTEBOOK RENDERING ## -->'
523 523 js = MakoTemplate(r'''
524 524 <!-- MathJax configuration -->
525 525 <script type="text/x-mathjax-config">
526 526 MathJax.Hub.Config({
527 527 jax: ["input/TeX","output/HTML-CSS", "output/PreviewHTML"],
528 528 extensions: ["tex2jax.js","MathMenu.js","MathZoom.js", "fast-preview.js", "AssistiveMML.js", "[Contrib]/a11y/accessibility-menu.js"],
529 529 TeX: {
530 530 extensions: ["AMSmath.js","AMSsymbols.js","noErrors.js","noUndefined.js"]
531 531 },
532 532 tex2jax: {
533 533 inlineMath: [ ['$','$'], ["\\(","\\)"] ],
534 534 displayMath: [ ['$$','$$'], ["\\[","\\]"] ],
535 535 processEscapes: true,
536 536 processEnvironments: true
537 537 },
538 538 // Center justify equations in code and markdown cells. Elsewhere
539 539 // we use CSS to left justify single line equations in code cells.
540 540 displayAlign: 'center',
541 541 "HTML-CSS": {
542 542 styles: {'.MathJax_Display': {"margin": 0}},
543 543 linebreaks: { automatic: true },
544 544 availableFonts: ["STIX", "TeX"]
545 545 },
546 546 showMathMenu: false
547 547 });
548 548 </script>
549 549 <!-- End of MathJax configuration -->
550 550 <script src="${h.asset('js/src/math_jax/MathJax.js')}"></script>
551 551 ''').render(h=helpers)
552 552
553 553 css = MakoTemplate(r'''
554 554 <link rel="stylesheet" type="text/css" href="${h.asset('css/style-ipython.css', ver=ver)}" media="screen"/>
555 555 ''').render(h=helpers, ver='ver1')
556 556
557 557 body = '\n'.join([header, css, js, body])
558 558 return body, resources
559 559
560 560 notebook = nbformat.reads(source, as_version=4)
561 561 (body, resources) = as_html(notebook)
562 562 return body
563 563
564 564
565 565 class RstTemplateRenderer(object):
566 566
567 567 def __init__(self):
568 568 base = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
569 569 rst_template_dirs = [os.path.join(base, 'templates', 'rst_templates')]
570 570 self.template_store = TemplateLookup(
571 571 directories=rst_template_dirs,
572 572 input_encoding='utf-8',
573 573 imports=['from rhodecode.lib import helpers as h'])
574 574
575 575 def _get_template(self, templatename):
576 576 return self.template_store.get_template(templatename)
577 577
578 578 def render(self, template_name, **kwargs):
579 579 template = self._get_template(template_name)
580 580 return template.render(**kwargs)
@@ -1,156 +1,156 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 """
22 22 SimpleGit middleware for handling git protocol request (push/clone etc.)
23 23 It's implemented with basic auth function
24 24 """
25 25 import os
26 26 import re
27 27 import logging
28 28 import urllib.parse
29 29
30 30 import rhodecode
31 31 from rhodecode.lib import utils
32 32 from rhodecode.lib import utils2
33 33 from rhodecode.lib.middleware import simplevcs
34 34
35 35 log = logging.getLogger(__name__)
36 36
37 37
38 38 GIT_PROTO_PAT = re.compile(
39 39 r'^/(.+)/(info/refs|info/lfs/(.+)|git-upload-pack|git-receive-pack)')
40 40 GIT_LFS_PROTO_PAT = re.compile(r'^/(.+)/(info/lfs/(.+))')
41 41
42 42
43 43 def default_lfs_store():
44 44 """
45 45 Default lfs store location, it's consistent with Mercurials large file
46 46 store which is in .cache/largefiles
47 47 """
48 48 from rhodecode.lib.vcs.backends.git import lfs_store
49 49 user_home = os.path.expanduser("~")
50 50 return lfs_store(user_home)
51 51
52 52
53 53 class SimpleGit(simplevcs.SimpleVCS):
54 54
55 55 SCM = 'git'
56 56
57 57 def _get_repository_name(self, environ):
58 58 """
59 59 Gets repository name out of PATH_INFO header
60 60
61 61 :param environ: environ where PATH_INFO is stored
62 62 """
63 63 repo_name = GIT_PROTO_PAT.match(environ['PATH_INFO']).group(1)
64 64 # for GIT LFS, and bare format strip .git suffix from names
65 65 if repo_name.endswith('.git'):
66 66 repo_name = repo_name[:-4]
67 67 return repo_name
68 68
69 69 def _get_lfs_action(self, path, request_method):
70 70 """
71 71 return an action based on LFS requests type.
72 72 Those routes are handled inside vcsserver app.
73 73
74 74 batch -> POST to /info/lfs/objects/batch => PUSH/PULL
75 75 batch is based on the `operation.
76 76 that could be download or upload, but those are only
77 77 instructions to fetch so we return pull always
78 78
79 79 download -> GET to /info/lfs/{oid} => PULL
80 80 upload -> PUT to /info/lfs/{oid} => PUSH
81 81
82 82 verification -> POST to /info/lfs/verify => PULL
83 83
84 84 """
85 85
86 86 match_obj = GIT_LFS_PROTO_PAT.match(path)
87 87 _parts = match_obj.groups()
88 88 repo_name, path, operation = _parts
89 89 log.debug(
90 90 'LFS: detecting operation based on following '
91 91 'data: %s, req_method:%s', _parts, request_method)
92 92
93 93 if operation == 'verify':
94 94 return 'pull'
95 95 elif operation == 'objects/batch':
96 96 # batch sends back instructions for API to dl/upl we report it
97 97 # as pull
98 98 if request_method == 'POST':
99 99 return 'pull'
100 100
101 101 elif operation:
102 102 # probably a OID, upload is PUT, download a GET
103 103 if request_method == 'GET':
104 104 return 'pull'
105 105 else:
106 106 return 'push'
107 107
108 108 # if default not found require push, as action
109 109 return 'push'
110 110
111 111 _ACTION_MAPPING = {
112 112 'git-receive-pack': 'push',
113 113 'git-upload-pack': 'pull',
114 114 }
115 115
116 116 def _get_action(self, environ):
117 117 """
118 118 Maps git request commands into a pull or push command.
119 119 In case of unknown/unexpected data, it returns 'pull' to be safe.
120 120
121 121 :param environ:
122 122 """
123 123 path = environ['PATH_INFO']
124 124
125 125 if path.endswith('/info/refs'):
126 query = urllib.parse.urlparse.parse_qs(environ['QUERY_STRING'])
126 query = urllib.parse.parse_qs(environ['QUERY_STRING'])
127 127 service_cmd = query.get('service', [''])[0]
128 128 return self._ACTION_MAPPING.get(service_cmd, 'pull')
129 129
130 130 elif GIT_LFS_PROTO_PAT.match(environ['PATH_INFO']):
131 131 return self._get_lfs_action(
132 132 environ['PATH_INFO'], environ['REQUEST_METHOD'])
133 133
134 134 elif path.endswith('/git-receive-pack'):
135 135 return 'push'
136 136 elif path.endswith('/git-upload-pack'):
137 137 return 'pull'
138 138
139 139 return 'pull'
140 140
141 141 def _create_wsgi_app(self, repo_path, repo_name, config):
142 142 return self.scm_app.create_git_wsgi_app(
143 143 repo_path, repo_name, config)
144 144
145 145 def _create_config(self, extras, repo_name, scheme='http'):
146 146 extras['git_update_server_info'] = utils2.str2bool(
147 147 rhodecode.CONFIG.get('git_update_server_info'))
148 148
149 149 config = utils.make_db_config(repo=repo_name)
150 150 custom_store = config.get('vcs_git_lfs', 'store_location')
151 151
152 152 extras['git_lfs_enabled'] = utils2.str2bool(
153 153 config.get('vcs_git_lfs', 'enabled'))
154 154 extras['git_lfs_store_path'] = custom_store or default_lfs_store()
155 155 extras['git_lfs_http_scheme'] = scheme
156 156 return extras
@@ -1,160 +1,159 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 """
22 22 SimpleHG middleware for handling mercurial protocol request
23 23 (push/clone etc.). It's implemented with basic auth function
24 24 """
25 25
26 26 import logging
27 27 import urllib.parse
28 28 import urllib.request, urllib.parse, urllib.error
29 29
30 30 from rhodecode.lib import utils
31 31 from rhodecode.lib.ext_json import json
32 32 from rhodecode.lib.middleware import simplevcs
33 33
34 34 log = logging.getLogger(__name__)
35 35
36 36
37 37 class SimpleHg(simplevcs.SimpleVCS):
38 38
39 39 SCM = 'hg'
40 40
41 41 def _get_repository_name(self, environ):
42 42 """
43 43 Gets repository name out of PATH_INFO header
44 44
45 45 :param environ: environ where PATH_INFO is stored
46 46 """
47 47 repo_name = environ['PATH_INFO']
48 48 if repo_name and repo_name.startswith('/'):
49 49 # remove only the first leading /
50 50 repo_name = repo_name[1:]
51 51 return repo_name.rstrip('/')
52 52
53 53 _ACTION_MAPPING = {
54 54 'changegroup': 'pull',
55 55 'changegroupsubset': 'pull',
56 56 'getbundle': 'pull',
57 57 'stream_out': 'pull',
58 58 'listkeys': 'pull',
59 59 'between': 'pull',
60 60 'branchmap': 'pull',
61 61 'branches': 'pull',
62 62 'clonebundles': 'pull',
63 63 'capabilities': 'pull',
64 64 'debugwireargs': 'pull',
65 65 'heads': 'pull',
66 66 'lookup': 'pull',
67 67 'hello': 'pull',
68 68 'known': 'pull',
69 69
70 70 # largefiles
71 71 'putlfile': 'push',
72 72 'getlfile': 'pull',
73 73 'statlfile': 'pull',
74 74 'lheads': 'pull',
75 75
76 76 # evolve
77 77 'evoext_obshashrange_v1': 'pull',
78 78 'evoext_obshash': 'pull',
79 79 'evoext_obshash1': 'pull',
80 80
81 81 'unbundle': 'push',
82 82 'pushkey': 'push',
83 83 }
84 84
85 85 @classmethod
86 86 def _get_xarg_headers(cls, environ):
87 87 i = 1
88 88 chunks = [] # gather chunks stored in multiple 'hgarg_N'
89 89 while True:
90 90 head = environ.get('HTTP_X_HGARG_{}'.format(i))
91 91 if not head:
92 92 break
93 93 i += 1
94 94 chunks.append(urllib.parse.unquote_plus(head))
95 95 full_arg = ''.join(chunks)
96 96 pref = 'cmds='
97 97 if full_arg.startswith(pref):
98 98 # strip the cmds= header defining our batch commands
99 99 full_arg = full_arg[len(pref):]
100 100 cmds = full_arg.split(';')
101 101 return cmds
102 102
103 103 @classmethod
104 104 def _get_batch_cmd(cls, environ):
105 105 """
106 106 Handle batch command send commands. Those are ';' separated commands
107 107 sent by batch command that server needs to execute. We need to extract
108 108 those, and map them to our ACTION_MAPPING to get all push/pull commands
109 109 specified in the batch
110 110 """
111 111 default = 'push'
112 112 batch_cmds = []
113 113 try:
114 114 cmds = cls._get_xarg_headers(environ)
115 115 for pair in cmds:
116 116 parts = pair.split(' ', 1)
117 117 if len(parts) != 2:
118 118 continue
119 119 # entry should be in a format `key ARGS`
120 120 cmd, args = parts
121 121 action = cls._ACTION_MAPPING.get(cmd, default)
122 122 batch_cmds.append(action)
123 123 except Exception:
124 124 log.exception('Failed to extract batch commands operations')
125 125
126 126 # in case we failed, (e.g malformed data) assume it's PUSH sub-command
127 127 # for safety
128 128 return batch_cmds or [default]
129 129
130 130 def _get_action(self, environ):
131 131 """
132 132 Maps mercurial request commands into a pull or push command.
133 133 In case of unknown/unexpected data, it returns 'push' to be safe.
134 134
135 135 :param environ:
136 136 """
137 137 default = 'push'
138 query = urllib.parse.urlparse.parse_qs(environ['QUERY_STRING'],
139 keep_blank_values=True)
138 query = urllib.parse.parse_qs(environ['QUERY_STRING'], keep_blank_values=True)
140 139
141 140 if 'cmd' in query:
142 141 cmd = query['cmd'][0]
143 142 if cmd == 'batch':
144 143 cmds = self._get_batch_cmd(environ)
145 144 if 'push' in cmds:
146 145 return 'push'
147 146 else:
148 147 return 'pull'
149 148 return self._ACTION_MAPPING.get(cmd, default)
150 149
151 150 return default
152 151
153 152 def _create_wsgi_app(self, repo_path, repo_name, config):
154 153 return self.scm_app.create_hg_wsgi_app(repo_path, repo_name, config)
155 154
156 155 def _create_config(self, extras, repo_name, scheme='http'):
157 156 config = utils.make_db_config(repo=repo_name)
158 157 config.set('rhodecode', 'RC_SCM_DATA', json.dumps(extras))
159 158
160 159 return config.serialize()
@@ -1,679 +1,679 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2014-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 """
22 22 SimpleVCS middleware for handling protocol request (push/clone etc.)
23 23 It's implemented with basic auth function
24 24 """
25 25
26 26 import os
27 27 import re
28 import io
28 29 import logging
29 30 import importlib
30 31 from functools import wraps
31 from io import StringIO
32 32 from lxml import etree
33 33
34 34 import time
35 35 from paste.httpheaders import REMOTE_USER, AUTH_TYPE
36 36
37 37 from pyramid.httpexceptions import (
38 38 HTTPNotFound, HTTPForbidden, HTTPNotAcceptable, HTTPInternalServerError)
39 39 from zope.cachedescriptors.property import Lazy as LazyProperty
40 40
41 41 import rhodecode
42 42 from rhodecode.authentication.base import authenticate, VCS_TYPE, loadplugin
43 43 from rhodecode.lib import rc_cache
44 44 from rhodecode.lib.auth import AuthUser, HasPermissionAnyMiddleware
45 45 from rhodecode.lib.base import (
46 46 BasicAuth, get_ip_addr, get_user_agent, vcs_operation_context)
47 47 from rhodecode.lib.exceptions import (UserCreationError, NotAllowedToCreateUserError)
48 48 from rhodecode.lib.hooks_daemon import prepare_callback_daemon
49 49 from rhodecode.lib.middleware import appenlight
50 50 from rhodecode.lib.middleware.utils import scm_app_http
51 51 from rhodecode.lib.utils import is_valid_repo, SLUG_RE
52 52 from rhodecode.lib.utils2 import safe_str, fix_PATH, str2bool, safe_unicode
53 53 from rhodecode.lib.vcs.conf import settings as vcs_settings
54 54 from rhodecode.lib.vcs.backends import base
55 55
56 56 from rhodecode.model import meta
57 57 from rhodecode.model.db import User, Repository, PullRequest
58 58 from rhodecode.model.scm import ScmModel
59 59 from rhodecode.model.pull_request import PullRequestModel
60 60 from rhodecode.model.settings import SettingsModel, VcsSettingsModel
61 61
62 62 log = logging.getLogger(__name__)
63 63
64 64
65 65 def extract_svn_txn_id(acl_repo_name, data):
66 66 """
67 67 Helper method for extraction of svn txn_id from submitted XML data during
68 68 POST operations
69 69 """
70 70 try:
71 71 root = etree.fromstring(data)
72 72 pat = re.compile(r'/txn/(?P<txn_id>.*)')
73 73 for el in root:
74 74 if el.tag == '{DAV:}source':
75 75 for sub_el in el:
76 76 if sub_el.tag == '{DAV:}href':
77 77 match = pat.search(sub_el.text)
78 78 if match:
79 79 svn_tx_id = match.groupdict()['txn_id']
80 80 txn_id = rc_cache.utils.compute_key_from_params(
81 81 acl_repo_name, svn_tx_id)
82 82 return txn_id
83 83 except Exception:
84 84 log.exception('Failed to extract txn_id')
85 85
86 86
87 87 def initialize_generator(factory):
88 88 """
89 89 Initializes the returned generator by draining its first element.
90 90
91 91 This can be used to give a generator an initializer, which is the code
92 92 up to the first yield statement. This decorator enforces that the first
93 93 produced element has the value ``"__init__"`` to make its special
94 94 purpose very explicit in the using code.
95 95 """
96 96
97 97 @wraps(factory)
98 98 def wrapper(*args, **kwargs):
99 99 gen = factory(*args, **kwargs)
100 100 try:
101 101 init = next(gen)
102 102 except StopIteration:
103 103 raise ValueError('Generator must yield at least one element.')
104 104 if init != "__init__":
105 105 raise ValueError('First yielded element must be "__init__".')
106 106 return gen
107 107 return wrapper
108 108
109 109
110 110 class SimpleVCS(object):
111 111 """Common functionality for SCM HTTP handlers."""
112 112
113 113 SCM = 'unknown'
114 114
115 115 acl_repo_name = None
116 116 url_repo_name = None
117 117 vcs_repo_name = None
118 118 rc_extras = {}
119 119
120 120 # We have to handle requests to shadow repositories different than requests
121 121 # to normal repositories. Therefore we have to distinguish them. To do this
122 122 # we use this regex which will match only on URLs pointing to shadow
123 123 # repositories.
124 124 shadow_repo_re = re.compile(
125 '(?P<groups>(?:{slug_pat}/)*)' # repo groups
126 '(?P<target>{slug_pat})/' # target repo
127 'pull-request/(?P<pr_id>\d+)/' # pull request
128 'repository$' # shadow repo
125 '(?P<groups>(?:{slug_pat}/)*)' # repo groups
126 '(?P<target>{slug_pat})/' # target repo
127 'pull-request/(?P<pr_id>\\d+)/' # pull request
128 'repository$' # shadow repo
129 129 .format(slug_pat=SLUG_RE.pattern))
130 130
131 131 def __init__(self, config, registry):
132 132 self.registry = registry
133 133 self.config = config
134 134 # re-populated by specialized middleware
135 135 self.repo_vcs_config = base.Config()
136 136
137 137 rc_settings = SettingsModel().get_all_settings(cache=True, from_request=False)
138 138 realm = rc_settings.get('rhodecode_realm') or 'RhodeCode AUTH'
139 139
140 140 # authenticate this VCS request using authfunc
141 141 auth_ret_code_detection = \
142 142 str2bool(self.config.get('auth_ret_code_detection', False))
143 143 self.authenticate = BasicAuth(
144 144 '', authenticate, registry, config.get('auth_ret_code'),
145 145 auth_ret_code_detection, rc_realm=realm)
146 146 self.ip_addr = '0.0.0.0'
147 147
148 148 @LazyProperty
149 149 def global_vcs_config(self):
150 150 try:
151 151 return VcsSettingsModel().get_ui_settings_as_config_obj()
152 152 except Exception:
153 153 return base.Config()
154 154
155 155 @property
156 156 def base_path(self):
157 157 settings_path = self.repo_vcs_config.get(*VcsSettingsModel.PATH_SETTING)
158 158
159 159 if not settings_path:
160 160 settings_path = self.global_vcs_config.get(*VcsSettingsModel.PATH_SETTING)
161 161
162 162 if not settings_path:
163 163 # try, maybe we passed in explicitly as config option
164 164 settings_path = self.config.get('base_path')
165 165
166 166 if not settings_path:
167 167 raise ValueError('FATAL: base_path is empty')
168 168 return settings_path
169 169
170 170 def set_repo_names(self, environ):
171 171 """
172 172 This will populate the attributes acl_repo_name, url_repo_name,
173 173 vcs_repo_name and is_shadow_repo. In case of requests to normal (non
174 174 shadow) repositories all names are equal. In case of requests to a
175 175 shadow repository the acl-name points to the target repo of the pull
176 176 request and the vcs-name points to the shadow repo file system path.
177 177 The url-name is always the URL used by the vcs client program.
178 178
179 179 Example in case of a shadow repo:
180 180 acl_repo_name = RepoGroup/MyRepo
181 181 url_repo_name = RepoGroup/MyRepo/pull-request/3/repository
182 182 vcs_repo_name = /repo/base/path/RepoGroup/.__shadow_MyRepo_pr-3'
183 183 """
184 184 # First we set the repo name from URL for all attributes. This is the
185 185 # default if handling normal (non shadow) repo requests.
186 186 self.url_repo_name = self._get_repository_name(environ)
187 187 self.acl_repo_name = self.vcs_repo_name = self.url_repo_name
188 188 self.is_shadow_repo = False
189 189
190 190 # Check if this is a request to a shadow repository.
191 191 match = self.shadow_repo_re.match(self.url_repo_name)
192 192 if match:
193 193 match_dict = match.groupdict()
194 194
195 195 # Build acl repo name from regex match.
196 196 acl_repo_name = safe_unicode('{groups}{target}'.format(
197 197 groups=match_dict['groups'] or '',
198 198 target=match_dict['target']))
199 199
200 200 # Retrieve pull request instance by ID from regex match.
201 201 pull_request = PullRequest.get(match_dict['pr_id'])
202 202
203 203 # Only proceed if we got a pull request and if acl repo name from
204 204 # URL equals the target repo name of the pull request.
205 205 if pull_request and (acl_repo_name == pull_request.target_repo.repo_name):
206 206
207 207 # Get file system path to shadow repository.
208 208 workspace_id = PullRequestModel()._workspace_id(pull_request)
209 209 vcs_repo_name = pull_request.target_repo.get_shadow_repository_path(workspace_id)
210 210
211 211 # Store names for later usage.
212 212 self.vcs_repo_name = vcs_repo_name
213 213 self.acl_repo_name = acl_repo_name
214 214 self.is_shadow_repo = True
215 215
216 216 log.debug('Setting all VCS repository names: %s', {
217 217 'acl_repo_name': self.acl_repo_name,
218 218 'url_repo_name': self.url_repo_name,
219 219 'vcs_repo_name': self.vcs_repo_name,
220 220 })
221 221
222 222 @property
223 223 def scm_app(self):
224 224 custom_implementation = self.config['vcs.scm_app_implementation']
225 225 if custom_implementation == 'http':
226 226 log.debug('Using HTTP implementation of scm app.')
227 227 scm_app_impl = scm_app_http
228 228 else:
229 229 log.debug('Using custom implementation of scm_app: "{}"'.format(
230 230 custom_implementation))
231 231 scm_app_impl = importlib.import_module(custom_implementation)
232 232 return scm_app_impl
233 233
234 234 def _get_by_id(self, repo_name):
235 235 """
236 236 Gets a special pattern _<ID> from clone url and tries to replace it
237 237 with a repository_name for support of _<ID> non changeable urls
238 238 """
239 239
240 240 data = repo_name.split('/')
241 241 if len(data) >= 2:
242 242 from rhodecode.model.repo import RepoModel
243 243 by_id_match = RepoModel().get_repo_by_id(repo_name)
244 244 if by_id_match:
245 245 data[1] = by_id_match.repo_name
246 246
247 247 return safe_str('/'.join(data))
248 248
249 249 def _invalidate_cache(self, repo_name):
250 250 """
251 251 Set's cache for this repository for invalidation on next access
252 252
253 253 :param repo_name: full repo name, also a cache key
254 254 """
255 255 ScmModel().mark_for_invalidation(repo_name)
256 256
257 257 def is_valid_and_existing_repo(self, repo_name, base_path, scm_type):
258 258 db_repo = Repository.get_by_repo_name(repo_name)
259 259 if not db_repo:
260 260 log.debug('Repository `%s` not found inside the database.',
261 261 repo_name)
262 262 return False
263 263
264 264 if db_repo.repo_type != scm_type:
265 265 log.warning(
266 266 'Repository `%s` have incorrect scm_type, expected %s got %s',
267 267 repo_name, db_repo.repo_type, scm_type)
268 268 return False
269 269
270 270 config = db_repo._config
271 271 config.set('extensions', 'largefiles', '')
272 272 return is_valid_repo(
273 273 repo_name, base_path,
274 274 explicit_scm=scm_type, expect_scm=scm_type, config=config)
275 275
276 276 def valid_and_active_user(self, user):
277 277 """
278 278 Checks if that user is not empty, and if it's actually object it checks
279 279 if he's active.
280 280
281 281 :param user: user object or None
282 282 :return: boolean
283 283 """
284 284 if user is None:
285 285 return False
286 286
287 287 elif user.active:
288 288 return True
289 289
290 290 return False
291 291
292 292 @property
293 293 def is_shadow_repo_dir(self):
294 294 return os.path.isdir(self.vcs_repo_name)
295 295
296 296 def _check_permission(self, action, user, auth_user, repo_name, ip_addr=None,
297 297 plugin_id='', plugin_cache_active=False, cache_ttl=0):
298 298 """
299 299 Checks permissions using action (push/pull) user and repository
300 300 name. If plugin_cache and ttl is set it will use the plugin which
301 301 authenticated the user to store the cached permissions result for N
302 302 amount of seconds as in cache_ttl
303 303
304 304 :param action: push or pull action
305 305 :param user: user instance
306 306 :param repo_name: repository name
307 307 """
308 308
309 309 log.debug('AUTH_CACHE_TTL for permissions `%s` active: %s (TTL: %s)',
310 310 plugin_id, plugin_cache_active, cache_ttl)
311 311
312 312 user_id = user.user_id
313 313 cache_namespace_uid = 'cache_user_auth.{}'.format(user_id)
314 314 region = rc_cache.get_or_create_region('cache_perms', cache_namespace_uid)
315 315
316 316 @region.conditional_cache_on_arguments(namespace=cache_namespace_uid,
317 317 expiration_time=cache_ttl,
318 318 condition=plugin_cache_active)
319 319 def compute_perm_vcs(
320 320 cache_name, plugin_id, action, user_id, repo_name, ip_addr):
321 321
322 322 log.debug('auth: calculating permission access now...')
323 323 # check IP
324 324 inherit = user.inherit_default_permissions
325 325 ip_allowed = AuthUser.check_ip_allowed(
326 326 user_id, ip_addr, inherit_from_default=inherit)
327 327 if ip_allowed:
328 328 log.info('Access for IP:%s allowed', ip_addr)
329 329 else:
330 330 return False
331 331
332 332 if action == 'push':
333 333 perms = ('repository.write', 'repository.admin')
334 334 if not HasPermissionAnyMiddleware(*perms)(auth_user, repo_name):
335 335 return False
336 336
337 337 else:
338 338 # any other action need at least read permission
339 339 perms = (
340 340 'repository.read', 'repository.write', 'repository.admin')
341 341 if not HasPermissionAnyMiddleware(*perms)(auth_user, repo_name):
342 342 return False
343 343
344 344 return True
345 345
346 346 start = time.time()
347 347 log.debug('Running plugin `%s` permissions check', plugin_id)
348 348
349 349 # for environ based auth, password can be empty, but then the validation is
350 350 # on the server that fills in the env data needed for authentication
351 351 perm_result = compute_perm_vcs(
352 352 'vcs_permissions', plugin_id, action, user.user_id, repo_name, ip_addr)
353 353
354 354 auth_time = time.time() - start
355 355 log.debug('Permissions for plugin `%s` completed in %.4fs, '
356 356 'expiration time of fetched cache %.1fs.',
357 357 plugin_id, auth_time, cache_ttl)
358 358
359 359 return perm_result
360 360
361 361 def _get_http_scheme(self, environ):
362 362 try:
363 363 return environ['wsgi.url_scheme']
364 364 except Exception:
365 365 log.exception('Failed to read http scheme')
366 366 return 'http'
367 367
368 368 def _check_ssl(self, environ, start_response):
369 369 """
370 370 Checks the SSL check flag and returns False if SSL is not present
371 371 and required True otherwise
372 372 """
373 373 org_proto = environ['wsgi._org_proto']
374 374 # check if we have SSL required ! if not it's a bad request !
375 375 require_ssl = str2bool(self.repo_vcs_config.get('web', 'push_ssl'))
376 376 if require_ssl and org_proto == 'http':
377 377 log.debug(
378 378 'Bad request: detected protocol is `%s` and '
379 379 'SSL/HTTPS is required.', org_proto)
380 380 return False
381 381 return True
382 382
383 383 def _get_default_cache_ttl(self):
384 384 # take AUTH_CACHE_TTL from the `rhodecode` auth plugin
385 385 plugin = loadplugin('egg:rhodecode-enterprise-ce#rhodecode')
386 386 plugin_settings = plugin.get_settings()
387 387 plugin_cache_active, cache_ttl = plugin.get_ttl_cache(
388 388 plugin_settings) or (False, 0)
389 389 return plugin_cache_active, cache_ttl
390 390
391 391 def __call__(self, environ, start_response):
392 392 try:
393 393 return self._handle_request(environ, start_response)
394 394 except Exception:
395 395 log.exception("Exception while handling request")
396 396 appenlight.track_exception(environ)
397 397 return HTTPInternalServerError()(environ, start_response)
398 398 finally:
399 399 meta.Session.remove()
400 400
401 401 def _handle_request(self, environ, start_response):
402 402 if not self._check_ssl(environ, start_response):
403 403 reason = ('SSL required, while RhodeCode was unable '
404 404 'to detect this as SSL request')
405 405 log.debug('User not allowed to proceed, %s', reason)
406 406 return HTTPNotAcceptable(reason)(environ, start_response)
407 407
408 408 if not self.url_repo_name:
409 409 log.warning('Repository name is empty: %s', self.url_repo_name)
410 410 # failed to get repo name, we fail now
411 411 return HTTPNotFound()(environ, start_response)
412 412 log.debug('Extracted repo name is %s', self.url_repo_name)
413 413
414 414 ip_addr = get_ip_addr(environ)
415 415 user_agent = get_user_agent(environ)
416 416 username = None
417 417
418 418 # skip passing error to error controller
419 419 environ['pylons.status_code_redirect'] = True
420 420
421 421 # ======================================================================
422 422 # GET ACTION PULL or PUSH
423 423 # ======================================================================
424 424 action = self._get_action(environ)
425 425
426 426 # ======================================================================
427 427 # Check if this is a request to a shadow repository of a pull request.
428 428 # In this case only pull action is allowed.
429 429 # ======================================================================
430 430 if self.is_shadow_repo and action != 'pull':
431 431 reason = 'Only pull action is allowed for shadow repositories.'
432 432 log.debug('User not allowed to proceed, %s', reason)
433 433 return HTTPNotAcceptable(reason)(environ, start_response)
434 434
435 435 # Check if the shadow repo actually exists, in case someone refers
436 436 # to it, and it has been deleted because of successful merge.
437 437 if self.is_shadow_repo and not self.is_shadow_repo_dir:
438 438 log.debug(
439 439 'Shadow repo detected, and shadow repo dir `%s` is missing',
440 440 self.is_shadow_repo_dir)
441 441 return HTTPNotFound()(environ, start_response)
442 442
443 443 # ======================================================================
444 444 # CHECK ANONYMOUS PERMISSION
445 445 # ======================================================================
446 446 detect_force_push = False
447 447 check_branch_perms = False
448 448 if action in ['pull', 'push']:
449 449 user_obj = anonymous_user = User.get_default_user()
450 450 auth_user = user_obj.AuthUser()
451 451 username = anonymous_user.username
452 452 if anonymous_user.active:
453 453 plugin_cache_active, cache_ttl = self._get_default_cache_ttl()
454 454 # ONLY check permissions if the user is activated
455 455 anonymous_perm = self._check_permission(
456 456 action, anonymous_user, auth_user, self.acl_repo_name, ip_addr,
457 457 plugin_id='anonymous_access',
458 458 plugin_cache_active=plugin_cache_active,
459 459 cache_ttl=cache_ttl,
460 460 )
461 461 else:
462 462 anonymous_perm = False
463 463
464 464 if not anonymous_user.active or not anonymous_perm:
465 465 if not anonymous_user.active:
466 466 log.debug('Anonymous access is disabled, running '
467 467 'authentication')
468 468
469 469 if not anonymous_perm:
470 470 log.debug('Not enough credentials to access this '
471 471 'repository as anonymous user')
472 472
473 473 username = None
474 474 # ==============================================================
475 475 # DEFAULT PERM FAILED OR ANONYMOUS ACCESS IS DISABLED SO WE
476 476 # NEED TO AUTHENTICATE AND ASK FOR AUTH USER PERMISSIONS
477 477 # ==============================================================
478 478
479 479 # try to auth based on environ, container auth methods
480 480 log.debug('Running PRE-AUTH for container based authentication')
481 481 pre_auth = authenticate(
482 482 '', '', environ, VCS_TYPE, registry=self.registry,
483 483 acl_repo_name=self.acl_repo_name)
484 484 if pre_auth and pre_auth.get('username'):
485 485 username = pre_auth['username']
486 486 log.debug('PRE-AUTH got %s as username', username)
487 487 if pre_auth:
488 488 log.debug('PRE-AUTH successful from %s',
489 489 pre_auth.get('auth_data', {}).get('_plugin'))
490 490
491 491 # If not authenticated by the container, running basic auth
492 492 # before inject the calling repo_name for special scope checks
493 493 self.authenticate.acl_repo_name = self.acl_repo_name
494 494
495 495 plugin_cache_active, cache_ttl = False, 0
496 496 plugin = None
497 497 if not username:
498 498 self.authenticate.realm = self.authenticate.get_rc_realm()
499 499
500 500 try:
501 501 auth_result = self.authenticate(environ)
502 502 except (UserCreationError, NotAllowedToCreateUserError) as e:
503 503 log.error(e)
504 504 reason = safe_str(e)
505 505 return HTTPNotAcceptable(reason)(environ, start_response)
506 506
507 507 if isinstance(auth_result, dict):
508 508 AUTH_TYPE.update(environ, 'basic')
509 509 REMOTE_USER.update(environ, auth_result['username'])
510 510 username = auth_result['username']
511 511 plugin = auth_result.get('auth_data', {}).get('_plugin')
512 512 log.info(
513 513 'MAIN-AUTH successful for user `%s` from %s plugin',
514 514 username, plugin)
515 515
516 516 plugin_cache_active, cache_ttl = auth_result.get(
517 517 'auth_data', {}).get('_ttl_cache') or (False, 0)
518 518 else:
519 519 return auth_result.wsgi_application(environ, start_response)
520 520
521 521 # ==============================================================
522 522 # CHECK PERMISSIONS FOR THIS REQUEST USING GIVEN USERNAME
523 523 # ==============================================================
524 524 user = User.get_by_username(username)
525 525 if not self.valid_and_active_user(user):
526 526 return HTTPForbidden()(environ, start_response)
527 527 username = user.username
528 528 user_id = user.user_id
529 529
530 530 # check user attributes for password change flag
531 531 user_obj = user
532 532 auth_user = user_obj.AuthUser()
533 533 if user_obj and user_obj.username != User.DEFAULT_USER and \
534 534 user_obj.user_data.get('force_password_change'):
535 535 reason = 'password change required'
536 536 log.debug('User not allowed to authenticate, %s', reason)
537 537 return HTTPNotAcceptable(reason)(environ, start_response)
538 538
539 539 # check permissions for this repository
540 540 perm = self._check_permission(
541 541 action, user, auth_user, self.acl_repo_name, ip_addr,
542 542 plugin, plugin_cache_active, cache_ttl)
543 543 if not perm:
544 544 return HTTPForbidden()(environ, start_response)
545 545 environ['rc_auth_user_id'] = user_id
546 546
547 547 if action == 'push':
548 548 perms = auth_user.get_branch_permissions(self.acl_repo_name)
549 549 if perms:
550 550 check_branch_perms = True
551 551 detect_force_push = True
552 552
553 553 # extras are injected into UI object and later available
554 554 # in hooks executed by RhodeCode
555 555 check_locking = _should_check_locking(environ.get('QUERY_STRING'))
556 556
557 557 extras = vcs_operation_context(
558 558 environ, repo_name=self.acl_repo_name, username=username,
559 559 action=action, scm=self.SCM, check_locking=check_locking,
560 560 is_shadow_repo=self.is_shadow_repo, check_branch_perms=check_branch_perms,
561 561 detect_force_push=detect_force_push
562 562 )
563 563
564 564 # ======================================================================
565 565 # REQUEST HANDLING
566 566 # ======================================================================
567 567 repo_path = os.path.join(
568 568 safe_str(self.base_path), safe_str(self.vcs_repo_name))
569 569 log.debug('Repository path is %s', repo_path)
570 570
571 571 fix_PATH()
572 572
573 573 log.info(
574 574 '%s action on %s repo "%s" by "%s" from %s %s',
575 575 action, self.SCM, safe_str(self.url_repo_name),
576 576 safe_str(username), ip_addr, user_agent)
577 577
578 578 return self._generate_vcs_response(
579 579 environ, start_response, repo_path, extras, action)
580 580
581 581 @initialize_generator
582 582 def _generate_vcs_response(
583 583 self, environ, start_response, repo_path, extras, action):
584 584 """
585 585 Returns a generator for the response content.
586 586
587 587 This method is implemented as a generator, so that it can trigger
588 588 the cache validation after all content sent back to the client. It
589 589 also handles the locking exceptions which will be triggered when
590 590 the first chunk is produced by the underlying WSGI application.
591 591 """
592 592 txn_id = ''
593 593 if 'CONTENT_LENGTH' in environ and environ['REQUEST_METHOD'] == 'MERGE':
594 594 # case for SVN, we want to re-use the callback daemon port
595 595 # so we use the txn_id, for this we peek the body, and still save
596 596 # it as wsgi.input
597 597 data = environ['wsgi.input'].read()
598 environ['wsgi.input'] = StringIO(data)
598 environ['wsgi.input'] = io.StringIO(data)
599 599 txn_id = extract_svn_txn_id(self.acl_repo_name, data)
600 600
601 601 callback_daemon, extras = self._prepare_callback_daemon(
602 602 extras, environ, action, txn_id=txn_id)
603 603 log.debug('HOOKS extras is %s', extras)
604 604
605 605 http_scheme = self._get_http_scheme(environ)
606 606
607 607 config = self._create_config(extras, self.acl_repo_name, scheme=http_scheme)
608 608 app = self._create_wsgi_app(repo_path, self.url_repo_name, config)
609 609 with callback_daemon:
610 610 app.rc_extras = extras
611 611
612 612 try:
613 613 response = app(environ, start_response)
614 614 finally:
615 615 # This statement works together with the decorator
616 616 # "initialize_generator" above. The decorator ensures that
617 617 # we hit the first yield statement before the generator is
618 618 # returned back to the WSGI server. This is needed to
619 619 # ensure that the call to "app" above triggers the
620 620 # needed callback to "start_response" before the
621 621 # generator is actually used.
622 622 yield "__init__"
623 623
624 624 # iter content
625 625 for chunk in response:
626 626 yield chunk
627 627
628 628 try:
629 629 # invalidate cache on push
630 630 if action == 'push':
631 631 self._invalidate_cache(self.url_repo_name)
632 632 finally:
633 633 meta.Session.remove()
634 634
635 635 def _get_repository_name(self, environ):
636 636 """Get repository name out of the environmnent
637 637
638 638 :param environ: WSGI environment
639 639 """
640 640 raise NotImplementedError()
641 641
642 642 def _get_action(self, environ):
643 643 """Map request commands into a pull or push command.
644 644
645 645 :param environ: WSGI environment
646 646 """
647 647 raise NotImplementedError()
648 648
649 649 def _create_wsgi_app(self, repo_path, repo_name, config):
650 650 """Return the WSGI app that will finally handle the request."""
651 651 raise NotImplementedError()
652 652
653 653 def _create_config(self, extras, repo_name, scheme='http'):
654 654 """Create a safe config representation."""
655 655 raise NotImplementedError()
656 656
657 657 def _should_use_callback_daemon(self, extras, environ, action):
658 658 if extras.get('is_shadow_repo'):
659 659 # we don't want to execute hooks, and callback daemon for shadow repos
660 660 return False
661 661 return True
662 662
663 663 def _prepare_callback_daemon(self, extras, environ, action, txn_id=None):
664 664 direct_calls = vcs_settings.HOOKS_DIRECT_CALLS
665 665 if not self._should_use_callback_daemon(extras, environ, action):
666 666 # disable callback daemon for actions that don't require it
667 667 direct_calls = True
668 668
669 669 return prepare_callback_daemon(
670 670 extras, protocol=vcs_settings.HOOKS_PROTOCOL,
671 671 host=vcs_settings.HOOKS_HOST, use_direct_calls=direct_calls, txn_id=txn_id)
672 672
673 673
674 674 def _should_check_locking(query_string):
675 675 # this is kind of hacky, but due to how mercurial handles client-server
676 676 # server see all operation on commit; bookmarks, phases and
677 677 # obsolescence marker in different transaction, we don't want to check
678 678 # locking on those
679 679 return query_string not in ['cmd=listkeys']
@@ -1,284 +1,284 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import gzip
22 22 import shutil
23 23 import logging
24 24 import tempfile
25 25 import urllib.parse
26 26
27 27 from webob.exc import HTTPNotFound
28 28
29 29 import rhodecode
30 30 from rhodecode.lib.middleware.appenlight import wrap_in_appenlight_if_enabled
31 31 from rhodecode.lib.middleware.simplegit import SimpleGit, GIT_PROTO_PAT
32 32 from rhodecode.lib.middleware.simplehg import SimpleHg
33 33 from rhodecode.lib.middleware.simplesvn import SimpleSvn
34 34 from rhodecode.model.settings import VcsSettingsModel
35 35
36 36 log = logging.getLogger(__name__)
37 37
38 38 VCS_TYPE_KEY = '_rc_vcs_type'
39 39 VCS_TYPE_SKIP = '_rc_vcs_skip'
40 40
41 41
42 42 def is_git(environ):
43 43 """
44 44 Returns True if requests should be handled by GIT wsgi middleware
45 45 """
46 46 is_git_path = GIT_PROTO_PAT.match(environ['PATH_INFO'])
47 47 log.debug(
48 48 'request path: `%s` detected as GIT PROTOCOL %s', environ['PATH_INFO'],
49 49 is_git_path is not None)
50 50
51 51 return is_git_path
52 52
53 53
54 54 def is_hg(environ):
55 55 """
56 56 Returns True if requests target is mercurial server - header
57 57 ``HTTP_ACCEPT`` of such request would start with ``application/mercurial``.
58 58 """
59 59 is_hg_path = False
60 60
61 61 http_accept = environ.get('HTTP_ACCEPT')
62 62
63 63 if http_accept and http_accept.startswith('application/mercurial'):
64 query = urllib.parse.urlparse.parse_qs(environ['QUERY_STRING'])
64 query = urllib.parse.parse_qs(environ['QUERY_STRING'])
65 65 if 'cmd' in query:
66 66 is_hg_path = True
67 67
68 68 log.debug(
69 69 'request path: `%s` detected as HG PROTOCOL %s', environ['PATH_INFO'],
70 70 is_hg_path)
71 71
72 72 return is_hg_path
73 73
74 74
75 75 def is_svn(environ):
76 76 """
77 77 Returns True if requests target is Subversion server
78 78 """
79 79
80 80 http_dav = environ.get('HTTP_DAV', '')
81 81 magic_path_segment = rhodecode.CONFIG.get(
82 82 'rhodecode_subversion_magic_path', '/!svn')
83 83 is_svn_path = (
84 84 'subversion' in http_dav or
85 85 magic_path_segment in environ['PATH_INFO']
86 86 or environ['REQUEST_METHOD'] in ['PROPFIND', 'PROPPATCH']
87 87 )
88 88 log.debug(
89 89 'request path: `%s` detected as SVN PROTOCOL %s', environ['PATH_INFO'],
90 90 is_svn_path)
91 91
92 92 return is_svn_path
93 93
94 94
95 95 class GunzipMiddleware(object):
96 96 """
97 97 WSGI middleware that unzips gzip-encoded requests before
98 98 passing on to the underlying application.
99 99 """
100 100
101 101 def __init__(self, application):
102 102 self.app = application
103 103
104 104 def __call__(self, environ, start_response):
105 105 accepts_encoding_header = environ.get('HTTP_CONTENT_ENCODING', b'')
106 106
107 107 if b'gzip' in accepts_encoding_header:
108 108 log.debug('gzip detected, now running gunzip wrapper')
109 109 wsgi_input = environ['wsgi.input']
110 110
111 111 if not hasattr(environ['wsgi.input'], 'seek'):
112 112 # The gzip implementation in the standard library of Python 2.x
113 113 # requires the '.seek()' and '.tell()' methods to be available
114 114 # on the input stream. Read the data into a temporary file to
115 115 # work around this limitation.
116 116
117 117 wsgi_input = tempfile.SpooledTemporaryFile(64 * 1024 * 1024)
118 118 shutil.copyfileobj(environ['wsgi.input'], wsgi_input)
119 119 wsgi_input.seek(0)
120 120
121 121 environ['wsgi.input'] = gzip.GzipFile(fileobj=wsgi_input, mode='r')
122 122 # since we "Ungzipped" the content we say now it's no longer gzip
123 123 # content encoding
124 124 del environ['HTTP_CONTENT_ENCODING']
125 125
126 126 # content length has changes ? or i'm not sure
127 127 if 'CONTENT_LENGTH' in environ:
128 128 del environ['CONTENT_LENGTH']
129 129 else:
130 130 log.debug('content not gzipped, gzipMiddleware passing '
131 131 'request further')
132 132 return self.app(environ, start_response)
133 133
134 134
135 135 def is_vcs_call(environ):
136 136 if VCS_TYPE_KEY in environ:
137 137 raw_type = environ[VCS_TYPE_KEY]
138 138 return raw_type and raw_type != VCS_TYPE_SKIP
139 139 return False
140 140
141 141
142 142 def get_path_elem(route_path):
143 143 if not route_path:
144 144 return None
145 145
146 146 cleaned_route_path = route_path.lstrip('/')
147 147 if cleaned_route_path:
148 148 cleaned_route_path_elems = cleaned_route_path.split('/')
149 149 if cleaned_route_path_elems:
150 150 return cleaned_route_path_elems[0]
151 151 return None
152 152
153 153
154 154 def detect_vcs_request(environ, backends):
155 155 checks = {
156 156 'hg': (is_hg, SimpleHg),
157 157 'git': (is_git, SimpleGit),
158 158 'svn': (is_svn, SimpleSvn),
159 159 }
160 160 handler = None
161 161 # List of path views first chunk we don't do any checks
162 162 white_list = [
163 163 # e.g /_file_store/download
164 164 '_file_store',
165 165
166 166 # static files no detection
167 167 '_static',
168 168
169 169 # skip ops ping, status
170 170 '_admin/ops/ping',
171 171 '_admin/ops/status',
172 172
173 173 # full channelstream connect should be VCS skipped
174 174 '_admin/channelstream/connect',
175 175 ]
176 176
177 177 path_info = environ['PATH_INFO']
178 178
179 179 path_elem = get_path_elem(path_info)
180 180
181 181 if path_elem in white_list:
182 182 log.debug('path `%s` in whitelist, skipping...', path_info)
183 183 return handler
184 184
185 185 path_url = path_info.lstrip('/')
186 186 if path_url in white_list:
187 187 log.debug('full url path `%s` in whitelist, skipping...', path_url)
188 188 return handler
189 189
190 190 if VCS_TYPE_KEY in environ:
191 191 raw_type = environ[VCS_TYPE_KEY]
192 192 if raw_type == VCS_TYPE_SKIP:
193 193 log.debug('got `skip` marker for vcs detection, skipping...')
194 194 return handler
195 195
196 196 _check, handler = checks.get(raw_type) or [None, None]
197 197 if handler:
198 198 log.debug('got handler:%s from environ', handler)
199 199
200 200 if not handler:
201 201 log.debug('request start: checking if request for `%s` is of VCS type in order: %s', path_elem, backends)
202 202 for vcs_type in backends:
203 203 vcs_check, _handler = checks[vcs_type]
204 204 if vcs_check(environ):
205 205 log.debug('vcs handler found %s', _handler)
206 206 handler = _handler
207 207 break
208 208
209 209 return handler
210 210
211 211
212 212 class VCSMiddleware(object):
213 213
214 214 def __init__(self, app, registry, config, appenlight_client):
215 215 self.application = app
216 216 self.registry = registry
217 217 self.config = config
218 218 self.appenlight_client = appenlight_client
219 219 self.use_gzip = True
220 220 # order in which we check the middlewares, based on vcs.backends config
221 221 self.check_middlewares = config['vcs.backends']
222 222
223 223 def vcs_config(self, repo_name=None):
224 224 """
225 225 returns serialized VcsSettings
226 226 """
227 227 try:
228 228 return VcsSettingsModel(
229 229 repo=repo_name).get_ui_settings_as_config_obj()
230 230 except Exception:
231 231 pass
232 232
233 233 def wrap_in_gzip_if_enabled(self, app, config):
234 234 if self.use_gzip:
235 235 app = GunzipMiddleware(app)
236 236 return app
237 237
238 238 def _get_handler_app(self, environ):
239 239 app = None
240 240 log.debug('VCSMiddleware: detecting vcs type.')
241 241 handler = detect_vcs_request(environ, self.check_middlewares)
242 242 if handler:
243 243 app = handler(self.config, self.registry)
244 244
245 245 return app
246 246
247 247 def __call__(self, environ, start_response):
248 248 # check if we handle one of interesting protocols, optionally extract
249 249 # specific vcsSettings and allow changes of how things are wrapped
250 250 vcs_handler = self._get_handler_app(environ)
251 251 if vcs_handler:
252 252 # translate the _REPO_ID into real repo NAME for usage
253 253 # in middleware
254 254 environ['PATH_INFO'] = vcs_handler._get_by_id(environ['PATH_INFO'])
255 255
256 256 # Set acl, url and vcs repo names.
257 257 vcs_handler.set_repo_names(environ)
258 258
259 259 # register repo config back to the handler
260 260 vcs_conf = self.vcs_config(vcs_handler.acl_repo_name)
261 261 # maybe damaged/non existent settings. We still want to
262 262 # pass that point to validate on is_valid_and_existing_repo
263 263 # and return proper HTTP Code back to client
264 264 if vcs_conf:
265 265 vcs_handler.repo_vcs_config = vcs_conf
266 266
267 267 # check for type, presence in database and on filesystem
268 268 if not vcs_handler.is_valid_and_existing_repo(
269 269 vcs_handler.acl_repo_name,
270 270 vcs_handler.base_path,
271 271 vcs_handler.SCM):
272 272 return HTTPNotFound()(environ, start_response)
273 273
274 274 environ['REPO_NAME'] = vcs_handler.url_repo_name
275 275
276 276 # Wrap handler in middlewares if they are enabled.
277 277 vcs_handler = self.wrap_in_gzip_if_enabled(
278 278 vcs_handler, self.config)
279 279 vcs_handler, _ = wrap_in_appenlight_if_enabled(
280 280 vcs_handler, self.config, self.appenlight_client)
281 281
282 282 return vcs_handler(environ, start_response)
283 283
284 284 return self.application(environ, start_response)
@@ -1,1061 +1,1061 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (c) 2007-2012 Christoph Haas <email@christoph-haas.de>
4 4 # NOTE: MIT license based code, backported and edited by RhodeCode GmbH
5 5
6 6 """
7 7 paginate: helps split up large collections into individual pages
8 8 ================================================================
9 9
10 10 What is pagination?
11 11 ---------------------
12 12
13 13 This module helps split large lists of items into pages. The user is shown one page at a time and
14 14 can navigate to other pages. Imagine you are offering a company phonebook and let the user search
15 15 the entries. The entire search result may contains 23 entries but you want to display no more than
16 16 10 entries at once. The first page contains entries 1-10, the second 11-20 and the third 21-23.
17 17 Each "Page" instance represents the items of one of these three pages.
18 18
19 19 See the documentation of the "Page" class for more information.
20 20
21 21 How do I use it?
22 22 ------------------
23 23
24 24 A page of items is represented by the *Page* object. A *Page* gets initialized with these arguments:
25 25
26 26 - The collection of items to pick a range from. Usually just a list.
27 27 - The page number you want to display. Default is 1: the first page.
28 28
29 29 Now we can make up a collection and create a Page instance of it::
30 30
31 31 # Create a sample collection of 1000 items
32 32 >> my_collection = range(1000)
33 33
34 34 # Create a Page object for the 3rd page (20 items per page is the default)
35 35 >> my_page = Page(my_collection, page=3)
36 36
37 37 # The page object can be printed as a string to get its details
38 38 >> str(my_page)
39 39 Page:
40 40 Collection type: <type 'range'>
41 41 Current page: 3
42 42 First item: 41
43 43 Last item: 60
44 44 First page: 1
45 45 Last page: 50
46 46 Previous page: 2
47 47 Next page: 4
48 48 Items per page: 20
49 49 Number of items: 1000
50 50 Number of pages: 50
51 51
52 52 # Print a list of items on the current page
53 53 >> my_page.items
54 54 [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59]
55 55
56 56 # The *Page* object can be used as an iterator:
57 57 >> for my_item in my_page: print(my_item)
58 58 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
59 59
60 60 # The .pager() method returns an HTML fragment with links to surrounding pages.
61 61 >> my_page.pager(url="http://example.org/foo/page=$page")
62 62
63 63 <a href="http://example.org/foo/page=1">1</a>
64 64 <a href="http://example.org/foo/page=2">2</a>
65 65 3
66 66 <a href="http://example.org/foo/page=4">4</a>
67 67 <a href="http://example.org/foo/page=5">5</a>
68 68 ..
69 69 <a href="http://example.org/foo/page=50">50</a>'
70 70
71 71 # Without the HTML it would just look like:
72 72 # 1 2 [3] 4 5 .. 50
73 73
74 74 # The pager can be customized:
75 75 >> my_page.pager('$link_previous ~3~ $link_next (Page $page of $page_count)',
76 76 url="http://example.org/foo/page=$page")
77 77
78 78 <a href="http://example.org/foo/page=2">&lt;</a>
79 79 <a href="http://example.org/foo/page=1">1</a>
80 80 <a href="http://example.org/foo/page=2">2</a>
81 81 3
82 82 <a href="http://example.org/foo/page=4">4</a>
83 83 <a href="http://example.org/foo/page=5">5</a>
84 84 <a href="http://example.org/foo/page=6">6</a>
85 85 ..
86 86 <a href="http://example.org/foo/page=50">50</a>
87 87 <a href="http://example.org/foo/page=4">&gt;</a>
88 88 (Page 3 of 50)
89 89
90 90 # Without the HTML it would just look like:
91 91 # 1 2 [3] 4 5 6 .. 50 > (Page 3 of 50)
92 92
93 93 # The url argument to the pager method can be omitted when an url_maker is
94 94 # given during instantiation:
95 95 >> my_page = Page(my_collection, page=3,
96 96 url_maker=lambda p: "http://example.org/%s" % p)
97 97 >> page.pager()
98 98
99 99 There are some interesting parameters that customize the Page's behavior. See the documentation on
100 100 ``Page`` and ``Page.pager()``.
101 101
102 102
103 103 Notes
104 104 -------
105 105
106 106 Page numbers and item numbers start at 1. This concept has been used because users expect that the
107 107 first page has number 1 and the first item on a page also has number 1. So if you want to use the
108 108 page's items by their index number please note that you have to subtract 1.
109 109 """
110 110
111 111 import re
112 112 import sys
113 113 from string import Template
114 114 from webhelpers2.html import literal
115 115
116 116 # are we running at least python 3.x ?
117 117 PY3 = sys.version_info[0] >= 3
118 118
119 119 if PY3:
120 120 unicode = str
121 121
122 122
123 123 def make_html_tag(tag, text=None, **params):
124 124 """Create an HTML tag string.
125 125
126 126 tag
127 127 The HTML tag to use (e.g. 'a', 'span' or 'div')
128 128
129 129 text
130 130 The text to enclose between opening and closing tag. If no text is specified then only
131 131 the opening tag is returned.
132 132
133 133 Example::
134 134 make_html_tag('a', text="Hello", href="/another/page")
135 135 -> <a href="/another/page">Hello</a>
136 136
137 137 To use reserved Python keywords like "class" as a parameter prepend it with
138 138 an underscore. Instead of "class='green'" use "_class='green'".
139 139
140 140 Warning: Quotes and apostrophes are not escaped."""
141 141 params_string = ""
142 142
143 143 # Parameters are passed. Turn the dict into a string like "a=1 b=2 c=3" string.
144 144 for key, value in sorted(params.items()):
145 145 # Strip off a leading underscore from the attribute's key to allow attributes like '_class'
146 146 # to be used as a CSS class specification instead of the reserved Python keyword 'class'.
147 147 key = key.lstrip("_")
148 148
149 params_string += u' {0}="{1}"'.format(key, value)
149 params_string += ' {0}="{1}"'.format(key, value)
150 150
151 151 # Create the tag string
152 tag_string = u"<{0}{1}>".format(tag, params_string)
152 tag_string = "<{0}{1}>".format(tag, params_string)
153 153
154 154 # Add text and closing tag if required.
155 155 if text:
156 tag_string += u"{0}</{1}>".format(text, tag)
156 tag_string += "{0}</{1}>".format(text, tag)
157 157
158 158 return tag_string
159 159
160 160
161 161 # Since the items on a page are mainly a list we subclass the "list" type
162 162 class _Page(list):
163 163 """A list/iterator representing the items on one page of a larger collection.
164 164
165 165 An instance of the "Page" class is created from a _collection_ which is any
166 166 list-like object that allows random access to its elements.
167 167
168 168 The instance works as an iterator running from the first item to the last item on the given
169 169 page. The Page.pager() method creates a link list allowing the user to go to other pages.
170 170
171 171 A "Page" does not only carry the items on a certain page. It gives you additional information
172 172 about the page in these "Page" object attributes:
173 173
174 174 item_count
175 175 Number of items in the collection
176 176
177 177 **WARNING:** Unless you pass in an item_count, a count will be
178 178 performed on the collection every time a Page instance is created.
179 179
180 180 page
181 181 Number of the current page
182 182
183 183 items_per_page
184 184 Maximal number of items displayed on a page
185 185
186 186 first_page
187 187 Number of the first page - usually 1 :)
188 188
189 189 last_page
190 190 Number of the last page
191 191
192 192 previous_page
193 193 Number of the previous page. If this is the first page it returns None.
194 194
195 195 next_page
196 196 Number of the next page. If this is the last page it returns None.
197 197
198 198 page_count
199 199 Number of pages
200 200
201 201 items
202 202 Sequence/iterator of items on the current page
203 203
204 204 first_item
205 205 Index of first item on the current page - starts with 1
206 206
207 207 last_item
208 208 Index of last item on the current page
209 209 """
210 210
211 211 def __init__(
212 212 self,
213 213 collection,
214 214 page=1,
215 215 items_per_page=20,
216 216 item_count=None,
217 217 wrapper_class=None,
218 218 url_maker=None,
219 219 bar_size=10,
220 220 **kwargs
221 221 ):
222 222 """Create a "Page" instance.
223 223
224 224 Parameters:
225 225
226 226 collection
227 227 Sequence representing the collection of items to page through.
228 228
229 229 page
230 230 The requested page number - starts with 1. Default: 1.
231 231
232 232 items_per_page
233 233 The maximal number of items to be displayed per page.
234 234 Default: 20.
235 235
236 236 item_count (optional)
237 237 The total number of items in the collection - if known.
238 238 If this parameter is not given then the paginator will count
239 239 the number of elements in the collection every time a "Page"
240 240 is created. Giving this parameter will speed up things. In a busy
241 241 real-life application you may want to cache the number of items.
242 242
243 243 url_maker (optional)
244 244 Callback to generate the URL of other pages, given its numbers.
245 245 Must accept one int parameter and return a URI string.
246 246
247 247 bar_size
248 248 maximum size of rendered pages numbers within radius
249 249
250 250 """
251 251 if collection is not None:
252 252 if wrapper_class is None:
253 253 # Default case. The collection is already a list-type object.
254 254 self.collection = collection
255 255 else:
256 256 # Special case. A custom wrapper class is used to access elements of the collection.
257 257 self.collection = wrapper_class(collection)
258 258 else:
259 259 self.collection = []
260 260
261 261 self.collection_type = type(collection)
262 262
263 263 if url_maker is not None:
264 264 self.url_maker = url_maker
265 265 else:
266 266 self.url_maker = self._default_url_maker
267 267 self.bar_size = bar_size
268 268 # Assign kwargs to self
269 269 self.kwargs = kwargs
270 270
271 271 # The self.page is the number of the current page.
272 272 # The first page has the number 1!
273 273 try:
274 274 self.page = int(page) # make it int() if we get it as a string
275 275 except (ValueError, TypeError):
276 276 self.page = 1
277 277 # normally page should be always at least 1 but the original maintainer
278 278 # decided that for empty collection and empty page it can be...0? (based on tests)
279 279 # preserving behavior for BW compat
280 280 if self.page < 1:
281 281 self.page = 1
282 282
283 283 self.items_per_page = items_per_page
284 284
285 285 # We subclassed "list" so we need to call its init() method
286 286 # and fill the new list with the items to be displayed on the page.
287 287 # We use list() so that the items on the current page are retrieved
288 288 # only once. In an SQL context that could otherwise lead to running the
289 289 # same SQL query every time items would be accessed.
290 290 # We do this here, prior to calling len() on the collection so that a
291 291 # wrapper class can execute a query with the knowledge of what the
292 292 # slice will be (for efficiency) and, in the same query, ask for the
293 293 # total number of items and only execute one query.
294 294
295 295 try:
296 296 first = (self.page - 1) * items_per_page
297 297 last = first + items_per_page
298 298 self.items = list(self.collection[first:last])
299 299 except TypeError as err:
300 300 raise TypeError(
301 301 f"Your collection of type {type(self.collection)} cannot be handled "
302 302 f"by paginate. ERROR:{err}"
303 303 )
304 304
305 305 # Unless the user tells us how many items the collections has
306 306 # we calculate that ourselves.
307 307 if item_count is not None:
308 308 self.item_count = item_count
309 309 else:
310 310 self.item_count = len(self.collection)
311 311
312 312 # Compute the number of the first and last available page
313 313 if self.item_count > 0:
314 314 self.first_page = 1
315 315 self.page_count = ((self.item_count - 1) // self.items_per_page) + 1
316 316 self.last_page = self.first_page + self.page_count - 1
317 317
318 318 # Make sure that the requested page number is the range of valid pages
319 319 if self.page > self.last_page:
320 320 self.page = self.last_page
321 321 elif self.page < self.first_page:
322 322 self.page = self.first_page
323 323
324 324 # Note: the number of items on this page can be less than
325 325 # items_per_page if the last page is not full
326 326 self.first_item = (self.page - 1) * items_per_page + 1
327 327 self.last_item = min(self.first_item + items_per_page - 1, self.item_count)
328 328
329 329 # Links to previous and next page
330 330 if self.page > self.first_page:
331 331 self.previous_page = self.page - 1
332 332 else:
333 333 self.previous_page = None
334 334
335 335 if self.page < self.last_page:
336 336 self.next_page = self.page + 1
337 337 else:
338 338 self.next_page = None
339 339
340 340 # No items available
341 341 else:
342 342 self.first_page = None
343 343 self.page_count = 0
344 344 self.last_page = None
345 345 self.first_item = None
346 346 self.last_item = None
347 347 self.previous_page = None
348 348 self.next_page = None
349 349 self.items = []
350 350
351 351 # This is a subclass of the 'list' type. Initialise the list now.
352 352 list.__init__(self, self.items)
353 353
354 354 def __str__(self):
355 355 return (
356 356 "Page:\n"
357 357 "Collection type: {0.collection_type}\n"
358 358 "Current page: {0.page}\n"
359 359 "First item: {0.first_item}\n"
360 360 "Last item: {0.last_item}\n"
361 361 "First page: {0.first_page}\n"
362 362 "Last page: {0.last_page}\n"
363 363 "Previous page: {0.previous_page}\n"
364 364 "Next page: {0.next_page}\n"
365 365 "Items per page: {0.items_per_page}\n"
366 366 "Total number of items: {0.item_count}\n"
367 367 "Number of pages: {0.page_count}\n"
368 368 ).format(self)
369 369
370 370 def __repr__(self):
371 371 return "<paginate.Page: Page {0}/{1}>".format(self.page, self.page_count)
372 372
373 373 def pager(
374 374 self,
375 375 tmpl_format="~2~",
376 376 url=None,
377 377 show_if_single_page=False,
378 378 separator=" ",
379 379 symbol_first="&lt;&lt;",
380 380 symbol_last="&gt;&gt;",
381 381 symbol_previous="&lt;",
382 382 symbol_next="&gt;",
383 383 link_attr=None,
384 384 curpage_attr=None,
385 385 dotdot_attr=None,
386 386 link_tag=None,
387 387 ):
388 388 """
389 389 Return string with links to other pages (e.g. '1 .. 5 6 7 [8] 9 10 11 .. 50').
390 390
391 391 tmpl_format:
392 392 Format string that defines how the pager is rendered. The string
393 393 can contain the following $-tokens that are substituted by the
394 394 string.Template module:
395 395
396 396 - $first_page: number of first reachable page
397 397 - $last_page: number of last reachable page
398 398 - $page: number of currently selected page
399 399 - $page_count: number of reachable pages
400 400 - $items_per_page: maximal number of items per page
401 401 - $first_item: index of first item on the current page
402 402 - $last_item: index of last item on the current page
403 403 - $item_count: total number of items
404 404 - $link_first: link to first page (unless this is first page)
405 405 - $link_last: link to last page (unless this is last page)
406 406 - $link_previous: link to previous page (unless this is first page)
407 407 - $link_next: link to next page (unless this is last page)
408 408
409 409 To render a range of pages the token '~3~' can be used. The
410 410 number sets the radius of pages around the current page.
411 411 Example for a range with radius 3:
412 412
413 413 '1 .. 5 6 7 [8] 9 10 11 .. 50'
414 414
415 415 Default: '~2~'
416 416
417 417 url
418 418 The URL that page links will point to. Make sure it contains the string
419 419 $page which will be replaced by the actual page number.
420 420 Must be given unless a url_maker is specified to __init__, in which
421 421 case this parameter is ignored.
422 422
423 423 symbol_first
424 424 String to be displayed as the text for the $link_first link above.
425 425
426 426 Default: '&lt;&lt;' (<<)
427 427
428 428 symbol_last
429 429 String to be displayed as the text for the $link_last link above.
430 430
431 431 Default: '&gt;&gt;' (>>)
432 432
433 433 symbol_previous
434 434 String to be displayed as the text for the $link_previous link above.
435 435
436 436 Default: '&lt;' (<)
437 437
438 438 symbol_next
439 439 String to be displayed as the text for the $link_next link above.
440 440
441 441 Default: '&gt;' (>)
442 442
443 443 separator:
444 444 String that is used to separate page links/numbers in the above range of pages.
445 445
446 446 Default: ' '
447 447
448 448 show_if_single_page:
449 449 if True the navigator will be shown even if there is only one page.
450 450
451 451 Default: False
452 452
453 453 link_attr (optional)
454 454 A dictionary of attributes that get added to A-HREF links pointing to other pages. Can
455 455 be used to define a CSS style or class to customize the look of links.
456 456
457 457 Example: { 'style':'border: 1px solid green' }
458 458 Example: { 'class':'pager_link' }
459 459
460 460 curpage_attr (optional)
461 461 A dictionary of attributes that get added to the current page number in the pager (which
462 462 is obviously not a link). If this dictionary is not empty then the elements will be
463 463 wrapped in a SPAN tag with the given attributes.
464 464
465 465 Example: { 'style':'border: 3px solid blue' }
466 466 Example: { 'class':'pager_curpage' }
467 467
468 468 dotdot_attr (optional)
469 469 A dictionary of attributes that get added to the '..' string in the pager (which is
470 470 obviously not a link). If this dictionary is not empty then the elements will be wrapped
471 471 in a SPAN tag with the given attributes.
472 472
473 473 Example: { 'style':'color: #808080' }
474 474 Example: { 'class':'pager_dotdot' }
475 475
476 476 link_tag (optional)
477 477 A callable that accepts single argument `page` (page link information)
478 478 and generates string with html that represents the link for specific page.
479 479 Page objects are supplied from `link_map()` so the keys are the same.
480 480
481 481
482 482 """
483 483 link_attr = link_attr or {}
484 484 curpage_attr = curpage_attr or {}
485 485 dotdot_attr = dotdot_attr or {}
486 486 self.curpage_attr = curpage_attr
487 487 self.separator = separator
488 488 self.link_attr = link_attr
489 489 self.dotdot_attr = dotdot_attr
490 490 self.url = url
491 491 self.link_tag = link_tag or self.default_link_tag
492 492
493 493 # Don't show navigator if there is no more than one page
494 494 if self.page_count == 0 or (self.page_count == 1 and not show_if_single_page):
495 495 return ""
496 496
497 497 regex_res = re.search(r"~(\d+)~", tmpl_format)
498 498 if regex_res:
499 499 radius = regex_res.group(1)
500 500 else:
501 501 radius = 2
502 502
503 503 self.radius = int(radius)
504 504 link_map = self.link_map(
505 505 tmpl_format=tmpl_format,
506 506 url=url,
507 507 show_if_single_page=show_if_single_page,
508 508 separator=separator,
509 509 symbol_first=symbol_first,
510 510 symbol_last=symbol_last,
511 511 symbol_previous=symbol_previous,
512 512 symbol_next=symbol_next,
513 513 link_attr=link_attr,
514 514 curpage_attr=curpage_attr,
515 515 dotdot_attr=dotdot_attr,
516 516 link_tag=link_tag,
517 517 )
518 518 links_markup = self._range(link_map, self.radius)
519 519
520 520 # Replace ~...~ in token tmpl_format by range of pages
521 521 result = re.sub(r"~(\d+)~", links_markup, tmpl_format)
522 522
523 523 link_first = (
524 524 self.page > self.first_page and self.link_tag(link_map["first_page"]) or ""
525 525 )
526 526 link_last = (
527 527 self.page < self.last_page and self.link_tag(link_map["last_page"]) or ""
528 528 )
529 529 link_previous = (
530 530 self.previous_page and self.link_tag(link_map["previous_page"]) or ""
531 531 )
532 532 link_next = self.next_page and self.link_tag(link_map["next_page"]) or ""
533 533 # Interpolate '$' variables
534 534 result = Template(result).safe_substitute(
535 535 {
536 536 "first_page": self.first_page,
537 537 "last_page": self.last_page,
538 538 "page": self.page,
539 539 "page_count": self.page_count,
540 540 "items_per_page": self.items_per_page,
541 541 "first_item": self.first_item,
542 542 "last_item": self.last_item,
543 543 "item_count": self.item_count,
544 544 "link_first": link_first,
545 545 "link_last": link_last,
546 546 "link_previous": link_previous,
547 547 "link_next": link_next,
548 548 }
549 549 )
550 550
551 551 return result
552 552
553 553 def _get_edges(self, cur_page, max_page, items):
554 554 cur_page = int(cur_page)
555 edge = (items / 2) + 1
555 edge = (items // 2) + 1
556 556 if cur_page <= edge:
557 radius = max(items / 2, items - cur_page)
557 radius = max(items // 2, items - cur_page)
558 558 elif (max_page - cur_page) < edge:
559 559 radius = (items - 1) - (max_page - cur_page)
560 560 else:
561 radius = (items / 2) - 1
561 radius = (items // 2) - 1
562 562
563 563 left = max(1, (cur_page - radius))
564 564 right = min(max_page, cur_page + radius)
565 565 return left, right
566 566
567 567 def link_map(
568 568 self,
569 569 tmpl_format="~2~",
570 570 url=None,
571 571 show_if_single_page=False,
572 572 separator=" ",
573 573 symbol_first="&lt;&lt;",
574 574 symbol_last="&gt;&gt;",
575 575 symbol_previous="&lt;",
576 576 symbol_next="&gt;",
577 577 link_attr=None,
578 578 curpage_attr=None,
579 579 dotdot_attr=None,
580 580 link_tag=None
581 581 ):
582 582 """ Return map with links to other pages if default pager() function is not suitable solution.
583 583 tmpl_format:
584 584 Format string that defines how the pager would be normally rendered rendered. Uses same arguments as pager()
585 585 method, but returns a simple dictionary in form of:
586 586 {'current_page': {'attrs': {},
587 587 'href': 'http://example.org/foo/page=1',
588 588 'value': 1},
589 589 'first_page': {'attrs': {},
590 590 'href': 'http://example.org/foo/page=1',
591 591 'type': 'first_page',
592 592 'value': 1},
593 593 'last_page': {'attrs': {},
594 594 'href': 'http://example.org/foo/page=8',
595 595 'type': 'last_page',
596 596 'value': 8},
597 597 'next_page': {'attrs': {}, 'href': 'HREF', 'type': 'next_page', 'value': 2},
598 598 'previous_page': None,
599 599 'range_pages': [{'attrs': {},
600 600 'href': 'http://example.org/foo/page=1',
601 601 'type': 'current_page',
602 602 'value': 1},
603 603 ....
604 604 {'attrs': {}, 'href': '', 'type': 'span', 'value': '..'}]}
605 605
606 606
607 607 The string can contain the following $-tokens that are substituted by the
608 608 string.Template module:
609 609
610 610 - $first_page: number of first reachable page
611 611 - $last_page: number of last reachable page
612 612 - $page: number of currently selected page
613 613 - $page_count: number of reachable pages
614 614 - $items_per_page: maximal number of items per page
615 615 - $first_item: index of first item on the current page
616 616 - $last_item: index of last item on the current page
617 617 - $item_count: total number of items
618 618 - $link_first: link to first page (unless this is first page)
619 619 - $link_last: link to last page (unless this is last page)
620 620 - $link_previous: link to previous page (unless this is first page)
621 621 - $link_next: link to next page (unless this is last page)
622 622
623 623 To render a range of pages the token '~3~' can be used. The
624 624 number sets the radius of pages around the current page.
625 625 Example for a range with radius 3:
626 626
627 627 '1 .. 5 6 7 [8] 9 10 11 .. 50'
628 628
629 629 Default: '~2~'
630 630
631 631 url
632 632 The URL that page links will point to. Make sure it contains the string
633 633 $page which will be replaced by the actual page number.
634 634 Must be given unless a url_maker is specified to __init__, in which
635 635 case this parameter is ignored.
636 636
637 637 symbol_first
638 638 String to be displayed as the text for the $link_first link above.
639 639
640 640 Default: '&lt;&lt;' (<<)
641 641
642 642 symbol_last
643 643 String to be displayed as the text for the $link_last link above.
644 644
645 645 Default: '&gt;&gt;' (>>)
646 646
647 647 symbol_previous
648 648 String to be displayed as the text for the $link_previous link above.
649 649
650 650 Default: '&lt;' (<)
651 651
652 652 symbol_next
653 653 String to be displayed as the text for the $link_next link above.
654 654
655 655 Default: '&gt;' (>)
656 656
657 657 separator:
658 658 String that is used to separate page links/numbers in the above range of pages.
659 659
660 660 Default: ' '
661 661
662 662 show_if_single_page:
663 663 if True the navigator will be shown even if there is only one page.
664 664
665 665 Default: False
666 666
667 667 link_attr (optional)
668 668 A dictionary of attributes that get added to A-HREF links pointing to other pages. Can
669 669 be used to define a CSS style or class to customize the look of links.
670 670
671 671 Example: { 'style':'border: 1px solid green' }
672 672 Example: { 'class':'pager_link' }
673 673
674 674 curpage_attr (optional)
675 675 A dictionary of attributes that get added to the current page number in the pager (which
676 676 is obviously not a link). If this dictionary is not empty then the elements will be
677 677 wrapped in a SPAN tag with the given attributes.
678 678
679 679 Example: { 'style':'border: 3px solid blue' }
680 680 Example: { 'class':'pager_curpage' }
681 681
682 682 dotdot_attr (optional)
683 683 A dictionary of attributes that get added to the '..' string in the pager (which is
684 684 obviously not a link). If this dictionary is not empty then the elements will be wrapped
685 685 in a SPAN tag with the given attributes.
686 686
687 687 Example: { 'style':'color: #808080' }
688 688 Example: { 'class':'pager_dotdot' }
689 689 """
690 690 link_attr = link_attr or {}
691 691 curpage_attr = curpage_attr or {}
692 692 dotdot_attr = dotdot_attr or {}
693 693 self.curpage_attr = curpage_attr
694 694 self.separator = separator
695 695 self.link_attr = link_attr
696 696 self.dotdot_attr = dotdot_attr
697 697 self.url = url
698 698
699 699 regex_res = re.search(r"~(\d+)~", tmpl_format)
700 700 if regex_res:
701 701 radius = regex_res.group(1)
702 702 else:
703 703 radius = 2
704 704
705 705 self.radius = int(radius)
706 706
707 707 # Compute the first and last page number within the radius
708 708 # e.g. '1 .. 5 6 [7] 8 9 .. 12'
709 709 # -> leftmost_page = 5
710 710 # -> rightmost_page = 9
711 711 leftmost_page, rightmost_page = self._get_edges(
712 712 self.page, self.last_page, (self.radius * 2) + 1)
713 713
714 714 nav_items = {
715 715 "first_page": None,
716 716 "last_page": None,
717 717 "previous_page": None,
718 718 "next_page": None,
719 719 "current_page": None,
720 720 "radius": self.radius,
721 721 "range_pages": [],
722 722 }
723 723
724 724 if leftmost_page is None or rightmost_page is None:
725 725 return nav_items
726 726
727 727 nav_items["first_page"] = {
728 728 "type": "first_page",
729 "value": unicode(symbol_first),
729 "value": str(symbol_first),
730 730 "attrs": self.link_attr,
731 731 "number": self.first_page,
732 732 "href": self.url_maker(self.first_page),
733 733 }
734 734
735 735 # Insert dots if there are pages between the first page
736 736 # and the currently displayed page range
737 737 if leftmost_page - self.first_page > 1:
738 738 # Wrap in a SPAN tag if dotdot_attr is set
739 739 nav_items["range_pages"].append(
740 740 {
741 741 "type": "span",
742 742 "value": "..",
743 743 "attrs": self.dotdot_attr,
744 744 "href": "",
745 745 "number": None,
746 746 }
747 747 )
748 748
749 749 for this_page in range(leftmost_page, rightmost_page + 1):
750 750 # Highlight the current page number and do not use a link
751 751 if this_page == self.page:
752 752 # Wrap in a SPAN tag if curpage_attr is set
753 753 nav_items["range_pages"].append(
754 754 {
755 755 "type": "current_page",
756 "value": unicode(this_page),
756 "value": str(this_page),
757 757 "number": this_page,
758 758 "attrs": self.curpage_attr,
759 759 "href": self.url_maker(this_page),
760 760 }
761 761 )
762 762 nav_items["current_page"] = {
763 763 "value": this_page,
764 764 "attrs": self.curpage_attr,
765 765 "type": "current_page",
766 766 "href": self.url_maker(this_page),
767 767 }
768 768 # Otherwise create just a link to that page
769 769 else:
770 770 nav_items["range_pages"].append(
771 771 {
772 772 "type": "page",
773 "value": unicode(this_page),
773 "value": str(this_page),
774 774 "number": this_page,
775 775 "attrs": self.link_attr,
776 776 "href": self.url_maker(this_page),
777 777 }
778 778 )
779 779
780 780 # Insert dots if there are pages between the displayed
781 781 # page numbers and the end of the page range
782 782 if self.last_page - rightmost_page > 1:
783 783 # Wrap in a SPAN tag if dotdot_attr is set
784 784 nav_items["range_pages"].append(
785 785 {
786 786 "type": "span",
787 787 "value": "..",
788 788 "attrs": self.dotdot_attr,
789 789 "href": "",
790 790 "number": None,
791 791 }
792 792 )
793 793
794 794 # Create a link to the very last page (unless we are on the last
795 795 # page or there would be no need to insert '..' spacers)
796 796 nav_items["last_page"] = {
797 797 "type": "last_page",
798 "value": unicode(symbol_last),
798 "value": str(symbol_last),
799 799 "attrs": self.link_attr,
800 800 "href": self.url_maker(self.last_page),
801 801 "number": self.last_page,
802 802 }
803 803
804 804 nav_items["previous_page"] = {
805 805 "type": "previous_page",
806 "value": unicode(symbol_previous),
806 "value": str(symbol_previous),
807 807 "attrs": self.link_attr,
808 808 "number": self.previous_page or self.first_page,
809 809 "href": self.url_maker(self.previous_page or self.first_page),
810 810 }
811 811
812 812 nav_items["next_page"] = {
813 813 "type": "next_page",
814 "value": unicode(symbol_next),
814 "value": str(symbol_next),
815 815 "attrs": self.link_attr,
816 816 "number": self.next_page or self.last_page,
817 817 "href": self.url_maker(self.next_page or self.last_page),
818 818 }
819 819
820 820 return nav_items
821 821
822 822 def _range(self, link_map, radius):
823 823 """
824 824 Return range of linked pages to substitute placeholder in pattern
825 825 """
826 826 # Compute the first and last page number within the radius
827 827 # e.g. '1 .. 5 6 [7] 8 9 .. 12'
828 828 # -> leftmost_page = 5
829 829 # -> rightmost_page = 9
830 830 leftmost_page, rightmost_page = self._get_edges(
831 831 self.page, self.last_page, (radius * 2) + 1)
832 832
833 833 nav_items = []
834 834 # Create a link to the first page (unless we are on the first page
835 835 # or there would be no need to insert '..' spacers)
836 836 if self.first_page and self.page != self.first_page and self.first_page < leftmost_page:
837 837 page = link_map["first_page"].copy()
838 page["value"] = unicode(page["number"])
838 page["value"] = str(page["number"])
839 839 nav_items.append(self.link_tag(page))
840 840
841 841 for item in link_map["range_pages"]:
842 842 nav_items.append(self.link_tag(item))
843 843
844 844 # Create a link to the very last page (unless we are on the last
845 845 # page or there would be no need to insert '..' spacers)
846 846 if self.last_page and self.page != self.last_page and rightmost_page < self.last_page:
847 847 page = link_map["last_page"].copy()
848 page["value"] = unicode(page["number"])
848 page["value"] = str(page["number"])
849 849 nav_items.append(self.link_tag(page))
850 850
851 851 return self.separator.join(nav_items)
852 852
853 853 def _default_url_maker(self, page_number):
854 854 if self.url is None:
855 855 raise Exception(
856 856 "You need to specify a 'url' parameter containing a '$page' placeholder."
857 857 )
858 858
859 859 if "$page" not in self.url:
860 860 raise Exception("The 'url' parameter must contain a '$page' placeholder.")
861 861
862 return self.url.replace("$page", unicode(page_number))
862 return self.url.replace("$page", str(page_number))
863 863
864 864 @staticmethod
865 865 def default_link_tag(item):
866 866 """
867 867 Create an A-HREF tag that points to another page.
868 868 """
869 869 text = item["value"]
870 870 target_url = item["href"]
871 871
872 872 if not item["href"] or item["type"] in ("span", "current_page"):
873 873 if item["attrs"]:
874 874 text = make_html_tag("span", **item["attrs"]) + text + "</span>"
875 875 return text
876 876
877 877 return make_html_tag("a", text=text, href=target_url, **item["attrs"])
878 878
879 879 # Below is RhodeCode custom code
880 880
881 881 # Copyright (C) 2010-2020 RhodeCode GmbH
882 882 #
883 883 # This program is free software: you can redistribute it and/or modify
884 884 # it under the terms of the GNU Affero General Public License, version 3
885 885 # (only), as published by the Free Software Foundation.
886 886 #
887 887 # This program is distributed in the hope that it will be useful,
888 888 # but WITHOUT ANY WARRANTY; without even the implied warranty of
889 889 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
890 890 # GNU General Public License for more details.
891 891 #
892 892 # You should have received a copy of the GNU Affero General Public License
893 893 # along with this program. If not, see <http://www.gnu.org/licenses/>.
894 894 #
895 895 # This program is dual-licensed. If you wish to learn more about the
896 896 # RhodeCode Enterprise Edition, including its added features, Support services,
897 897 # and proprietary license terms, please see https://rhodecode.com/licenses/
898 898
899 899
900 900 PAGE_FORMAT = '$link_previous ~3~ $link_next'
901 901
902 902
903 903 class SqlalchemyOrmWrapper(object):
904 904 """Wrapper class to access elements of a collection."""
905 905
906 906 def __init__(self, pager, collection):
907 907 self.pager = pager
908 908 self.collection = collection
909 909
910 910 def __getitem__(self, range):
911 911 # Return a range of objects of an sqlalchemy.orm.query.Query object
912 912 return self.collection[range]
913 913
914 914 def __len__(self):
915 915 # support empty types, without actually making a query.
916 916 if self.collection is None or self.collection == []:
917 917 return 0
918 918
919 919 # Count the number of objects in an sqlalchemy.orm.query.Query object
920 920 return self.collection.count()
921 921
922 922
923 923 class CustomPager(_Page):
924 924
925 925 @staticmethod
926 926 def disabled_link_tag(item):
927 927 """
928 928 Create an A-HREF tag that is disabled
929 929 """
930 930 text = item['value']
931 931 attrs = item['attrs'].copy()
932 932 attrs['class'] = 'disabled ' + attrs['class']
933 933
934 934 return make_html_tag('a', text=text, **attrs)
935 935
936 936 def render(self):
937 937 # Don't show navigator if there is no more than one page
938 938 if self.page_count == 0:
939 939 return ""
940 940
941 941 self.link_tag = self.default_link_tag
942 942
943 943 link_map = self.link_map(
944 944 tmpl_format=PAGE_FORMAT, url=None,
945 945 show_if_single_page=False, separator=' ',
946 946 symbol_first='<<', symbol_last='>>',
947 947 symbol_previous='<', symbol_next='>',
948 948 link_attr={'class': 'pager_link'},
949 949 curpage_attr={'class': 'pager_curpage'},
950 950 dotdot_attr={'class': 'pager_dotdot'})
951 951
952 952 links_markup = self._range(link_map, self.radius)
953 953
954 954 link_first = (
955 955 self.page > self.first_page and self.link_tag(link_map['first_page']) or ''
956 956 )
957 957 link_last = (
958 958 self.page < self.last_page and self.link_tag(link_map['last_page']) or ''
959 959 )
960 960
961 961 link_previous = (
962 962 self.previous_page and self.link_tag(link_map['previous_page'])
963 963 or self.disabled_link_tag(link_map['previous_page'])
964 964 )
965 965 link_next = (
966 966 self.next_page and self.link_tag(link_map['next_page'])
967 967 or self.disabled_link_tag(link_map['next_page'])
968 968 )
969 969
970 970 # Interpolate '$' variables
971 971 # Replace ~...~ in token tmpl_format by range of pages
972 972 result = re.sub(r"~(\d+)~", links_markup, PAGE_FORMAT)
973 973 result = Template(result).safe_substitute(
974 974 {
975 975 "links": links_markup,
976 976 "first_page": self.first_page,
977 977 "last_page": self.last_page,
978 978 "page": self.page,
979 979 "page_count": self.page_count,
980 980 "items_per_page": self.items_per_page,
981 981 "first_item": self.first_item,
982 982 "last_item": self.last_item,
983 983 "item_count": self.item_count,
984 984 "link_first": link_first,
985 985 "link_last": link_last,
986 986 "link_previous": link_previous,
987 987 "link_next": link_next,
988 988 }
989 989 )
990 990
991 991 return literal(result)
992 992
993 993
994 994 class Page(CustomPager):
995 995 """
996 996 Custom pager to match rendering style with paginator
997 997 """
998 998
999 999 def __init__(self, collection, page=1, items_per_page=20, item_count=None,
1000 1000 url_maker=None, **kwargs):
1001 1001 """
1002 1002 Special type of pager. We intercept collection to wrap it in our custom
1003 1003 logic instead of using wrapper_class
1004 1004 """
1005 1005
1006 1006 super(Page, self).__init__(collection=collection, page=page,
1007 1007 items_per_page=items_per_page, item_count=item_count,
1008 1008 wrapper_class=None, url_maker=url_maker, **kwargs)
1009 1009
1010 1010
1011 1011 class SqlPage(CustomPager):
1012 1012 """
1013 1013 Custom pager to match rendering style with paginator
1014 1014 """
1015 1015
1016 1016 def __init__(self, collection, page=1, items_per_page=20, item_count=None,
1017 1017 url_maker=None, **kwargs):
1018 1018 """
1019 1019 Special type of pager. We intercept collection to wrap it in our custom
1020 1020 logic instead of using wrapper_class
1021 1021 """
1022 1022 collection = SqlalchemyOrmWrapper(self, collection)
1023 1023
1024 1024 super(SqlPage, self).__init__(collection=collection, page=page,
1025 1025 items_per_page=items_per_page, item_count=item_count,
1026 1026 wrapper_class=None, url_maker=url_maker, **kwargs)
1027 1027
1028 1028
1029 1029 class RepoCommitsWrapper(object):
1030 1030 """Wrapper class to access elements of a collection."""
1031 1031
1032 1032 def __init__(self, pager, collection):
1033 1033 self.pager = pager
1034 1034 self.collection = collection
1035 1035
1036 1036 def __getitem__(self, range):
1037 1037 cur_page = self.pager.page
1038 1038 items_per_page = self.pager.items_per_page
1039 1039 first_item = max(0, (len(self.collection) - (cur_page * items_per_page)))
1040 1040 last_item = ((len(self.collection) - 1) - items_per_page * (cur_page - 1))
1041 1041 return reversed(list(self.collection[first_item:last_item + 1]))
1042 1042
1043 1043 def __len__(self):
1044 1044 return len(self.collection)
1045 1045
1046 1046
1047 1047 class RepoPage(CustomPager):
1048 1048 """
1049 1049 Create a "RepoPage" instance. special pager for paging repository
1050 1050 """
1051 1051
1052 1052 def __init__(self, collection, page=1, items_per_page=20, item_count=None,
1053 1053 url_maker=None, **kwargs):
1054 1054 """
1055 1055 Special type of pager. We intercept collection to wrap it in our custom
1056 1056 logic instead of using wrapper_class
1057 1057 """
1058 1058 collection = RepoCommitsWrapper(self, collection)
1059 1059 super(RepoPage, self).__init__(collection=collection, page=page,
1060 1060 items_per_page=items_per_page, item_count=item_count,
1061 1061 wrapper_class=None, url_maker=url_maker, **kwargs)
@@ -1,264 +1,264 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2017-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import os
22 22 import re
23 23 import time
24 24 import datetime
25 25 import dateutil
26 26 import pickle
27 27
28 28 from rhodecode.model.db import DbSession, Session
29 29
30 30
31 31 class CleanupCommand(Exception):
32 32 pass
33 33
34 34
35 35 class BaseAuthSessions(object):
36 36 SESSION_TYPE = None
37 37 NOT_AVAILABLE = 'NOT AVAILABLE'
38 38
39 39 def __init__(self, config):
40 40 session_conf = {}
41 41 for k, v in config.items():
42 42 if k.startswith('beaker.session'):
43 43 session_conf[k] = v
44 44 self.config = session_conf
45 45
46 46 def get_count(self):
47 47 raise NotImplementedError
48 48
49 49 def get_expired_count(self, older_than_seconds=None):
50 50 raise NotImplementedError
51 51
52 52 def clean_sessions(self, older_than_seconds=None):
53 53 raise NotImplementedError
54 54
55 55 def _seconds_to_date(self, seconds):
56 56 return datetime.datetime.utcnow() - dateutil.relativedelta.relativedelta(
57 57 seconds=seconds)
58 58
59 59
60 60 class DbAuthSessions(BaseAuthSessions):
61 61 SESSION_TYPE = 'ext:database'
62 62
63 63 def get_count(self):
64 64 return DbSession.query().count()
65 65
66 66 def get_expired_count(self, older_than_seconds=None):
67 67 expiry_date = self._seconds_to_date(older_than_seconds)
68 68 return DbSession.query().filter(DbSession.accessed < expiry_date).count()
69 69
70 70 def clean_sessions(self, older_than_seconds=None):
71 71 expiry_date = self._seconds_to_date(older_than_seconds)
72 72 to_remove = DbSession.query().filter(DbSession.accessed < expiry_date).count()
73 73 DbSession.query().filter(DbSession.accessed < expiry_date).delete()
74 74 Session().commit()
75 75 return to_remove
76 76
77 77
78 78 class FileAuthSessions(BaseAuthSessions):
79 79 SESSION_TYPE = 'file sessions'
80 80
81 81 def _get_sessions_dir(self):
82 82 data_dir = self.config.get('beaker.session.data_dir')
83 83 return data_dir
84 84
85 85 def _count_on_filesystem(self, path, older_than=0, callback=None):
86 86 value = dict(percent=0, used=0, total=0, items=0, callbacks=0,
87 87 path=path, text='')
88 88 items_count = 0
89 89 used = 0
90 90 callbacks = 0
91 91 cur_time = time.time()
92 92 for root, dirs, files in os.walk(path):
93 93 for f in files:
94 94 final_path = os.path.join(root, f)
95 95 try:
96 96 mtime = os.stat(final_path).st_mtime
97 97 if (cur_time - mtime) > older_than:
98 98 items_count += 1
99 99 if callback:
100 100 callback_res = callback(final_path)
101 101 callbacks += 1
102 102 else:
103 103 used += os.path.getsize(final_path)
104 104 except OSError:
105 105 pass
106 106 value.update({
107 107 'percent': 100,
108 108 'used': used,
109 109 'total': used,
110 110 'items': items_count,
111 111 'callbacks': callbacks
112 112 })
113 113 return value
114 114
115 115 def get_count(self):
116 116 try:
117 117 sessions_dir = self._get_sessions_dir()
118 118 items_count = self._count_on_filesystem(sessions_dir)['items']
119 119 except Exception:
120 120 items_count = self.NOT_AVAILABLE
121 121 return items_count
122 122
123 123 def get_expired_count(self, older_than_seconds=0):
124 124 try:
125 125 sessions_dir = self._get_sessions_dir()
126 126 items_count = self._count_on_filesystem(
127 127 sessions_dir, older_than=older_than_seconds)['items']
128 128 except Exception:
129 129 items_count = self.NOT_AVAILABLE
130 130 return items_count
131 131
132 132 def clean_sessions(self, older_than_seconds=0):
133 133 # find . -mtime +60 -exec rm {} \;
134 134
135 135 sessions_dir = self._get_sessions_dir()
136 136
137 137 def remove_item(path):
138 138 os.remove(path)
139 139
140 140 stats = self._count_on_filesystem(
141 141 sessions_dir, older_than=older_than_seconds,
142 142 callback=remove_item)
143 143 return stats['callbacks']
144 144
145 145
146 146 class MemcachedAuthSessions(BaseAuthSessions):
147 147 SESSION_TYPE = 'ext:memcached'
148 148 _key_regex = re.compile(r'ITEM (.*_session) \[(.*); (.*)\]')
149 149
150 150 def _get_client(self):
151 151 import memcache
152 152 client = memcache.Client([self.config.get('beaker.session.url')])
153 153 return client
154 154
155 155 def _get_telnet_client(self, host, port):
156 156 import telnetlib
157 157 client = telnetlib.Telnet(host, port, None)
158 158 return client
159 159
160 160 def _run_telnet_cmd(self, client, cmd):
161 161 client.write("%s\n" % cmd)
162 162 return client.read_until('END')
163 163
164 164 def key_details(self, client, slab_ids, limit=100):
165 165 """ Return a list of tuples containing keys and details """
166 166 cmd = 'stats cachedump %s %s'
167 167 for slab_id in slab_ids:
168 168 for key in self._key_regex.finditer(
169 169 self._run_telnet_cmd(client, cmd % (slab_id, limit))):
170 170 yield key
171 171
172 172 def get_count(self):
173 173 client = self._get_client()
174 174 count = self.NOT_AVAILABLE
175 175 try:
176 176 slabs = []
177 177 for server, slabs_data in client.get_slabs():
178 slabs.extend(slabs_data.keys())
178 slabs.extend(list(slabs_data.keys()))
179 179
180 180 host, port = client.servers[0].address
181 181 telnet_client = self._get_telnet_client(host, port)
182 182 keys = self.key_details(telnet_client, slabs)
183 183 count = 0
184 184 for _k in keys:
185 185 count += 1
186 186 except Exception:
187 187 return count
188 188
189 189 return count
190 190
191 191 def get_expired_count(self, older_than_seconds=None):
192 192 return self.NOT_AVAILABLE
193 193
194 194 def clean_sessions(self, older_than_seconds=None):
195 195 raise CleanupCommand('Cleanup for this session type not yet available')
196 196
197 197
198 198 class RedisAuthSessions(BaseAuthSessions):
199 199 SESSION_TYPE = 'ext:redis'
200 200
201 201 def _get_client(self):
202 202 import redis
203 203 args = {
204 204 'socket_timeout': 60,
205 205 'url': self.config.get('beaker.session.url')
206 206 }
207 207
208 208 client = redis.StrictRedis.from_url(**args)
209 209 return client
210 210
211 211 def get_count(self):
212 212 client = self._get_client()
213 213 return len(client.keys('beaker_cache:*'))
214 214
215 215 def get_expired_count(self, older_than_seconds=None):
216 216 expiry_date = self._seconds_to_date(older_than_seconds)
217 217 return self.NOT_AVAILABLE
218 218
219 219 def clean_sessions(self, older_than_seconds=None):
220 220 client = self._get_client()
221 221 expiry_time = time.time() - older_than_seconds
222 222 deleted_keys = 0
223 223 for key in client.keys('beaker_cache:*'):
224 224 data = client.get(key)
225 225 if data:
226 226 json_data = pickle.loads(data)
227 227 try:
228 228 accessed_time = json_data['_accessed_time']
229 229 except KeyError:
230 230 accessed_time = 0
231 231 if accessed_time < expiry_time:
232 232 client.delete(key)
233 233 deleted_keys += 1
234 234
235 235 return deleted_keys
236 236
237 237
238 238 class MemoryAuthSessions(BaseAuthSessions):
239 239 SESSION_TYPE = 'memory'
240 240
241 241 def get_count(self):
242 242 return self.NOT_AVAILABLE
243 243
244 244 def get_expired_count(self, older_than_seconds=None):
245 245 return self.NOT_AVAILABLE
246 246
247 247 def clean_sessions(self, older_than_seconds=None):
248 248 raise CleanupCommand('Cleanup for this session type not yet available')
249 249
250 250
251 251 def get_session_handler(session_type):
252 252 types = {
253 253 'file': FileAuthSessions,
254 254 'ext:memcached': MemcachedAuthSessions,
255 255 'ext:redis': RedisAuthSessions,
256 256 'ext:database': DbAuthSessions,
257 257 'memory': MemoryAuthSessions
258 258 }
259 259
260 260 try:
261 261 return types[session_type]
262 262 except KeyError:
263 263 raise ValueError(
264 264 'This type {} is not supported'.format(session_type))
@@ -1,799 +1,799 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 """
22 22 Utilities library for RhodeCode
23 23 """
24 24
25 25 import datetime
26 26 import decorator
27 27 import json
28 28 import logging
29 29 import os
30 30 import re
31 31 import sys
32 32 import shutil
33 33 import socket
34 34 import tempfile
35 35 import traceback
36 36 import tarfile
37 37 import warnings
38 38 import hashlib
39 39 from os.path import join as jn
40 40
41 41 import paste
42 42 import pkg_resources
43 43 from webhelpers2.text import collapse, remove_formatting
44 44 from mako import exceptions
45 45 from pyramid.threadlocal import get_current_registry
46 46
47 47 from rhodecode.lib.vcs.backends.base import Config
48 48 from rhodecode.lib.vcs.exceptions import VCSError
49 49 from rhodecode.lib.vcs.utils.helpers import get_scm, get_scm_backend
50 50 from rhodecode.lib.utils2 import (
51 51 safe_str, safe_unicode, get_current_rhodecode_user, md5, sha1)
52 52 from rhodecode.model import meta
53 53 from rhodecode.model.db import (
54 54 Repository, User, RhodeCodeUi, UserLog, RepoGroup, UserGroup)
55 55 from rhodecode.model.meta import Session
56 56
57 57
58 58 log = logging.getLogger(__name__)
59 59
60 60 REMOVED_REPO_PAT = re.compile(r'rm__\d{8}_\d{6}_\d{6}__.*')
61 61
62 62 # String which contains characters that are not allowed in slug names for
63 63 # repositories or repository groups. It is properly escaped to use it in
64 64 # regular expressions.
65 65 SLUG_BAD_CHARS = re.escape('`?=[]\;\'"<>,/~!@#$%^&*()+{}|:')
66 66
67 67 # Regex that matches forbidden characters in repo/group slugs.
68 68 SLUG_BAD_CHAR_RE = re.compile('[{}\x00-\x08\x0b-\x0c\x0e-\x1f]'.format(SLUG_BAD_CHARS))
69 69
70 70 # Regex that matches allowed characters in repo/group slugs.
71 71 SLUG_GOOD_CHAR_RE = re.compile('[^{}]'.format(SLUG_BAD_CHARS))
72 72
73 73 # Regex that matches whole repo/group slugs.
74 74 SLUG_RE = re.compile('[^{}]+'.format(SLUG_BAD_CHARS))
75 75
76 76 _license_cache = None
77 77
78 78
79 79 def repo_name_slug(value):
80 80 """
81 81 Return slug of name of repository
82 82 This function is called on each creation/modification
83 83 of repository to prevent bad names in repo
84 84 """
85 85 replacement_char = '-'
86 86
87 87 slug = remove_formatting(value)
88 88 slug = SLUG_BAD_CHAR_RE.sub('', slug)
89 89 slug = re.sub('[\s]+', '-', slug)
90 90 slug = collapse(slug, replacement_char)
91 91 return slug
92 92
93 93
94 94 #==============================================================================
95 95 # PERM DECORATOR HELPERS FOR EXTRACTING NAMES FOR PERM CHECKS
96 96 #==============================================================================
97 97 def get_repo_slug(request):
98 98 _repo = ''
99 99
100 100 if hasattr(request, 'db_repo'):
101 101 # if our requests has set db reference use it for name, this
102 102 # translates the example.com/_<id> into proper repo names
103 103 _repo = request.db_repo.repo_name
104 104 elif getattr(request, 'matchdict', None):
105 105 # pyramid
106 106 _repo = request.matchdict.get('repo_name')
107 107
108 108 if _repo:
109 109 _repo = _repo.rstrip('/')
110 110 return _repo
111 111
112 112
113 113 def get_repo_group_slug(request):
114 114 _group = ''
115 115 if hasattr(request, 'db_repo_group'):
116 116 # if our requests has set db reference use it for name, this
117 117 # translates the example.com/_<id> into proper repo group names
118 118 _group = request.db_repo_group.group_name
119 119 elif getattr(request, 'matchdict', None):
120 120 # pyramid
121 121 _group = request.matchdict.get('repo_group_name')
122 122
123 123 if _group:
124 124 _group = _group.rstrip('/')
125 125 return _group
126 126
127 127
128 128 def get_user_group_slug(request):
129 129 _user_group = ''
130 130
131 131 if hasattr(request, 'db_user_group'):
132 132 _user_group = request.db_user_group.users_group_name
133 133 elif getattr(request, 'matchdict', None):
134 134 # pyramid
135 135 _user_group = request.matchdict.get('user_group_id')
136 136 _user_group_name = request.matchdict.get('user_group_name')
137 137 try:
138 138 if _user_group:
139 139 _user_group = UserGroup.get(_user_group)
140 140 elif _user_group_name:
141 141 _user_group = UserGroup.get_by_group_name(_user_group_name)
142 142
143 143 if _user_group:
144 144 _user_group = _user_group.users_group_name
145 145 except Exception:
146 146 log.exception('Failed to get user group by id and name')
147 147 # catch all failures here
148 148 return None
149 149
150 150 return _user_group
151 151
152 152
153 153 def get_filesystem_repos(path, recursive=False, skip_removed_repos=True):
154 154 """
155 155 Scans given path for repos and return (name,(type,path)) tuple
156 156
157 157 :param path: path to scan for repositories
158 158 :param recursive: recursive search and return names with subdirs in front
159 159 """
160 160
161 161 # remove ending slash for better results
162 162 path = path.rstrip(os.sep)
163 163 log.debug('now scanning in %s location recursive:%s...', path, recursive)
164 164
165 165 def _get_repos(p):
166 166 dirpaths = _get_dirpaths(p)
167 167 if not _is_dir_writable(p):
168 168 log.warning('repo path without write access: %s', p)
169 169
170 170 for dirpath in dirpaths:
171 171 if os.path.isfile(os.path.join(p, dirpath)):
172 172 continue
173 173 cur_path = os.path.join(p, dirpath)
174 174
175 175 # skip removed repos
176 176 if skip_removed_repos and REMOVED_REPO_PAT.match(dirpath):
177 177 continue
178 178
179 179 #skip .<somethin> dirs
180 180 if dirpath.startswith('.'):
181 181 continue
182 182
183 183 try:
184 184 scm_info = get_scm(cur_path)
185 185 yield scm_info[1].split(path, 1)[-1].lstrip(os.sep), scm_info
186 186 except VCSError:
187 187 if not recursive:
188 188 continue
189 189 #check if this dir containts other repos for recursive scan
190 190 rec_path = os.path.join(p, dirpath)
191 191 if os.path.isdir(rec_path):
192 192 for inner_scm in _get_repos(rec_path):
193 193 yield inner_scm
194 194
195 195 return _get_repos(path)
196 196
197 197
198 198 def _get_dirpaths(p):
199 199 try:
200 200 # OS-independable way of checking if we have at least read-only
201 201 # access or not.
202 202 dirpaths = os.listdir(p)
203 203 except OSError:
204 204 log.warning('ignoring repo path without read access: %s', p)
205 205 return []
206 206
207 207 # os.listpath has a tweak: If a unicode is passed into it, then it tries to
208 208 # decode paths and suddenly returns unicode objects itself. The items it
209 209 # cannot decode are returned as strings and cause issues.
210 210 #
211 211 # Those paths are ignored here until a solid solution for path handling has
212 212 # been built.
213 213 expected_type = type(p)
214 214
215 215 def _has_correct_type(item):
216 216 if type(item) is not expected_type:
217 217 log.error(
218 u"Ignoring path %s since it cannot be decoded into unicode.",
218 "Ignoring path %s since it cannot be decoded into unicode.",
219 219 # Using "repr" to make sure that we see the byte value in case
220 220 # of support.
221 221 repr(item))
222 222 return False
223 223 return True
224 224
225 225 dirpaths = [item for item in dirpaths if _has_correct_type(item)]
226 226
227 227 return dirpaths
228 228
229 229
230 230 def _is_dir_writable(path):
231 231 """
232 232 Probe if `path` is writable.
233 233
234 234 Due to trouble on Cygwin / Windows, this is actually probing if it is
235 235 possible to create a file inside of `path`, stat does not produce reliable
236 236 results in this case.
237 237 """
238 238 try:
239 239 with tempfile.TemporaryFile(dir=path):
240 240 pass
241 241 except OSError:
242 242 return False
243 243 return True
244 244
245 245
246 246 def is_valid_repo(repo_name, base_path, expect_scm=None, explicit_scm=None, config=None):
247 247 """
248 248 Returns True if given path is a valid repository False otherwise.
249 249 If expect_scm param is given also, compare if given scm is the same
250 250 as expected from scm parameter. If explicit_scm is given don't try to
251 251 detect the scm, just use the given one to check if repo is valid
252 252
253 253 :param repo_name:
254 254 :param base_path:
255 255 :param expect_scm:
256 256 :param explicit_scm:
257 257 :param config:
258 258
259 259 :return True: if given path is a valid repository
260 260 """
261 261 full_path = os.path.join(safe_str(base_path), safe_str(repo_name))
262 262 log.debug('Checking if `%s` is a valid path for repository. '
263 263 'Explicit type: %s', repo_name, explicit_scm)
264 264
265 265 try:
266 266 if explicit_scm:
267 267 detected_scms = [get_scm_backend(explicit_scm)(
268 268 full_path, config=config).alias]
269 269 else:
270 270 detected_scms = get_scm(full_path)
271 271
272 272 if expect_scm:
273 273 return detected_scms[0] == expect_scm
274 274 log.debug('path: %s is an vcs object:%s', full_path, detected_scms)
275 275 return True
276 276 except VCSError:
277 277 log.debug('path: %s is not a valid repo !', full_path)
278 278 return False
279 279
280 280
281 281 def is_valid_repo_group(repo_group_name, base_path, skip_path_check=False):
282 282 """
283 283 Returns True if given path is a repository group, False otherwise
284 284
285 285 :param repo_name:
286 286 :param base_path:
287 287 """
288 288 full_path = os.path.join(safe_str(base_path), safe_str(repo_group_name))
289 289 log.debug('Checking if `%s` is a valid path for repository group',
290 290 repo_group_name)
291 291
292 292 # check if it's not a repo
293 293 if is_valid_repo(repo_group_name, base_path):
294 294 log.debug('Repo called %s exist, it is not a valid repo group', repo_group_name)
295 295 return False
296 296
297 297 try:
298 298 # we need to check bare git repos at higher level
299 299 # since we might match branches/hooks/info/objects or possible
300 300 # other things inside bare git repo
301 301 maybe_repo = os.path.dirname(full_path)
302 302 if maybe_repo == base_path:
303 303 # skip root level repo check, we know root location CANNOT BE a repo group
304 304 return False
305 305
306 306 scm_ = get_scm(maybe_repo)
307 307 log.debug('path: %s is a vcs object:%s, not valid repo group', full_path, scm_)
308 308 return False
309 309 except VCSError:
310 310 pass
311 311
312 312 # check if it's a valid path
313 313 if skip_path_check or os.path.isdir(full_path):
314 314 log.debug('path: %s is a valid repo group !', full_path)
315 315 return True
316 316
317 317 log.debug('path: %s is not a valid repo group !', full_path)
318 318 return False
319 319
320 320
321 321 def ask_ok(prompt, retries=4, complaint='[y]es or [n]o please!'):
322 322 while True:
323 ok = raw_input(prompt)
323 ok = eval(input(prompt))
324 324 if ok.lower() in ('y', 'ye', 'yes'):
325 325 return True
326 326 if ok.lower() in ('n', 'no', 'nop', 'nope'):
327 327 return False
328 328 retries = retries - 1
329 329 if retries < 0:
330 330 raise IOError
331 331 print(complaint)
332 332
333 333 # propagated from mercurial documentation
334 334 ui_sections = [
335 335 'alias', 'auth',
336 336 'decode/encode', 'defaults',
337 337 'diff', 'email',
338 338 'extensions', 'format',
339 339 'merge-patterns', 'merge-tools',
340 340 'hooks', 'http_proxy',
341 341 'smtp', 'patch',
342 342 'paths', 'profiling',
343 343 'server', 'trusted',
344 344 'ui', 'web', ]
345 345
346 346
347 347 def config_data_from_db(clear_session=True, repo=None):
348 348 """
349 349 Read the configuration data from the database and return configuration
350 350 tuples.
351 351 """
352 352 from rhodecode.model.settings import VcsSettingsModel
353 353
354 354 config = []
355 355
356 356 sa = meta.Session()
357 357 settings_model = VcsSettingsModel(repo=repo, sa=sa)
358 358
359 359 ui_settings = settings_model.get_ui_settings()
360 360
361 361 ui_data = []
362 362 for setting in ui_settings:
363 363 if setting.active:
364 364 ui_data.append((setting.section, setting.key, setting.value))
365 365 config.append((
366 366 safe_str(setting.section), safe_str(setting.key),
367 367 safe_str(setting.value)))
368 368 if setting.key == 'push_ssl':
369 369 # force set push_ssl requirement to False, rhodecode
370 370 # handles that
371 371 config.append((
372 372 safe_str(setting.section), safe_str(setting.key), False))
373 373 log.debug(
374 374 'settings ui from db@repo[%s]: %s',
375 375 repo,
376 376 ','.join(map(lambda s: '[{}] {}={}'.format(*s), ui_data)))
377 377 if clear_session:
378 378 meta.Session.remove()
379 379
380 380 # TODO: mikhail: probably it makes no sense to re-read hooks information.
381 381 # It's already there and activated/deactivated
382 382 skip_entries = []
383 383 enabled_hook_classes = get_enabled_hook_classes(ui_settings)
384 384 if 'pull' not in enabled_hook_classes:
385 385 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRE_PULL))
386 386 if 'push' not in enabled_hook_classes:
387 387 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRE_PUSH))
388 388 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRETX_PUSH))
389 389 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PUSH_KEY))
390 390
391 391 config = [entry for entry in config if entry[:2] not in skip_entries]
392 392
393 393 return config
394 394
395 395
396 396 def make_db_config(clear_session=True, repo=None):
397 397 """
398 398 Create a :class:`Config` instance based on the values in the database.
399 399 """
400 400 config = Config()
401 401 config_data = config_data_from_db(clear_session=clear_session, repo=repo)
402 402 for section, option, value in config_data:
403 403 config.set(section, option, value)
404 404 return config
405 405
406 406
407 407 def get_enabled_hook_classes(ui_settings):
408 408 """
409 409 Return the enabled hook classes.
410 410
411 411 :param ui_settings: List of ui_settings as returned
412 412 by :meth:`VcsSettingsModel.get_ui_settings`
413 413
414 414 :return: a list with the enabled hook classes. The order is not guaranteed.
415 415 :rtype: list
416 416 """
417 417 enabled_hooks = []
418 418 active_hook_keys = [
419 419 key for section, key, value, active in ui_settings
420 420 if section == 'hooks' and active]
421 421
422 422 hook_names = {
423 423 RhodeCodeUi.HOOK_PUSH: 'push',
424 424 RhodeCodeUi.HOOK_PULL: 'pull',
425 425 RhodeCodeUi.HOOK_REPO_SIZE: 'repo_size'
426 426 }
427 427
428 428 for key in active_hook_keys:
429 429 hook = hook_names.get(key)
430 430 if hook:
431 431 enabled_hooks.append(hook)
432 432
433 433 return enabled_hooks
434 434
435 435
436 436 def set_rhodecode_config(config):
437 437 """
438 438 Updates pyramid config with new settings from database
439 439
440 440 :param config:
441 441 """
442 442 from rhodecode.model.settings import SettingsModel
443 443 app_settings = SettingsModel().get_all_settings()
444 444
445 445 for k, v in app_settings.items():
446 446 config[k] = v
447 447
448 448
449 449 def get_rhodecode_realm():
450 450 """
451 451 Return the rhodecode realm from database.
452 452 """
453 453 from rhodecode.model.settings import SettingsModel
454 454 realm = SettingsModel().get_setting_by_name('realm')
455 455 return safe_str(realm.app_settings_value)
456 456
457 457
458 458 def get_rhodecode_base_path():
459 459 """
460 460 Returns the base path. The base path is the filesystem path which points
461 461 to the repository store.
462 462 """
463 463 from rhodecode.model.settings import SettingsModel
464 464 paths_ui = SettingsModel().get_ui_by_section_and_key('paths', '/')
465 465 return safe_str(paths_ui.ui_value)
466 466
467 467
468 468 def map_groups(path):
469 469 """
470 470 Given a full path to a repository, create all nested groups that this
471 471 repo is inside. This function creates parent-child relationships between
472 472 groups and creates default perms for all new groups.
473 473
474 474 :param paths: full path to repository
475 475 """
476 476 from rhodecode.model.repo_group import RepoGroupModel
477 477 sa = meta.Session()
478 478 groups = path.split(Repository.NAME_SEP)
479 479 parent = None
480 480 group = None
481 481
482 482 # last element is repo in nested groups structure
483 483 groups = groups[:-1]
484 484 rgm = RepoGroupModel(sa)
485 485 owner = User.get_first_super_admin()
486 486 for lvl, group_name in enumerate(groups):
487 487 group_name = '/'.join(groups[:lvl] + [group_name])
488 488 group = RepoGroup.get_by_group_name(group_name)
489 489 desc = '%s group' % group_name
490 490
491 491 # skip folders that are now removed repos
492 492 if REMOVED_REPO_PAT.match(group_name):
493 493 break
494 494
495 495 if group is None:
496 496 log.debug('creating group level: %s group_name: %s',
497 497 lvl, group_name)
498 498 group = RepoGroup(group_name, parent)
499 499 group.group_description = desc
500 500 group.user = owner
501 501 sa.add(group)
502 502 perm_obj = rgm._create_default_perms(group)
503 503 sa.add(perm_obj)
504 504 sa.flush()
505 505
506 506 parent = group
507 507 return group
508 508
509 509
510 510 def repo2db_mapper(initial_repo_list, remove_obsolete=False):
511 511 """
512 512 maps all repos given in initial_repo_list, non existing repositories
513 513 are created, if remove_obsolete is True it also checks for db entries
514 514 that are not in initial_repo_list and removes them.
515 515
516 516 :param initial_repo_list: list of repositories found by scanning methods
517 517 :param remove_obsolete: check for obsolete entries in database
518 518 """
519 519 from rhodecode.model.repo import RepoModel
520 520 from rhodecode.model.repo_group import RepoGroupModel
521 521 from rhodecode.model.settings import SettingsModel
522 522
523 523 sa = meta.Session()
524 524 repo_model = RepoModel()
525 525 user = User.get_first_super_admin()
526 526 added = []
527 527
528 528 # creation defaults
529 529 defs = SettingsModel().get_default_repo_settings(strip_prefix=True)
530 530 enable_statistics = defs.get('repo_enable_statistics')
531 531 enable_locking = defs.get('repo_enable_locking')
532 532 enable_downloads = defs.get('repo_enable_downloads')
533 533 private = defs.get('repo_private')
534 534
535 535 for name, repo in initial_repo_list.items():
536 536 group = map_groups(name)
537 537 unicode_name = safe_unicode(name)
538 538 db_repo = repo_model.get_by_repo_name(unicode_name)
539 539 # found repo that is on filesystem not in RhodeCode database
540 540 if not db_repo:
541 541 log.info('repository %s not found, creating now', name)
542 542 added.append(name)
543 543 desc = (repo.description
544 544 if repo.description != 'unknown'
545 545 else '%s repository' % name)
546 546
547 547 db_repo = repo_model._create_repo(
548 548 repo_name=name,
549 549 repo_type=repo.alias,
550 550 description=desc,
551 551 repo_group=getattr(group, 'group_id', None),
552 552 owner=user,
553 553 enable_locking=enable_locking,
554 554 enable_downloads=enable_downloads,
555 555 enable_statistics=enable_statistics,
556 556 private=private,
557 557 state=Repository.STATE_CREATED
558 558 )
559 559 sa.commit()
560 560 # we added that repo just now, and make sure we updated server info
561 561 if db_repo.repo_type == 'git':
562 562 git_repo = db_repo.scm_instance()
563 563 # update repository server-info
564 564 log.debug('Running update server info')
565 565 git_repo._update_server_info()
566 566
567 567 db_repo.update_commit_cache()
568 568
569 569 config = db_repo._config
570 570 config.set('extensions', 'largefiles', '')
571 571 repo = db_repo.scm_instance(config=config)
572 572 repo.install_hooks()
573 573
574 574 removed = []
575 575 if remove_obsolete:
576 576 # remove from database those repositories that are not in the filesystem
577 577 for repo in sa.query(Repository).all():
578 578 if repo.repo_name not in initial_repo_list.keys():
579 579 log.debug("Removing non-existing repository found in db `%s`",
580 580 repo.repo_name)
581 581 try:
582 582 RepoModel(sa).delete(repo, forks='detach', fs_remove=False)
583 583 sa.commit()
584 584 removed.append(repo.repo_name)
585 585 except Exception:
586 586 # don't hold further removals on error
587 587 log.error(traceback.format_exc())
588 588 sa.rollback()
589 589
590 590 def splitter(full_repo_name):
591 591 _parts = full_repo_name.rsplit(RepoGroup.url_sep(), 1)
592 592 gr_name = None
593 593 if len(_parts) == 2:
594 594 gr_name = _parts[0]
595 595 return gr_name
596 596
597 597 initial_repo_group_list = [splitter(x) for x in
598 598 initial_repo_list.keys() if splitter(x)]
599 599
600 600 # remove from database those repository groups that are not in the
601 601 # filesystem due to parent child relationships we need to delete them
602 602 # in a specific order of most nested first
603 603 all_groups = [x.group_name for x in sa.query(RepoGroup).all()]
604 604 nested_sort = lambda gr: len(gr.split('/'))
605 605 for group_name in sorted(all_groups, key=nested_sort, reverse=True):
606 606 if group_name not in initial_repo_group_list:
607 607 repo_group = RepoGroup.get_by_group_name(group_name)
608 608 if (repo_group.children.all() or
609 609 not RepoGroupModel().check_exist_filesystem(
610 610 group_name=group_name, exc_on_failure=False)):
611 611 continue
612 612
613 613 log.info(
614 614 'Removing non-existing repository group found in db `%s`',
615 615 group_name)
616 616 try:
617 617 RepoGroupModel(sa).delete(group_name, fs_remove=False)
618 618 sa.commit()
619 619 removed.append(group_name)
620 620 except Exception:
621 621 # don't hold further removals on error
622 622 log.exception(
623 623 'Unable to remove repository group `%s`',
624 624 group_name)
625 625 sa.rollback()
626 626 raise
627 627
628 628 return added, removed
629 629
630 630
631 631 def load_rcextensions(root_path):
632 632 import rhodecode
633 633 from rhodecode.config import conf
634 634
635 635 path = os.path.join(root_path)
636 636 sys.path.append(path)
637 637
638 638 try:
639 639 rcextensions = __import__('rcextensions')
640 640 except ImportError:
641 641 if os.path.isdir(os.path.join(path, 'rcextensions')):
642 642 log.warn('Unable to load rcextensions from %s', path)
643 643 rcextensions = None
644 644
645 645 if rcextensions:
646 646 log.info('Loaded rcextensions from %s...', rcextensions)
647 647 rhodecode.EXTENSIONS = rcextensions
648 648
649 649 # Additional mappings that are not present in the pygments lexers
650 650 conf.LANGUAGES_EXTENSIONS_MAP.update(
651 651 getattr(rhodecode.EXTENSIONS, 'EXTRA_MAPPINGS', {}))
652 652
653 653
654 654 def get_custom_lexer(extension):
655 655 """
656 656 returns a custom lexer if it is defined in rcextensions module, or None
657 657 if there's no custom lexer defined
658 658 """
659 659 import rhodecode
660 660 from pygments import lexers
661 661
662 662 # custom override made by RhodeCode
663 663 if extension in ['mako']:
664 664 return lexers.get_lexer_by_name('html+mako')
665 665
666 666 # check if we didn't define this extension as other lexer
667 667 extensions = rhodecode.EXTENSIONS and getattr(rhodecode.EXTENSIONS, 'EXTRA_LEXERS', None)
668 668 if extensions and extension in rhodecode.EXTENSIONS.EXTRA_LEXERS:
669 669 _lexer_name = rhodecode.EXTENSIONS.EXTRA_LEXERS[extension]
670 670 return lexers.get_lexer_by_name(_lexer_name)
671 671
672 672
673 673 #==============================================================================
674 674 # TEST FUNCTIONS AND CREATORS
675 675 #==============================================================================
676 676 def create_test_index(repo_location, config):
677 677 """
678 678 Makes default test index.
679 679 """
680 680 import rc_testdata
681 681
682 682 rc_testdata.extract_search_index(
683 683 'vcs_search_index', os.path.dirname(config['search.location']))
684 684
685 685
686 686 def create_test_directory(test_path):
687 687 """
688 688 Create test directory if it doesn't exist.
689 689 """
690 690 if not os.path.isdir(test_path):
691 691 log.debug('Creating testdir %s', test_path)
692 692 os.makedirs(test_path)
693 693
694 694
695 695 def create_test_database(test_path, config):
696 696 """
697 697 Makes a fresh database.
698 698 """
699 699 from rhodecode.lib.db_manage import DbManage
700 700
701 701 # PART ONE create db
702 702 dbconf = config['sqlalchemy.db1.url']
703 703 log.debug('making test db %s', dbconf)
704 704
705 705 dbmanage = DbManage(log_sql=False, dbconf=dbconf, root=config['here'],
706 706 tests=True, cli_args={'force_ask': True})
707 707 dbmanage.create_tables(override=True)
708 708 dbmanage.set_db_version()
709 709 # for tests dynamically set new root paths based on generated content
710 710 dbmanage.create_settings(dbmanage.config_prompt(test_path))
711 711 dbmanage.create_default_user()
712 712 dbmanage.create_test_admin_and_users()
713 713 dbmanage.create_permissions()
714 714 dbmanage.populate_default_permissions()
715 715 Session().commit()
716 716
717 717
718 718 def create_test_repositories(test_path, config):
719 719 """
720 720 Creates test repositories in the temporary directory. Repositories are
721 721 extracted from archives within the rc_testdata package.
722 722 """
723 723 import rc_testdata
724 724 from rhodecode.tests import HG_REPO, GIT_REPO, SVN_REPO
725 725
726 726 log.debug('making test vcs repositories')
727 727
728 728 idx_path = config['search.location']
729 729 data_path = config['cache_dir']
730 730
731 731 # clean index and data
732 732 if idx_path and os.path.exists(idx_path):
733 733 log.debug('remove %s', idx_path)
734 734 shutil.rmtree(idx_path)
735 735
736 736 if data_path and os.path.exists(data_path):
737 737 log.debug('remove %s', data_path)
738 738 shutil.rmtree(data_path)
739 739
740 740 rc_testdata.extract_hg_dump('vcs_test_hg', jn(test_path, HG_REPO))
741 741 rc_testdata.extract_git_dump('vcs_test_git', jn(test_path, GIT_REPO))
742 742
743 743 # Note: Subversion is in the process of being integrated with the system,
744 744 # until we have a properly packed version of the test svn repository, this
745 745 # tries to copy over the repo from a package "rc_testdata"
746 746 svn_repo_path = rc_testdata.get_svn_repo_archive()
747 747 with tarfile.open(svn_repo_path) as tar:
748 748 tar.extractall(jn(test_path, SVN_REPO))
749 749
750 750
751 751 def password_changed(auth_user, session):
752 752 # Never report password change in case of default user or anonymous user.
753 753 if auth_user.username == User.DEFAULT_USER or auth_user.user_id is None:
754 754 return False
755 755
756 756 password_hash = md5(auth_user.password) if auth_user.password else None
757 757 rhodecode_user = session.get('rhodecode_user', {})
758 758 session_password_hash = rhodecode_user.get('password', '')
759 759 return password_hash != session_password_hash
760 760
761 761
762 762 def read_opensource_licenses():
763 763 global _license_cache
764 764
765 765 if not _license_cache:
766 766 licenses = pkg_resources.resource_string(
767 767 'rhodecode', 'config/licenses.json')
768 768 _license_cache = json.loads(licenses)
769 769
770 770 return _license_cache
771 771
772 772
773 773 def generate_platform_uuid():
774 774 """
775 775 Generates platform UUID based on it's name
776 776 """
777 777 import platform
778 778
779 779 try:
780 780 uuid_list = [platform.platform()]
781 781 return hashlib.sha256(':'.join(uuid_list)).hexdigest()
782 782 except Exception as e:
783 783 log.error('Failed to generate host uuid: %s', e)
784 784 return 'UNDEFINED'
785 785
786 786
787 787 def send_test_email(recipients, email_body='TEST EMAIL'):
788 788 """
789 789 Simple code for generating test emails.
790 790 Usage::
791 791
792 792 from rhodecode.lib import utils
793 793 utils.send_test_email()
794 794 """
795 795 from rhodecode.lib.celerylib import tasks, run_task
796 796
797 797 email_body = email_body_plaintext = email_body
798 798 subject = 'SUBJECT FROM: {}'.format(socket.gethostname())
799 799 tasks.send_email(recipients, subject, email_body_plaintext, email_body)
@@ -1,1045 +1,1045 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21
22 22 """
23 23 Some simple helper functions
24 24 """
25 25
26 26 import collections
27 27 import datetime
28 28 import dateutil.relativedelta
29 29 import logging
30 30 import re
31 31 import sys
32 32 import time
33 33 import urllib.request, urllib.parse, urllib.error
34 34 import urlobject
35 35 import uuid
36 36 import getpass
37 37 import socket
38 38 import errno
39 39 import random
40 40 from functools import update_wrapper, partial, wraps
41 41 from contextlib import closing
42 42
43 43 import pygments.lexers
44 44 import sqlalchemy
45 45 import sqlalchemy.engine.url
46 46 import sqlalchemy.exc
47 47 import sqlalchemy.sql
48 48 import webob
49 49 import pyramid.threadlocal
50 50 from pyramid.settings import asbool
51 51
52 52 import rhodecode
53 53 from rhodecode.translation import _, _pluralize
54 54 from rhodecode.lib.str_utils import safe_str, safe_int, safe_bytes
55 55 from rhodecode.lib.hash_utils import md5, md5_safe, sha1, sha1_safe
56 56 from rhodecode.lib.type_utils import aslist, str2bool
57 57 from functools import reduce
58 58
59 59 #TODO: there's no longer safe_unicode, we mock it now, but should remove it
60 60 safe_unicode = safe_str
61 61
62 62
63 63 def __get_lem(extra_mapping=None):
64 64 """
65 65 Get language extension map based on what's inside pygments lexers
66 66 """
67 67 d = collections.defaultdict(lambda: [])
68 68
69 69 def __clean(s):
70 70 s = s.lstrip('*')
71 71 s = s.lstrip('.')
72 72
73 73 if s.find('[') != -1:
74 74 exts = []
75 75 start, stop = s.find('['), s.find(']')
76 76
77 77 for suffix in s[start + 1:stop]:
78 78 exts.append(s[:s.find('[')] + suffix)
79 79 return [e.lower() for e in exts]
80 80 else:
81 81 return [s.lower()]
82 82
83 83 for lx, t in sorted(pygments.lexers.LEXERS.items()):
84 84 m = list(map(__clean, t[-2]))
85 85 if m:
86 86 m = reduce(lambda x, y: x + y, m)
87 87 for ext in m:
88 88 desc = lx.replace('Lexer', '')
89 89 d[ext].append(desc)
90 90
91 91 data = dict(d)
92 92
93 93 extra_mapping = extra_mapping or {}
94 94 if extra_mapping:
95 95 for k, v in extra_mapping.items():
96 96 if k not in data:
97 97 # register new mapping2lexer
98 98 data[k] = [v]
99 99
100 100 return data
101 101
102 102
103 103 def convert_line_endings(line, mode):
104 104 """
105 105 Converts a given line "line end" accordingly to given mode
106 106
107 107 Available modes are::
108 108 0 - Unix
109 109 1 - Mac
110 110 2 - DOS
111 111
112 112 :param line: given line to convert
113 113 :param mode: mode to convert to
114 114 :rtype: str
115 115 :return: converted line according to mode
116 116 """
117 117 if mode == 0:
118 118 line = line.replace('\r\n', '\n')
119 119 line = line.replace('\r', '\n')
120 120 elif mode == 1:
121 121 line = line.replace('\r\n', '\r')
122 122 line = line.replace('\n', '\r')
123 123 elif mode == 2:
124 124 line = re.sub('\r(?!\n)|(?<!\r)\n', '\r\n', line)
125 125 return line
126 126
127 127
128 128 def detect_mode(line, default):
129 129 """
130 130 Detects line break for given line, if line break couldn't be found
131 131 given default value is returned
132 132
133 133 :param line: str line
134 134 :param default: default
135 135 :rtype: int
136 136 :return: value of line end on of 0 - Unix, 1 - Mac, 2 - DOS
137 137 """
138 138 if line.endswith('\r\n'):
139 139 return 2
140 140 elif line.endswith('\n'):
141 141 return 0
142 142 elif line.endswith('\r'):
143 143 return 1
144 144 else:
145 145 return default
146 146
147 147
148 148 def remove_suffix(s, suffix):
149 149 if s.endswith(suffix):
150 150 s = s[:-1 * len(suffix)]
151 151 return s
152 152
153 153
154 154 def remove_prefix(s, prefix):
155 155 if s.startswith(prefix):
156 156 s = s[len(prefix):]
157 157 return s
158 158
159 159
160 160 def find_calling_context(ignore_modules=None):
161 161 """
162 162 Look through the calling stack and return the frame which called
163 163 this function and is part of core module ( ie. rhodecode.* )
164 164
165 165 :param ignore_modules: list of modules to ignore eg. ['rhodecode.lib']
166 166
167 167 usage::
168 168 from rhodecode.lib.utils2 import find_calling_context
169 169
170 170 calling_context = find_calling_context(ignore_modules=[
171 171 'rhodecode.lib.caching_query',
172 172 'rhodecode.model.settings',
173 173 ])
174 174
175 175 if calling_context:
176 176 cc_str = 'call context %s:%s' % (
177 177 calling_context.f_code.co_filename,
178 178 calling_context.f_lineno,
179 179 )
180 180 print(cc_str)
181 181 """
182 182
183 183 ignore_modules = ignore_modules or []
184 184
185 185 f = sys._getframe(2)
186 186 while f.f_back is not None:
187 187 name = f.f_globals.get('__name__')
188 188 if name and name.startswith(__name__.split('.')[0]):
189 189 if name not in ignore_modules:
190 190 return f
191 191 f = f.f_back
192 192 return None
193 193
194 194
195 195 def ping_connection(connection, branch):
196 196 if branch:
197 197 # "branch" refers to a sub-connection of a connection,
198 198 # we don't want to bother pinging on these.
199 199 return
200 200
201 201 # turn off "close with result". This flag is only used with
202 202 # "connectionless" execution, otherwise will be False in any case
203 203 save_should_close_with_result = connection.should_close_with_result
204 204 connection.should_close_with_result = False
205 205
206 206 try:
207 207 # run a SELECT 1. use a core select() so that
208 208 # the SELECT of a scalar value without a table is
209 209 # appropriately formatted for the backend
210 210 connection.scalar(sqlalchemy.sql.select([1]))
211 211 except sqlalchemy.exc.DBAPIError as err:
212 212 # catch SQLAlchemy's DBAPIError, which is a wrapper
213 213 # for the DBAPI's exception. It includes a .connection_invalidated
214 214 # attribute which specifies if this connection is a "disconnect"
215 215 # condition, which is based on inspection of the original exception
216 216 # by the dialect in use.
217 217 if err.connection_invalidated:
218 218 # run the same SELECT again - the connection will re-validate
219 219 # itself and establish a new connection. The disconnect detection
220 220 # here also causes the whole connection pool to be invalidated
221 221 # so that all stale connections are discarded.
222 222 connection.scalar(sqlalchemy.sql.select([1]))
223 223 else:
224 224 raise
225 225 finally:
226 226 # restore "close with result"
227 227 connection.should_close_with_result = save_should_close_with_result
228 228
229 229
230 230 def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs):
231 231 """Custom engine_from_config functions."""
232 232 log = logging.getLogger('sqlalchemy.engine')
233 233 use_ping_connection = asbool(configuration.pop('sqlalchemy.db1.ping_connection', None))
234 234 debug = asbool(configuration.pop('sqlalchemy.db1.debug_query', None))
235 235
236 236 engine = sqlalchemy.engine_from_config(configuration, prefix, **kwargs)
237 237
238 238 def color_sql(sql):
239 239 color_seq = '\033[1;33m' # This is yellow: code 33
240 240 normal = '\x1b[0m'
241 241 return ''.join([color_seq, sql, normal])
242 242
243 243 if use_ping_connection:
244 244 log.debug('Adding ping_connection on the engine config.')
245 245 sqlalchemy.event.listen(engine, "engine_connect", ping_connection)
246 246
247 247 if debug:
248 248 # attach events only for debug configuration
249 249 def before_cursor_execute(conn, cursor, statement,
250 250 parameters, context, executemany):
251 251 setattr(conn, 'query_start_time', time.time())
252 252 log.info(color_sql(">>>>> STARTING QUERY >>>>>"))
253 253 calling_context = find_calling_context(ignore_modules=[
254 254 'rhodecode.lib.caching_query',
255 255 'rhodecode.model.settings',
256 256 ])
257 257 if calling_context:
258 258 log.info(color_sql('call context %s:%s' % (
259 259 calling_context.f_code.co_filename,
260 260 calling_context.f_lineno,
261 261 )))
262 262
263 263 def after_cursor_execute(conn, cursor, statement,
264 264 parameters, context, executemany):
265 265 delattr(conn, 'query_start_time')
266 266
267 267 sqlalchemy.event.listen(engine, "before_cursor_execute", before_cursor_execute)
268 268 sqlalchemy.event.listen(engine, "after_cursor_execute", after_cursor_execute)
269 269
270 270 return engine
271 271
272 272
273 273 def get_encryption_key(config):
274 274 secret = config.get('rhodecode.encrypted_values.secret')
275 275 default = config['beaker.session.secret']
276 276 return secret or default
277 277
278 278
279 279 def age(prevdate, now=None, show_short_version=False, show_suffix=True, short_format=False):
280 280 """
281 281 Turns a datetime into an age string.
282 282 If show_short_version is True, this generates a shorter string with
283 283 an approximate age; ex. '1 day ago', rather than '1 day and 23 hours ago'.
284 284
285 285 * IMPORTANT*
286 286 Code of this function is written in special way so it's easier to
287 287 backport it to javascript. If you mean to update it, please also update
288 288 `jquery.timeago-extension.js` file
289 289
290 290 :param prevdate: datetime object
291 291 :param now: get current time, if not define we use
292 292 `datetime.datetime.now()`
293 293 :param show_short_version: if it should approximate the date and
294 294 return a shorter string
295 295 :param show_suffix:
296 296 :param short_format: show short format, eg 2D instead of 2 days
297 297 :rtype: unicode
298 298 :returns: unicode words describing age
299 299 """
300 300
301 301 def _get_relative_delta(now, prevdate):
302 302 base = dateutil.relativedelta.relativedelta(now, prevdate)
303 303 return {
304 304 'year': base.years,
305 305 'month': base.months,
306 306 'day': base.days,
307 307 'hour': base.hours,
308 308 'minute': base.minutes,
309 309 'second': base.seconds,
310 310 }
311 311
312 312 def _is_leap_year(year):
313 313 return year % 4 == 0 and (year % 100 != 0 or year % 400 == 0)
314 314
315 315 def get_month(prevdate):
316 316 return prevdate.month
317 317
318 318 def get_year(prevdate):
319 319 return prevdate.year
320 320
321 321 now = now or datetime.datetime.now()
322 322 order = ['year', 'month', 'day', 'hour', 'minute', 'second']
323 323 deltas = {}
324 324 future = False
325 325
326 326 if prevdate > now:
327 327 now_old = now
328 328 now = prevdate
329 329 prevdate = now_old
330 330 future = True
331 331 if future:
332 332 prevdate = prevdate.replace(microsecond=0)
333 333 # Get date parts deltas
334 334 for part in order:
335 335 rel_delta = _get_relative_delta(now, prevdate)
336 336 deltas[part] = rel_delta[part]
337 337
338 338 # Fix negative offsets (there is 1 second between 10:59:59 and 11:00:00,
339 339 # not 1 hour, -59 minutes and -59 seconds)
340 340 offsets = [[5, 60], [4, 60], [3, 24]]
341 341 for element in offsets: # seconds, minutes, hours
342 342 num = element[0]
343 343 length = element[1]
344 344
345 345 part = order[num]
346 346 carry_part = order[num - 1]
347 347
348 348 if deltas[part] < 0:
349 349 deltas[part] += length
350 350 deltas[carry_part] -= 1
351 351
352 352 # Same thing for days except that the increment depends on the (variable)
353 353 # number of days in the month
354 354 month_lengths = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
355 355 if deltas['day'] < 0:
356 356 if get_month(prevdate) == 2 and _is_leap_year(get_year(prevdate)):
357 357 deltas['day'] += 29
358 358 else:
359 359 deltas['day'] += month_lengths[get_month(prevdate) - 1]
360 360
361 361 deltas['month'] -= 1
362 362
363 363 if deltas['month'] < 0:
364 364 deltas['month'] += 12
365 365 deltas['year'] -= 1
366 366
367 367 # Format the result
368 368 if short_format:
369 369 fmt_funcs = {
370 'year': lambda d: u'%dy' % d,
371 'month': lambda d: u'%dm' % d,
372 'day': lambda d: u'%dd' % d,
373 'hour': lambda d: u'%dh' % d,
374 'minute': lambda d: u'%dmin' % d,
375 'second': lambda d: u'%dsec' % d,
370 'year': lambda d: '%dy' % d,
371 'month': lambda d: '%dm' % d,
372 'day': lambda d: '%dd' % d,
373 'hour': lambda d: '%dh' % d,
374 'minute': lambda d: '%dmin' % d,
375 'second': lambda d: '%dsec' % d,
376 376 }
377 377 else:
378 378 fmt_funcs = {
379 'year': lambda d: _pluralize(u'${num} year', u'${num} years', d, mapping={'num': d}).interpolate(),
380 'month': lambda d: _pluralize(u'${num} month', u'${num} months', d, mapping={'num': d}).interpolate(),
381 'day': lambda d: _pluralize(u'${num} day', u'${num} days', d, mapping={'num': d}).interpolate(),
382 'hour': lambda d: _pluralize(u'${num} hour', u'${num} hours', d, mapping={'num': d}).interpolate(),
383 'minute': lambda d: _pluralize(u'${num} minute', u'${num} minutes', d, mapping={'num': d}).interpolate(),
384 'second': lambda d: _pluralize(u'${num} second', u'${num} seconds', d, mapping={'num': d}).interpolate(),
379 'year': lambda d: _pluralize('${num} year', '${num} years', d, mapping={'num': d}).interpolate(),
380 'month': lambda d: _pluralize('${num} month', '${num} months', d, mapping={'num': d}).interpolate(),
381 'day': lambda d: _pluralize('${num} day', '${num} days', d, mapping={'num': d}).interpolate(),
382 'hour': lambda d: _pluralize('${num} hour', '${num} hours', d, mapping={'num': d}).interpolate(),
383 'minute': lambda d: _pluralize('${num} minute', '${num} minutes', d, mapping={'num': d}).interpolate(),
384 'second': lambda d: _pluralize('${num} second', '${num} seconds', d, mapping={'num': d}).interpolate(),
385 385 }
386 386
387 387 i = 0
388 388 for part in order:
389 389 value = deltas[part]
390 390 if value != 0:
391 391
392 392 if i < 5:
393 393 sub_part = order[i + 1]
394 394 sub_value = deltas[sub_part]
395 395 else:
396 396 sub_value = 0
397 397
398 398 if sub_value == 0 or show_short_version:
399 399 _val = fmt_funcs[part](value)
400 400 if future:
401 401 if show_suffix:
402 return _(u'in ${ago}', mapping={'ago': _val})
402 return _('in ${ago}', mapping={'ago': _val})
403 403 else:
404 404 return _(_val)
405 405
406 406 else:
407 407 if show_suffix:
408 return _(u'${ago} ago', mapping={'ago': _val})
408 return _('${ago} ago', mapping={'ago': _val})
409 409 else:
410 410 return _(_val)
411 411
412 412 val = fmt_funcs[part](value)
413 413 val_detail = fmt_funcs[sub_part](sub_value)
414 414 mapping = {'val': val, 'detail': val_detail}
415 415
416 416 if short_format:
417 datetime_tmpl = _(u'${val}, ${detail}', mapping=mapping)
417 datetime_tmpl = _('${val}, ${detail}', mapping=mapping)
418 418 if show_suffix:
419 datetime_tmpl = _(u'${val}, ${detail} ago', mapping=mapping)
419 datetime_tmpl = _('${val}, ${detail} ago', mapping=mapping)
420 420 if future:
421 datetime_tmpl = _(u'in ${val}, ${detail}', mapping=mapping)
421 datetime_tmpl = _('in ${val}, ${detail}', mapping=mapping)
422 422 else:
423 datetime_tmpl = _(u'${val} and ${detail}', mapping=mapping)
423 datetime_tmpl = _('${val} and ${detail}', mapping=mapping)
424 424 if show_suffix:
425 datetime_tmpl = _(u'${val} and ${detail} ago', mapping=mapping)
425 datetime_tmpl = _('${val} and ${detail} ago', mapping=mapping)
426 426 if future:
427 datetime_tmpl = _(u'in ${val} and ${detail}', mapping=mapping)
427 datetime_tmpl = _('in ${val} and ${detail}', mapping=mapping)
428 428
429 429 return datetime_tmpl
430 430 i += 1
431 return _(u'just now')
431 return _('just now')
432 432
433 433
434 434 def age_from_seconds(seconds):
435 435 seconds = safe_int(seconds) or 0
436 436 prevdate = time_to_datetime(time.time() + seconds)
437 437 return age(prevdate, show_suffix=False, show_short_version=True)
438 438
439 439
440 440 def cleaned_uri(uri):
441 441 """
442 442 Quotes '[' and ']' from uri if there is only one of them.
443 443 according to RFC3986 we cannot use such chars in uri
444 444 :param uri:
445 445 :return: uri without this chars
446 446 """
447 447 return urllib.parse.quote(uri, safe='@$:/')
448 448
449 449
450 450 def credentials_filter(uri):
451 451 """
452 452 Returns a url with removed credentials
453 453
454 454 :param uri:
455 455 """
456 456 import urlobject
457 457 if isinstance(uri, rhodecode.lib.encrypt.InvalidDecryptedValue):
458 458 return 'InvalidDecryptionKey'
459 459
460 460 url_obj = urlobject.URLObject(cleaned_uri(uri))
461 461 url_obj = url_obj.without_password().without_username()
462 462
463 463 return url_obj
464 464
465 465
466 466 def get_host_info(request):
467 467 """
468 468 Generate host info, to obtain full url e.g https://server.com
469 469 use this
470 470 `{scheme}://{netloc}`
471 471 """
472 472 if not request:
473 473 return {}
474 474
475 475 qualified_home_url = request.route_url('home')
476 476 parsed_url = urlobject.URLObject(qualified_home_url)
477 477 decoded_path = safe_unicode(urllib.parse.unquote(parsed_url.path.rstrip('/')))
478 478
479 479 return {
480 480 'scheme': parsed_url.scheme,
481 481 'netloc': parsed_url.netloc+decoded_path,
482 482 'hostname': parsed_url.hostname,
483 483 }
484 484
485 485
486 486 def get_clone_url(request, uri_tmpl, repo_name, repo_id, repo_type, **override):
487 487 qualified_home_url = request.route_url('home')
488 488 parsed_url = urlobject.URLObject(qualified_home_url)
489 489 decoded_path = safe_unicode(urllib.parse.unquote(parsed_url.path.rstrip('/')))
490 490
491 491 args = {
492 492 'scheme': parsed_url.scheme,
493 493 'user': '',
494 494 'sys_user': getpass.getuser(),
495 495 # path if we use proxy-prefix
496 496 'netloc': parsed_url.netloc+decoded_path,
497 497 'hostname': parsed_url.hostname,
498 498 'prefix': decoded_path,
499 499 'repo': repo_name,
500 500 'repoid': str(repo_id),
501 501 'repo_type': repo_type
502 502 }
503 503 args.update(override)
504 504 args['user'] = urllib.parse.quote(safe_str(args['user']))
505 505
506 506 for k, v in args.items():
507 507 uri_tmpl = uri_tmpl.replace('{%s}' % k, v)
508 508
509 509 # special case for SVN clone url
510 510 if repo_type == 'svn':
511 511 uri_tmpl = uri_tmpl.replace('ssh://', 'svn+ssh://')
512 512
513 513 # remove leading @ sign if it's present. Case of empty user
514 514 url_obj = urlobject.URLObject(uri_tmpl)
515 515 url = url_obj.with_netloc(url_obj.netloc.lstrip('@'))
516 516
517 517 return safe_unicode(url)
518 518
519 519
520 520 def get_commit_safe(repo, commit_id=None, commit_idx=None, pre_load=None,
521 521 maybe_unreachable=False, reference_obj=None):
522 522 """
523 523 Safe version of get_commit if this commit doesn't exists for a
524 524 repository it returns a Dummy one instead
525 525
526 526 :param repo: repository instance
527 527 :param commit_id: commit id as str
528 528 :param commit_idx: numeric commit index
529 529 :param pre_load: optional list of commit attributes to load
530 530 :param maybe_unreachable: translate unreachable commits on git repos
531 531 :param reference_obj: explicitly search via a reference obj in git. E.g "branch:123" would mean branch "123"
532 532 """
533 533 # TODO(skreft): remove these circular imports
534 534 from rhodecode.lib.vcs.backends.base import BaseRepository, EmptyCommit
535 535 from rhodecode.lib.vcs.exceptions import RepositoryError
536 536 if not isinstance(repo, BaseRepository):
537 537 raise Exception('You must pass an Repository '
538 538 'object as first argument got %s', type(repo))
539 539
540 540 try:
541 541 commit = repo.get_commit(
542 542 commit_id=commit_id, commit_idx=commit_idx, pre_load=pre_load,
543 543 maybe_unreachable=maybe_unreachable, reference_obj=reference_obj)
544 544 except (RepositoryError, LookupError):
545 545 commit = EmptyCommit()
546 546 return commit
547 547
548 548
549 549 def datetime_to_time(dt):
550 550 if dt:
551 551 return time.mktime(dt.timetuple())
552 552
553 553
554 554 def time_to_datetime(tm):
555 555 if tm:
556 556 if isinstance(tm, str):
557 557 try:
558 558 tm = float(tm)
559 559 except ValueError:
560 560 return
561 561 return datetime.datetime.fromtimestamp(tm)
562 562
563 563
564 564 def time_to_utcdatetime(tm):
565 565 if tm:
566 566 if isinstance(tm, str):
567 567 try:
568 568 tm = float(tm)
569 569 except ValueError:
570 570 return
571 571 return datetime.datetime.utcfromtimestamp(tm)
572 572
573 573
574 574 MENTIONS_REGEX = re.compile(
575 575 # ^@ or @ without any special chars in front
576 576 r'(?:^@|[^a-zA-Z0-9\-\_\.]@)'
577 577 # main body starts with letter, then can be . - _
578 578 r'([a-zA-Z0-9]{1}[a-zA-Z0-9\-\_\.]+)',
579 579 re.VERBOSE | re.MULTILINE)
580 580
581 581
582 582 def extract_mentioned_users(s):
583 583 """
584 584 Returns unique usernames from given string s that have @mention
585 585
586 586 :param s: string to get mentions
587 587 """
588 588 usrs = set()
589 589 for username in MENTIONS_REGEX.findall(s):
590 590 usrs.add(username)
591 591
592 592 return sorted(list(usrs), key=lambda k: k.lower())
593 593
594 594
595 595 class AttributeDictBase(dict):
596 596 def __getstate__(self):
597 597 odict = self.__dict__ # get attribute dictionary
598 598 return odict
599 599
600 600 def __setstate__(self, dict):
601 601 self.__dict__ = dict
602 602
603 603 __setattr__ = dict.__setitem__
604 604 __delattr__ = dict.__delitem__
605 605
606 606
607 607 class StrictAttributeDict(AttributeDictBase):
608 608 """
609 609 Strict Version of Attribute dict which raises an Attribute error when
610 610 requested attribute is not set
611 611 """
612 612 def __getattr__(self, attr):
613 613 try:
614 614 return self[attr]
615 615 except KeyError:
616 616 raise AttributeError('%s object has no attribute %s' % (
617 617 self.__class__, attr))
618 618
619 619
620 620 class AttributeDict(AttributeDictBase):
621 621 def __getattr__(self, attr):
622 622 return self.get(attr, None)
623 623
624 624
625 625 def fix_PATH(os_=None):
626 626 """
627 627 Get current active python path, and append it to PATH variable to fix
628 628 issues of subprocess calls and different python versions
629 629 """
630 630 if os_ is None:
631 631 import os
632 632 else:
633 633 os = os_
634 634
635 635 cur_path = os.path.split(sys.executable)[0]
636 636 if not os.environ['PATH'].startswith(cur_path):
637 637 os.environ['PATH'] = '%s:%s' % (cur_path, os.environ['PATH'])
638 638
639 639
640 640 def obfuscate_url_pw(engine):
641 641 _url = engine or ''
642 642 try:
643 643 _url = sqlalchemy.engine.url.make_url(engine)
644 644 if _url.password:
645 645 _url.password = 'XXXXX'
646 646 except Exception:
647 647 pass
648 648 return str(_url)
649 649
650 650
651 651 def get_server_url(environ):
652 652 req = webob.Request(environ)
653 653 return req.host_url + req.script_name
654 654
655 655
656 656 def unique_id(hexlen=32):
657 657 alphabet = "23456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghjklmnpqrstuvwxyz"
658 658 return suuid(truncate_to=hexlen, alphabet=alphabet)
659 659
660 660
661 661 def suuid(url=None, truncate_to=22, alphabet=None):
662 662 """
663 663 Generate and return a short URL safe UUID.
664 664
665 665 If the url parameter is provided, set the namespace to the provided
666 666 URL and generate a UUID.
667 667
668 668 :param url to get the uuid for
669 669 :truncate_to: truncate the basic 22 UUID to shorter version
670 670
671 671 The IDs won't be universally unique any longer, but the probability of
672 672 a collision will still be very low.
673 673 """
674 674 # Define our alphabet.
675 675 _ALPHABET = alphabet or "23456789ABCDEFGHJKLMNPQRSTUVWXYZ"
676 676
677 677 # If no URL is given, generate a random UUID.
678 678 if url is None:
679 679 unique_id = uuid.uuid4().int
680 680 else:
681 681 unique_id = uuid.uuid3(uuid.NAMESPACE_URL, url).int
682 682
683 683 alphabet_length = len(_ALPHABET)
684 684 output = []
685 685 while unique_id > 0:
686 686 digit = unique_id % alphabet_length
687 687 output.append(_ALPHABET[digit])
688 688 unique_id = int(unique_id / alphabet_length)
689 689 return "".join(output)[:truncate_to]
690 690
691 691
692 692 def get_current_rhodecode_user(request=None):
693 693 """
694 694 Gets rhodecode user from request
695 695 """
696 696 pyramid_request = request or pyramid.threadlocal.get_current_request()
697 697
698 698 # web case
699 699 if pyramid_request and hasattr(pyramid_request, 'user'):
700 700 return pyramid_request.user
701 701
702 702 # api case
703 703 if pyramid_request and hasattr(pyramid_request, 'rpc_user'):
704 704 return pyramid_request.rpc_user
705 705
706 706 return None
707 707
708 708
709 709 def action_logger_generic(action, namespace=''):
710 710 """
711 711 A generic logger for actions useful to the system overview, tries to find
712 712 an acting user for the context of the call otherwise reports unknown user
713 713
714 714 :param action: logging message eg 'comment 5 deleted'
715 715 :param type: string
716 716
717 717 :param namespace: namespace of the logging message eg. 'repo.comments'
718 718 :param type: string
719 719
720 720 """
721 721
722 722 logger_name = 'rhodecode.actions'
723 723
724 724 if namespace:
725 725 logger_name += '.' + namespace
726 726
727 727 log = logging.getLogger(logger_name)
728 728
729 729 # get a user if we can
730 730 user = get_current_rhodecode_user()
731 731
732 732 logfunc = log.info
733 733
734 734 if not user:
735 735 user = '<unknown user>'
736 736 logfunc = log.warning
737 737
738 738 logfunc('Logging action by {}: {}'.format(user, action))
739 739
740 740
741 741 def escape_split(text, sep=',', maxsplit=-1):
742 742 r"""
743 743 Allows for escaping of the separator: e.g. arg='foo\, bar'
744 744
745 745 It should be noted that the way bash et. al. do command line parsing, those
746 746 single quotes are required.
747 747 """
748 748 escaped_sep = r'\%s' % sep
749 749
750 750 if escaped_sep not in text:
751 751 return text.split(sep, maxsplit)
752 752
753 753 before, _mid, after = text.partition(escaped_sep)
754 754 startlist = before.split(sep, maxsplit) # a regular split is fine here
755 755 unfinished = startlist[-1]
756 756 startlist = startlist[:-1]
757 757
758 758 # recurse because there may be more escaped separators
759 759 endlist = escape_split(after, sep, maxsplit)
760 760
761 761 # finish building the escaped value. we use endlist[0] becaue the first
762 762 # part of the string sent in recursion is the rest of the escaped value.
763 763 unfinished += sep + endlist[0]
764 764
765 765 return startlist + [unfinished] + endlist[1:] # put together all the parts
766 766
767 767
768 768 class OptionalAttr(object):
769 769 """
770 770 Special Optional Option that defines other attribute. Example::
771 771
772 772 def test(apiuser, userid=Optional(OAttr('apiuser')):
773 773 user = Optional.extract(userid)
774 774 # calls
775 775
776 776 """
777 777
778 778 def __init__(self, attr_name):
779 779 self.attr_name = attr_name
780 780
781 781 def __repr__(self):
782 782 return '<OptionalAttr:%s>' % self.attr_name
783 783
784 784 def __call__(self):
785 785 return self
786 786
787 787
788 788 # alias
789 789 OAttr = OptionalAttr
790 790
791 791
792 792 class Optional(object):
793 793 """
794 794 Defines an optional parameter::
795 795
796 796 param = param.getval() if isinstance(param, Optional) else param
797 797 param = param() if isinstance(param, Optional) else param
798 798
799 799 is equivalent of::
800 800
801 801 param = Optional.extract(param)
802 802
803 803 """
804 804
805 805 def __init__(self, type_):
806 806 self.type_ = type_
807 807
808 808 def __repr__(self):
809 809 return '<Optional:%s>' % self.type_.__repr__()
810 810
811 811 def __call__(self):
812 812 return self.getval()
813 813
814 814 def getval(self):
815 815 """
816 816 returns value from this Optional instance
817 817 """
818 818 if isinstance(self.type_, OAttr):
819 819 # use params name
820 820 return self.type_.attr_name
821 821 return self.type_
822 822
823 823 @classmethod
824 824 def extract(cls, val):
825 825 """
826 826 Extracts value from Optional() instance
827 827
828 828 :param val:
829 829 :return: original value if it's not Optional instance else
830 830 value of instance
831 831 """
832 832 if isinstance(val, cls):
833 833 return val.getval()
834 834 return val
835 835
836 836
837 837 def glob2re(pat):
838 838 """
839 839 Translate a shell PATTERN to a regular expression.
840 840
841 841 There is no way to quote meta-characters.
842 842 """
843 843
844 844 i, n = 0, len(pat)
845 845 res = ''
846 846 while i < n:
847 847 c = pat[i]
848 848 i = i+1
849 849 if c == '*':
850 850 #res = res + '.*'
851 851 res = res + '[^/]*'
852 852 elif c == '?':
853 853 #res = res + '.'
854 854 res = res + '[^/]'
855 855 elif c == '[':
856 856 j = i
857 857 if j < n and pat[j] == '!':
858 858 j = j+1
859 859 if j < n and pat[j] == ']':
860 860 j = j+1
861 861 while j < n and pat[j] != ']':
862 862 j = j+1
863 863 if j >= n:
864 864 res = res + '\\['
865 865 else:
866 866 stuff = pat[i:j].replace('\\','\\\\')
867 867 i = j+1
868 868 if stuff[0] == '!':
869 869 stuff = '^' + stuff[1:]
870 870 elif stuff[0] == '^':
871 871 stuff = '\\' + stuff
872 872 res = '%s[%s]' % (res, stuff)
873 873 else:
874 874 res = res + re.escape(c)
875 875 return res + '\Z(?ms)'
876 876
877 877
878 878 def parse_byte_string(size_str):
879 879 match = re.match(r'(\d+)(MB|KB)', size_str, re.IGNORECASE)
880 880 if not match:
881 881 raise ValueError('Given size:%s is invalid, please make sure '
882 882 'to use format of <num>(MB|KB)' % size_str)
883 883
884 884 _parts = match.groups()
885 885 num, type_ = _parts
886 886 return int(num) * {'mb': 1024*1024, 'kb': 1024}[type_.lower()]
887 887
888 888
889 889 class CachedProperty(object):
890 890 """
891 891 Lazy Attributes. With option to invalidate the cache by running a method
892 892
893 893 >>> class Foo(object):
894 894 ...
895 895 ... @CachedProperty
896 896 ... def heavy_func(self):
897 897 ... return 'super-calculation'
898 898 ...
899 899 ... foo = Foo()
900 900 ... foo.heavy_func() # first computation
901 901 ... foo.heavy_func() # fetch from cache
902 902 ... foo._invalidate_prop_cache('heavy_func')
903 903
904 904 # at this point calling foo.heavy_func() will be re-computed
905 905 """
906 906
907 907 def __init__(self, func, func_name=None):
908 908
909 909 if func_name is None:
910 910 func_name = func.__name__
911 911 self.data = (func, func_name)
912 912 update_wrapper(self, func)
913 913
914 914 def __get__(self, inst, class_):
915 915 if inst is None:
916 916 return self
917 917
918 918 func, func_name = self.data
919 919 value = func(inst)
920 920 inst.__dict__[func_name] = value
921 921 if '_invalidate_prop_cache' not in inst.__dict__:
922 922 inst.__dict__['_invalidate_prop_cache'] = partial(
923 923 self._invalidate_prop_cache, inst)
924 924 return value
925 925
926 926 def _invalidate_prop_cache(self, inst, name):
927 927 inst.__dict__.pop(name, None)
928 928
929 929
930 930 def retry(func=None, exception=Exception, n_tries=5, delay=5, backoff=1, logger=True):
931 931 """
932 932 Retry decorator with exponential backoff.
933 933
934 934 Parameters
935 935 ----------
936 936 func : typing.Callable, optional
937 937 Callable on which the decorator is applied, by default None
938 938 exception : Exception or tuple of Exceptions, optional
939 939 Exception(s) that invoke retry, by default Exception
940 940 n_tries : int, optional
941 941 Number of tries before giving up, by default 5
942 942 delay : int, optional
943 943 Initial delay between retries in seconds, by default 5
944 944 backoff : int, optional
945 945 Backoff multiplier e.g. value of 2 will double the delay, by default 1
946 946 logger : bool, optional
947 947 Option to log or print, by default False
948 948
949 949 Returns
950 950 -------
951 951 typing.Callable
952 952 Decorated callable that calls itself when exception(s) occur.
953 953
954 954 Examples
955 955 --------
956 956 >>> import random
957 957 >>> @retry(exception=Exception, n_tries=3)
958 958 ... def test_random(text):
959 959 ... x = random.random()
960 960 ... if x < 0.5:
961 961 ... raise Exception("Fail")
962 962 ... else:
963 963 ... print("Success: ", text)
964 964 >>> test_random("It works!")
965 965 """
966 966
967 967 if func is None:
968 968 return partial(
969 969 retry,
970 970 exception=exception,
971 971 n_tries=n_tries,
972 972 delay=delay,
973 973 backoff=backoff,
974 974 logger=logger,
975 975 )
976 976
977 977 @wraps(func)
978 978 def wrapper(*args, **kwargs):
979 979 _n_tries, n_delay = n_tries, delay
980 980 log = logging.getLogger('rhodecode.retry')
981 981
982 982 while _n_tries > 1:
983 983 try:
984 984 return func(*args, **kwargs)
985 985 except exception as e:
986 986 e_details = repr(e)
987 987 msg = "Exception on calling func {func}: {e}, " \
988 988 "Retrying in {n_delay} seconds..."\
989 989 .format(func=func, e=e_details, n_delay=n_delay)
990 990 if logger:
991 991 log.warning(msg)
992 992 else:
993 993 print(msg)
994 994 time.sleep(n_delay)
995 995 _n_tries -= 1
996 996 n_delay *= backoff
997 997
998 998 return func(*args, **kwargs)
999 999
1000 1000 return wrapper
1001 1001
1002 1002
1003 1003 def user_agent_normalizer(user_agent_raw, safe=True):
1004 1004 log = logging.getLogger('rhodecode.user_agent_normalizer')
1005 1005 ua = (user_agent_raw or '').strip().lower()
1006 1006 ua = ua.replace('"', '')
1007 1007
1008 1008 try:
1009 1009 if 'mercurial/proto-1.0' in ua:
1010 1010 ua = ua.replace('mercurial/proto-1.0', '')
1011 1011 ua = ua.replace('(', '').replace(')', '').strip()
1012 1012 ua = ua.replace('mercurial ', 'mercurial/')
1013 1013 elif ua.startswith('git'):
1014 1014 parts = ua.split(' ')
1015 1015 if parts:
1016 1016 ua = parts[0]
1017 1017 ua = re.sub('\.windows\.\d', '', ua).strip()
1018 1018
1019 1019 return ua
1020 1020 except Exception:
1021 1021 log.exception('Failed to parse scm user-agent')
1022 1022 if not safe:
1023 1023 raise
1024 1024
1025 1025 return ua
1026 1026
1027 1027
1028 1028 def get_available_port(min_port=40000, max_port=55555, use_range=False):
1029 1029 hostname = ''
1030 1030 for _ in range(min_port, max_port):
1031 1031 pick_port = 0
1032 1032 if use_range:
1033 1033 pick_port = random.randint(min_port, max_port)
1034 1034
1035 1035 with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
1036 1036 try:
1037 1037 s.bind((hostname, pick_port))
1038 1038 s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1039 1039 return s.getsockname()[1]
1040 1040 except OSError:
1041 1041 continue
1042 1042 except socket.error as e:
1043 1043 if e.args[0] in [errno.EADDRINUSE, errno.ECONNREFUSED]:
1044 1044 continue
1045 1045 raise
@@ -1,494 +1,494 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2014-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 """
22 22 GIT commit module
23 23 """
24 24
25 25 import re
26 import io
26 27 import stat
27 28 import configparser
28 29 from itertools import chain
29 from io import StringIO
30 30
31 31 from zope.cachedescriptors.property import Lazy as LazyProperty
32 32
33 33 from rhodecode.lib.datelib import utcdate_fromtimestamp
34 34 from rhodecode.lib.utils import safe_unicode, safe_str
35 35 from rhodecode.lib.utils2 import safe_int
36 36 from rhodecode.lib.vcs.conf import settings
37 37 from rhodecode.lib.vcs.backends import base
38 38 from rhodecode.lib.vcs.exceptions import CommitError, NodeDoesNotExistError
39 39 from rhodecode.lib.vcs.nodes import (
40 40 FileNode, DirNode, NodeKind, RootNode, SubModuleNode,
41 41 ChangedFileNodesGenerator, AddedFileNodesGenerator,
42 42 RemovedFileNodesGenerator, LargeFileNode)
43 43
44 44
45 45 class GitCommit(base.BaseCommit):
46 46 """
47 47 Represents state of the repository at single commit id.
48 48 """
49 49
50 50 _filter_pre_load = [
51 51 # done through a more complex tree walk on parents
52 52 "affected_files",
53 53 # done through subprocess not remote call
54 54 "children",
55 55 # done through a more complex tree walk on parents
56 56 "status",
57 57 # mercurial specific property not supported here
58 58 "_file_paths",
59 59 # mercurial specific property not supported here
60 60 'obsolete',
61 61 # mercurial specific property not supported here
62 62 'phase',
63 63 # mercurial specific property not supported here
64 64 'hidden'
65 65 ]
66 66
67 67 def __init__(self, repository, raw_id, idx, pre_load=None):
68 68 self.repository = repository
69 69 self._remote = repository._remote
70 70 # TODO: johbo: Tweak of raw_id should not be necessary
71 71 self.raw_id = safe_str(raw_id)
72 72 self.idx = idx
73 73
74 74 self._set_bulk_properties(pre_load)
75 75
76 76 # caches
77 77 self._stat_modes = {} # stat info for paths
78 78 self._paths = {} # path processed with parse_tree
79 79 self.nodes = {}
80 80 self._submodules = None
81 81
82 82 def _set_bulk_properties(self, pre_load):
83 83
84 84 if not pre_load:
85 85 return
86 86 pre_load = [entry for entry in pre_load
87 87 if entry not in self._filter_pre_load]
88 88 if not pre_load:
89 89 return
90 90
91 91 result = self._remote.bulk_request(self.raw_id, pre_load)
92 92 for attr, value in result.items():
93 93 if attr in ["author", "message"]:
94 94 if value:
95 95 value = safe_unicode(value)
96 96 elif attr == "date":
97 97 value = utcdate_fromtimestamp(*value)
98 98 elif attr == "parents":
99 99 value = self._make_commits(value)
100 100 elif attr == "branch":
101 101 value = self._set_branch(value)
102 102 self.__dict__[attr] = value
103 103
104 104 @LazyProperty
105 105 def _commit(self):
106 106 return self._remote[self.raw_id]
107 107
108 108 @LazyProperty
109 109 def _tree_id(self):
110 110 return self._remote[self._commit['tree']]['id']
111 111
112 112 @LazyProperty
113 113 def id(self):
114 114 return self.raw_id
115 115
116 116 @LazyProperty
117 117 def short_id(self):
118 118 return self.raw_id[:12]
119 119
120 120 @LazyProperty
121 121 def message(self):
122 122 return safe_unicode(self._remote.message(self.id))
123 123
124 124 @LazyProperty
125 125 def committer(self):
126 126 return safe_unicode(self._remote.author(self.id))
127 127
128 128 @LazyProperty
129 129 def author(self):
130 130 return safe_unicode(self._remote.author(self.id))
131 131
132 132 @LazyProperty
133 133 def date(self):
134 134 unix_ts, tz = self._remote.date(self.raw_id)
135 135 return utcdate_fromtimestamp(unix_ts, tz)
136 136
137 137 @LazyProperty
138 138 def status(self):
139 139 """
140 140 Returns modified, added, removed, deleted files for current commit
141 141 """
142 142 return self.changed, self.added, self.removed
143 143
144 144 @LazyProperty
145 145 def tags(self):
146 146 tags = [safe_unicode(name) for name,
147 147 commit_id in self.repository.tags.items()
148 148 if commit_id == self.raw_id]
149 149 return tags
150 150
151 151 @LazyProperty
152 152 def commit_branches(self):
153 153 branches = []
154 154 for name, commit_id in self.repository.branches.items():
155 155 if commit_id == self.raw_id:
156 156 branches.append(name)
157 157 return branches
158 158
159 159 def _set_branch(self, branches):
160 160 if branches:
161 161 # actually commit can have multiple branches in git
162 162 return safe_unicode(branches[0])
163 163
164 164 @LazyProperty
165 165 def branch(self):
166 166 branches = self._remote.branch(self.raw_id)
167 167 return self._set_branch(branches)
168 168
169 169 def _get_tree_id_for_path(self, path):
170 170 path = safe_str(path)
171 171 if path in self._paths:
172 172 return self._paths[path]
173 173
174 174 tree_id = self._tree_id
175 175
176 176 path = path.strip('/')
177 177 if path == '':
178 178 data = [tree_id, "tree"]
179 179 self._paths[''] = data
180 180 return data
181 181
182 182 tree_id, tree_type, tree_mode = \
183 183 self._remote.tree_and_type_for_path(self.raw_id, path)
184 184 if tree_id is None:
185 185 raise self.no_node_at_path(path)
186 186
187 187 self._paths[path] = [tree_id, tree_type]
188 188 self._stat_modes[path] = tree_mode
189 189
190 190 if path not in self._paths:
191 191 raise self.no_node_at_path(path)
192 192
193 193 return self._paths[path]
194 194
195 195 def _get_kind(self, path):
196 196 tree_id, type_ = self._get_tree_id_for_path(path)
197 197 if type_ == 'blob':
198 198 return NodeKind.FILE
199 199 elif type_ == 'tree':
200 200 return NodeKind.DIR
201 201 elif type_ == 'link':
202 202 return NodeKind.SUBMODULE
203 203 return None
204 204
205 205 def _get_filectx(self, path):
206 206 path = self._fix_path(path)
207 207 if self._get_kind(path) != NodeKind.FILE:
208 208 raise CommitError(
209 209 "File does not exist for commit %s at '%s'" % (self.raw_id, path))
210 210 return path
211 211
212 212 def _get_file_nodes(self):
213 213 return chain(*(t[2] for t in self.walk()))
214 214
215 215 @LazyProperty
216 216 def parents(self):
217 217 """
218 218 Returns list of parent commits.
219 219 """
220 220 parent_ids = self._remote.parents(self.id)
221 221 return self._make_commits(parent_ids)
222 222
223 223 @LazyProperty
224 224 def children(self):
225 225 """
226 226 Returns list of child commits.
227 227 """
228 228
229 229 children = self._remote.children(self.raw_id)
230 230 return self._make_commits(children)
231 231
232 232 def _make_commits(self, commit_ids):
233 233 def commit_maker(_commit_id):
234 234 return self.repository.get_commit(commit_id=commit_id)
235 235
236 236 return [commit_maker(commit_id) for commit_id in commit_ids]
237 237
238 238 def get_file_mode(self, path):
239 239 """
240 240 Returns stat mode of the file at the given `path`.
241 241 """
242 242 path = safe_str(path)
243 243 # ensure path is traversed
244 244 self._get_tree_id_for_path(path)
245 245 return self._stat_modes[path]
246 246
247 247 def is_link(self, path):
248 248 return stat.S_ISLNK(self.get_file_mode(path))
249 249
250 250 def is_node_binary(self, path):
251 251 tree_id, _ = self._get_tree_id_for_path(path)
252 252 return self._remote.is_binary(tree_id)
253 253
254 254 def get_file_content(self, path):
255 255 """
256 256 Returns content of the file at given `path`.
257 257 """
258 258 tree_id, _ = self._get_tree_id_for_path(path)
259 259 return self._remote.blob_as_pretty_string(tree_id)
260 260
261 261 def get_file_content_streamed(self, path):
262 262 tree_id, _ = self._get_tree_id_for_path(path)
263 263 stream_method = getattr(self._remote, 'stream:blob_as_pretty_string')
264 264 return stream_method(tree_id)
265 265
266 266 def get_file_size(self, path):
267 267 """
268 268 Returns size of the file at given `path`.
269 269 """
270 270 tree_id, _ = self._get_tree_id_for_path(path)
271 271 return self._remote.blob_raw_length(tree_id)
272 272
273 273 def get_path_history(self, path, limit=None, pre_load=None):
274 274 """
275 275 Returns history of file as reversed list of `GitCommit` objects for
276 276 which file at given `path` has been modified.
277 277 """
278 278
279 279 path = self._get_filectx(path)
280 280 hist = self._remote.node_history(self.raw_id, path, limit)
281 281 return [
282 282 self.repository.get_commit(commit_id=commit_id, pre_load=pre_load)
283 283 for commit_id in hist]
284 284
285 285 def get_file_annotate(self, path, pre_load=None):
286 286 """
287 287 Returns a generator of four element tuples with
288 288 lineno, commit_id, commit lazy loader and line
289 289 """
290 290
291 291 result = self._remote.node_annotate(self.raw_id, path)
292 292
293 293 for ln_no, commit_id, content in result:
294 294 yield (
295 295 ln_no, commit_id,
296 296 lambda: self.repository.get_commit(commit_id=commit_id, pre_load=pre_load),
297 297 content)
298 298
299 299 def get_nodes(self, path):
300 300
301 301 if self._get_kind(path) != NodeKind.DIR:
302 302 raise CommitError(
303 303 "Directory does not exist for commit %s at '%s'" % (self.raw_id, path))
304 304 path = self._fix_path(path)
305 305
306 306 tree_id, _ = self._get_tree_id_for_path(path)
307 307
308 308 dirnodes = []
309 309 filenodes = []
310 310
311 311 # extracted tree ID gives us our files...
312 312 bytes_path = safe_str(path) # libgit operates on bytes
313 313 for name, stat_, id_, type_ in self._remote.tree_items(tree_id):
314 314 if type_ == 'link':
315 315 url = self._get_submodule_url('/'.join((bytes_path, name)))
316 316 dirnodes.append(SubModuleNode(
317 317 name, url=url, commit=id_, alias=self.repository.alias))
318 318 continue
319 319
320 320 if bytes_path != '':
321 321 obj_path = '/'.join((bytes_path, name))
322 322 else:
323 323 obj_path = name
324 324 if obj_path not in self._stat_modes:
325 325 self._stat_modes[obj_path] = stat_
326 326
327 327 if type_ == 'tree':
328 328 dirnodes.append(DirNode(obj_path, commit=self))
329 329 elif type_ == 'blob':
330 330 filenodes.append(FileNode(obj_path, commit=self, mode=stat_))
331 331 else:
332 332 raise CommitError(
333 333 "Requested object should be Tree or Blob, is %s", type_)
334 334
335 335 nodes = dirnodes + filenodes
336 336 for node in nodes:
337 337 if node.path not in self.nodes:
338 338 self.nodes[node.path] = node
339 339 nodes.sort()
340 340 return nodes
341 341
342 342 def get_node(self, path, pre_load=None):
343 343 path = self._fix_path(path)
344 344 if path not in self.nodes:
345 345 try:
346 346 tree_id, type_ = self._get_tree_id_for_path(path)
347 347 except CommitError:
348 348 raise NodeDoesNotExistError(
349 349 "Cannot find one of parents' directories for a given "
350 350 "path: %s" % path)
351 351
352 352 if type_ in ['link', 'commit']:
353 353 url = self._get_submodule_url(path)
354 354 node = SubModuleNode(path, url=url, commit=tree_id,
355 355 alias=self.repository.alias)
356 356 elif type_ == 'tree':
357 357 if path == '':
358 358 node = RootNode(commit=self)
359 359 else:
360 360 node = DirNode(path, commit=self)
361 361 elif type_ == 'blob':
362 362 node = FileNode(path, commit=self, pre_load=pre_load)
363 363 self._stat_modes[path] = node.mode
364 364 else:
365 365 raise self.no_node_at_path(path)
366 366
367 367 # cache node
368 368 self.nodes[path] = node
369 369
370 370 return self.nodes[path]
371 371
372 372 def get_largefile_node(self, path):
373 373 tree_id, _ = self._get_tree_id_for_path(path)
374 374 pointer_spec = self._remote.is_large_file(tree_id)
375 375
376 376 if pointer_spec:
377 377 # content of that file regular FileNode is the hash of largefile
378 378 file_id = pointer_spec.get('oid_hash')
379 379 if self._remote.in_largefiles_store(file_id):
380 380 lf_path = self._remote.store_path(file_id)
381 381 return LargeFileNode(lf_path, commit=self, org_path=path)
382 382
383 383 @LazyProperty
384 384 def affected_files(self):
385 385 """
386 386 Gets a fast accessible file changes for given commit
387 387 """
388 388 added, modified, deleted = self._changes_cache
389 389 return list(added.union(modified).union(deleted))
390 390
391 391 @LazyProperty
392 392 def _changes_cache(self):
393 393 added = set()
394 394 modified = set()
395 395 deleted = set()
396 396 _r = self._remote
397 397
398 398 parents = self.parents
399 399 if not self.parents:
400 400 parents = [base.EmptyCommit()]
401 401 for parent in parents:
402 402 if isinstance(parent, base.EmptyCommit):
403 403 oid = None
404 404 else:
405 405 oid = parent.raw_id
406 406 changes = _r.tree_changes(oid, self.raw_id)
407 407 for (oldpath, newpath), (_, _), (_, _) in changes:
408 408 if newpath and oldpath:
409 409 modified.add(newpath)
410 410 elif newpath and not oldpath:
411 411 added.add(newpath)
412 412 elif not newpath and oldpath:
413 413 deleted.add(oldpath)
414 414 return added, modified, deleted
415 415
416 416 def _get_paths_for_status(self, status):
417 417 """
418 418 Returns sorted list of paths for given ``status``.
419 419
420 420 :param status: one of: *added*, *modified* or *deleted*
421 421 """
422 422 added, modified, deleted = self._changes_cache
423 423 return sorted({
424 424 'added': list(added),
425 425 'modified': list(modified),
426 426 'deleted': list(deleted)}[status]
427 427 )
428 428
429 429 @LazyProperty
430 430 def added(self):
431 431 """
432 432 Returns list of added ``FileNode`` objects.
433 433 """
434 434 if not self.parents:
435 435 return list(self._get_file_nodes())
436 436 return AddedFileNodesGenerator(self.added_paths, self)
437 437
438 438 @LazyProperty
439 439 def added_paths(self):
440 440 return [n for n in self._get_paths_for_status('added')]
441 441
442 442 @LazyProperty
443 443 def changed(self):
444 444 """
445 445 Returns list of modified ``FileNode`` objects.
446 446 """
447 447 if not self.parents:
448 448 return []
449 449 return ChangedFileNodesGenerator(self.changed_paths, self)
450 450
451 451 @LazyProperty
452 452 def changed_paths(self):
453 453 return [n for n in self._get_paths_for_status('modified')]
454 454
455 455 @LazyProperty
456 456 def removed(self):
457 457 """
458 458 Returns list of removed ``FileNode`` objects.
459 459 """
460 460 if not self.parents:
461 461 return []
462 462 return RemovedFileNodesGenerator(self.removed_paths, self)
463 463
464 464 @LazyProperty
465 465 def removed_paths(self):
466 466 return [n for n in self._get_paths_for_status('deleted')]
467 467
468 468 def _get_submodule_url(self, submodule_path):
469 469 git_modules_path = '.gitmodules'
470 470
471 471 if self._submodules is None:
472 472 self._submodules = {}
473 473
474 474 try:
475 475 submodules_node = self.get_node(git_modules_path)
476 476 except NodeDoesNotExistError:
477 477 return None
478 478
479 479 # ConfigParser fails if there are whitespaces, also it needs an iterable
480 480 # file like content
481 481 def iter_content(_content):
482 482 for line in _content.splitlines():
483 483 yield line
484 484
485 485 parser = configparser.RawConfigParser()
486 486 parser.read_file(iter_content(submodules_node.content))
487 487
488 488 for section in parser.sections():
489 489 path = parser.get(section, 'path')
490 490 url = parser.get(section, 'url')
491 491 if path and url:
492 492 self._submodules[path.strip('/')] = url
493 493
494 494 return self._submodules.get(submodule_path.strip('/'))
@@ -1,401 +1,402 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2014-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 """
22 22 HG commit module
23 23 """
24 24
25 25 import os
26 26
27 27 from zope.cachedescriptors.property import Lazy as LazyProperty
28 28
29 29 from rhodecode.lib.datelib import utcdate_fromtimestamp
30 30 from rhodecode.lib.utils import safe_str, safe_unicode
31 31 from rhodecode.lib.vcs import path as vcspath
32 32 from rhodecode.lib.vcs.backends import base
33 33 from rhodecode.lib.vcs.backends.hg.diff import MercurialDiff
34 34 from rhodecode.lib.vcs.exceptions import CommitError
35 35 from rhodecode.lib.vcs.nodes import (
36 36 AddedFileNodesGenerator, ChangedFileNodesGenerator, DirNode, FileNode,
37 37 NodeKind, RemovedFileNodesGenerator, RootNode, SubModuleNode,
38 38 LargeFileNode, LARGEFILE_PREFIX)
39 39 from rhodecode.lib.vcs.utils.paths import get_dirs_for_path
40 40
41 41
42 42 class MercurialCommit(base.BaseCommit):
43 43 """
44 44 Represents state of the repository at the single commit.
45 45 """
46 46
47 47 _filter_pre_load = [
48 48 # git specific property not supported here
49 49 "_commit",
50 50 ]
51 51
52 52 def __init__(self, repository, raw_id, idx, pre_load=None):
53 53 raw_id = safe_str(raw_id)
54 54
55 55 self.repository = repository
56 56 self._remote = repository._remote
57 57
58 58 self.raw_id = raw_id
59 59 self.idx = idx
60 60
61 61 self._set_bulk_properties(pre_load)
62 62
63 63 # caches
64 64 self.nodes = {}
65 65
66 66 def _set_bulk_properties(self, pre_load):
67 67 if not pre_load:
68 68 return
69 69 pre_load = [entry for entry in pre_load
70 70 if entry not in self._filter_pre_load]
71 71 if not pre_load:
72 72 return
73 73
74 74 result = self._remote.bulk_request(self.raw_id, pre_load)
75
75 76 for attr, value in result.items():
76 77 if attr in ["author", "branch", "message"]:
77 78 value = safe_unicode(value)
78 79 elif attr == "affected_files":
79 80 value = map(safe_unicode, value)
80 81 elif attr == "date":
81 82 value = utcdate_fromtimestamp(*value)
82 83 elif attr in ["children", "parents"]:
83 84 value = self._make_commits(value)
84 85 elif attr in ["phase"]:
85 86 value = self._get_phase_text(value)
86 87 self.__dict__[attr] = value
87 88
88 89 @LazyProperty
89 90 def tags(self):
90 91 tags = [name for name, commit_id in self.repository.tags.items()
91 92 if commit_id == self.raw_id]
92 93 return tags
93 94
94 95 @LazyProperty
95 96 def branch(self):
96 97 return safe_unicode(self._remote.ctx_branch(self.raw_id))
97 98
98 99 @LazyProperty
99 100 def bookmarks(self):
100 101 bookmarks = [
101 102 name for name, commit_id in self.repository.bookmarks.items()
102 103 if commit_id == self.raw_id]
103 104 return bookmarks
104 105
105 106 @LazyProperty
106 107 def message(self):
107 108 return safe_unicode(self._remote.ctx_description(self.raw_id))
108 109
109 110 @LazyProperty
110 111 def committer(self):
111 112 return safe_unicode(self.author)
112 113
113 114 @LazyProperty
114 115 def author(self):
115 116 return safe_unicode(self._remote.ctx_user(self.raw_id))
116 117
117 118 @LazyProperty
118 119 def date(self):
119 120 return utcdate_fromtimestamp(*self._remote.ctx_date(self.raw_id))
120 121
121 122 @LazyProperty
122 123 def status(self):
123 124 """
124 125 Returns modified, added, removed, deleted files for current commit
125 126 """
126 127 return self._remote.ctx_status(self.raw_id)
127 128
128 129 @LazyProperty
129 130 def _file_paths(self):
130 131 return self._remote.ctx_list(self.raw_id)
131 132
132 133 @LazyProperty
133 134 def _dir_paths(self):
134 135 p = list(set(get_dirs_for_path(*self._file_paths)))
135 136 p.insert(0, '')
136 137 return p
137 138
138 139 @LazyProperty
139 140 def _paths(self):
140 141 return self._dir_paths + self._file_paths
141 142
142 143 @LazyProperty
143 144 def id(self):
144 145 if self.last:
145 146 return u'tip'
146 147 return self.short_id
147 148
148 149 @LazyProperty
149 150 def short_id(self):
150 151 return self.raw_id[:12]
151 152
152 153 def _make_commits(self, commit_ids, pre_load=None):
153 154 return [self.repository.get_commit(commit_id=commit_id, pre_load=pre_load)
154 155 for commit_id in commit_ids]
155 156
156 157 @LazyProperty
157 158 def parents(self):
158 159 """
159 160 Returns list of parent commits.
160 161 """
161 162 parents = self._remote.ctx_parents(self.raw_id)
162 163 return self._make_commits(parents)
163 164
164 165 def _get_phase_text(self, phase_id):
165 166 return {
166 167 0: 'public',
167 168 1: 'draft',
168 169 2: 'secret',
169 170 }.get(phase_id) or ''
170 171
171 172 @LazyProperty
172 173 def phase(self):
173 174 phase_id = self._remote.ctx_phase(self.raw_id)
174 175 phase_text = self._get_phase_text(phase_id)
175 176
176 177 return safe_unicode(phase_text)
177 178
178 179 @LazyProperty
179 180 def obsolete(self):
180 181 obsolete = self._remote.ctx_obsolete(self.raw_id)
181 182 return obsolete
182 183
183 184 @LazyProperty
184 185 def hidden(self):
185 186 hidden = self._remote.ctx_hidden(self.raw_id)
186 187 return hidden
187 188
188 189 @LazyProperty
189 190 def children(self):
190 191 """
191 192 Returns list of child commits.
192 193 """
193 194 children = self._remote.ctx_children(self.raw_id)
194 195 return self._make_commits(children)
195 196
196 197 def _fix_path(self, path):
197 198 """
198 199 Mercurial keeps filenodes as str so we need to encode from unicode
199 200 to str.
200 201 """
201 202 return safe_str(super(MercurialCommit, self)._fix_path(path))
202 203
203 204 def _get_kind(self, path):
204 205 path = self._fix_path(path)
205 206 if path in self._file_paths:
206 207 return NodeKind.FILE
207 208 elif path in self._dir_paths:
208 209 return NodeKind.DIR
209 210 else:
210 211 raise CommitError(
211 212 "Node does not exist at the given path '%s'" % (path, ))
212 213
213 214 def _get_filectx(self, path):
214 215 path = self._fix_path(path)
215 216 if self._get_kind(path) != NodeKind.FILE:
216 217 raise CommitError(
217 218 "File does not exist for idx %s at '%s'" % (self.raw_id, path))
218 219 return path
219 220
220 221 def get_file_mode(self, path):
221 222 """
222 223 Returns stat mode of the file at the given ``path``.
223 224 """
224 225 path = self._get_filectx(path)
225 226 if 'x' in self._remote.fctx_flags(self.raw_id, path):
226 227 return base.FILEMODE_EXECUTABLE
227 228 else:
228 229 return base.FILEMODE_DEFAULT
229 230
230 231 def is_link(self, path):
231 232 path = self._get_filectx(path)
232 233 return 'l' in self._remote.fctx_flags(self.raw_id, path)
233 234
234 235 def is_node_binary(self, path):
235 236 path = self._get_filectx(path)
236 237 return self._remote.is_binary(self.raw_id, path)
237 238
238 239 def get_file_content(self, path):
239 240 """
240 241 Returns content of the file at given ``path``.
241 242 """
242 243 path = self._get_filectx(path)
243 244 return self._remote.fctx_node_data(self.raw_id, path)
244 245
245 246 def get_file_content_streamed(self, path):
246 247 path = self._get_filectx(path)
247 248 stream_method = getattr(self._remote, 'stream:fctx_node_data')
248 249 return stream_method(self.raw_id, path)
249 250
250 251 def get_file_size(self, path):
251 252 """
252 253 Returns size of the file at given ``path``.
253 254 """
254 255 path = self._get_filectx(path)
255 256 return self._remote.fctx_size(self.raw_id, path)
256 257
257 258 def get_path_history(self, path, limit=None, pre_load=None):
258 259 """
259 260 Returns history of file as reversed list of `MercurialCommit` objects
260 261 for which file at given ``path`` has been modified.
261 262 """
262 263 path = self._get_filectx(path)
263 264 hist = self._remote.node_history(self.raw_id, path, limit)
264 265 return [
265 266 self.repository.get_commit(commit_id=commit_id, pre_load=pre_load)
266 267 for commit_id in hist]
267 268
268 269 def get_file_annotate(self, path, pre_load=None):
269 270 """
270 271 Returns a generator of four element tuples with
271 272 lineno, commit_id, commit lazy loader and line
272 273 """
273 274 result = self._remote.fctx_annotate(self.raw_id, path)
274 275
275 276 for ln_no, commit_id, content in result:
276 277 yield (
277 278 ln_no, commit_id,
278 279 lambda: self.repository.get_commit(commit_id=commit_id, pre_load=pre_load),
279 280 content)
280 281
281 282 def get_nodes(self, path):
282 283 """
283 284 Returns combined ``DirNode`` and ``FileNode`` objects list representing
284 285 state of commit at the given ``path``. If node at the given ``path``
285 286 is not instance of ``DirNode``, CommitError would be raised.
286 287 """
287 288
288 289 if self._get_kind(path) != NodeKind.DIR:
289 290 raise CommitError(
290 291 "Directory does not exist for idx %s at '%s'" % (self.raw_id, path))
291 292 path = self._fix_path(path)
292 293
293 294 filenodes = [
294 295 FileNode(f, commit=self) for f in self._file_paths
295 296 if os.path.dirname(f) == path]
296 297 # TODO: johbo: Check if this can be done in a more obvious way
297 298 dirs = path == '' and '' or [
298 299 d for d in self._dir_paths
299 300 if d and vcspath.dirname(d) == path]
300 301 dirnodes = [
301 302 DirNode(d, commit=self) for d in dirs
302 303 if os.path.dirname(d) == path]
303 304
304 305 alias = self.repository.alias
305 306 for k, vals in self._submodules.items():
306 307 if vcspath.dirname(k) == path:
307 308 loc = vals[0]
308 309 commit = vals[1]
309 310 dirnodes.append(SubModuleNode(k, url=loc, commit=commit, alias=alias))
310 311
311 312 nodes = dirnodes + filenodes
312 313 for node in nodes:
313 314 if node.path not in self.nodes:
314 315 self.nodes[node.path] = node
315 316 nodes.sort()
316 317
317 318 return nodes
318 319
319 320 def get_node(self, path, pre_load=None):
320 321 """
321 322 Returns `Node` object from the given `path`. If there is no node at
322 323 the given `path`, `NodeDoesNotExistError` would be raised.
323 324 """
324 325 path = self._fix_path(path)
325 326
326 327 if path not in self.nodes:
327 328 if path in self._file_paths:
328 329 node = FileNode(path, commit=self, pre_load=pre_load)
329 330 elif path in self._dir_paths:
330 331 if path == '':
331 332 node = RootNode(commit=self)
332 333 else:
333 334 node = DirNode(path, commit=self)
334 335 else:
335 336 raise self.no_node_at_path(path)
336 337
337 338 # cache node
338 339 self.nodes[path] = node
339 340 return self.nodes[path]
340 341
341 342 def get_largefile_node(self, path):
342 343 pointer_spec = self._remote.is_large_file(self.raw_id, path)
343 344 if pointer_spec:
344 345 # content of that file regular FileNode is the hash of largefile
345 346 file_id = self.get_file_content(path).strip()
346 347
347 348 if self._remote.in_largefiles_store(file_id):
348 349 lf_path = self._remote.store_path(file_id)
349 350 return LargeFileNode(lf_path, commit=self, org_path=path)
350 351 elif self._remote.in_user_cache(file_id):
351 352 lf_path = self._remote.store_path(file_id)
352 353 self._remote.link(file_id, path)
353 354 return LargeFileNode(lf_path, commit=self, org_path=path)
354 355
355 356 @LazyProperty
356 357 def _submodules(self):
357 358 """
358 359 Returns a dictionary with submodule information from substate file
359 360 of hg repository.
360 361 """
361 362 return self._remote.ctx_substate(self.raw_id)
362 363
363 364 @LazyProperty
364 365 def affected_files(self):
365 366 """
366 367 Gets a fast accessible file changes for given commit
367 368 """
368 369 return self._remote.ctx_files(self.raw_id)
369 370
370 371 @property
371 372 def added(self):
372 373 """
373 374 Returns list of added ``FileNode`` objects.
374 375 """
375 376 return AddedFileNodesGenerator(self.added_paths, self)
376 377
377 378 @LazyProperty
378 379 def added_paths(self):
379 380 return [n for n in self.status[1]]
380 381
381 382 @property
382 383 def changed(self):
383 384 """
384 385 Returns list of modified ``FileNode`` objects.
385 386 """
386 387 return ChangedFileNodesGenerator(self.changed_paths, self)
387 388
388 389 @LazyProperty
389 390 def changed_paths(self):
390 391 return [n for n in self.status[0]]
391 392
392 393 @property
393 394 def removed(self):
394 395 """
395 396 Returns list of removed ``FileNode`` objects.
396 397 """
397 398 return RemovedFileNodesGenerator(self.removed_paths, self)
398 399
399 400 @LazyProperty
400 401 def removed_paths(self):
401 402 return [n for n in self.status[2]]
@@ -1,600 +1,600 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 """
22 22 permissions model for RhodeCode
23 23 """
24 24 import collections
25 25 import logging
26 26 import traceback
27 27
28 28 from sqlalchemy.exc import DatabaseError
29 29
30 30 from rhodecode import events
31 31 from rhodecode.model import BaseModel
32 32 from rhodecode.model.db import (
33 33 User, Permission, UserToPerm, UserRepoToPerm, UserRepoGroupToPerm,
34 34 UserUserGroupToPerm, UserGroup, UserGroupToPerm, UserToRepoBranchPermission)
35 35 from rhodecode.lib.utils2 import str2bool, safe_int
36 36
37 37 log = logging.getLogger(__name__)
38 38
39 39
40 40 class PermissionModel(BaseModel):
41 41 """
42 42 Permissions model for RhodeCode
43 43 """
44 44 FORKING_DISABLED = 'hg.fork.none'
45 45 FORKING_ENABLED = 'hg.fork.repository'
46 46
47 47 cls = Permission
48 48 global_perms = {
49 49 'default_repo_create': None,
50 50 # special case for create repos on write access to group
51 51 'default_repo_create_on_write': None,
52 52 'default_repo_group_create': None,
53 53 'default_user_group_create': None,
54 54 'default_fork_create': None,
55 55 'default_inherit_default_permissions': None,
56 56 'default_register': None,
57 57 'default_password_reset': None,
58 58 'default_extern_activate': None,
59 59
60 60 # object permissions below
61 61 'default_repo_perm': None,
62 62 'default_group_perm': None,
63 63 'default_user_group_perm': None,
64 64
65 65 # branch
66 66 'default_branch_perm': None,
67 67 }
68 68
69 69 def set_global_permission_choices(self, c_obj, gettext_translator):
70 70 _ = gettext_translator
71 71
72 72 c_obj.repo_perms_choices = [
73 73 ('repository.none', _('None'),),
74 74 ('repository.read', _('Read'),),
75 75 ('repository.write', _('Write'),),
76 76 ('repository.admin', _('Admin'),)]
77 77
78 78 c_obj.group_perms_choices = [
79 79 ('group.none', _('None'),),
80 80 ('group.read', _('Read'),),
81 81 ('group.write', _('Write'),),
82 82 ('group.admin', _('Admin'),)]
83 83
84 84 c_obj.user_group_perms_choices = [
85 85 ('usergroup.none', _('None'),),
86 86 ('usergroup.read', _('Read'),),
87 87 ('usergroup.write', _('Write'),),
88 88 ('usergroup.admin', _('Admin'),)]
89 89
90 90 c_obj.branch_perms_choices = [
91 91 ('branch.none', _('Protected/No Access'),),
92 92 ('branch.merge', _('Web merge'),),
93 93 ('branch.push', _('Push'),),
94 94 ('branch.push_force', _('Force Push'),)]
95 95
96 96 c_obj.register_choices = [
97 97 ('hg.register.none', _('Disabled')),
98 98 ('hg.register.manual_activate', _('Allowed with manual account activation')),
99 ('hg.register.auto_activate', _('Allowed with automatic account activation')),]
99 ('hg.register.auto_activate', _('Allowed with automatic account activation'))]
100 100
101 101 c_obj.password_reset_choices = [
102 102 ('hg.password_reset.enabled', _('Allow password recovery')),
103 103 ('hg.password_reset.hidden', _('Hide password recovery link')),
104 ('hg.password_reset.disabled', _('Disable password recovery')),]
104 ('hg.password_reset.disabled', _('Disable password recovery'))]
105 105
106 106 c_obj.extern_activate_choices = [
107 107 ('hg.extern_activate.manual', _('Manual activation of external account')),
108 ('hg.extern_activate.auto', _('Automatic activation of external account')),]
108 ('hg.extern_activate.auto', _('Automatic activation of external account'))]
109 109
110 110 c_obj.repo_create_choices = [
111 111 ('hg.create.none', _('Disabled')),
112 112 ('hg.create.repository', _('Enabled'))]
113 113
114 114 c_obj.repo_create_on_write_choices = [
115 115 ('hg.create.write_on_repogroup.false', _('Disabled')),
116 116 ('hg.create.write_on_repogroup.true', _('Enabled'))]
117 117
118 118 c_obj.user_group_create_choices = [
119 119 ('hg.usergroup.create.false', _('Disabled')),
120 120 ('hg.usergroup.create.true', _('Enabled'))]
121 121
122 122 c_obj.repo_group_create_choices = [
123 123 ('hg.repogroup.create.false', _('Disabled')),
124 124 ('hg.repogroup.create.true', _('Enabled'))]
125 125
126 126 c_obj.fork_choices = [
127 127 (self.FORKING_DISABLED, _('Disabled')),
128 128 (self.FORKING_ENABLED, _('Enabled'))]
129 129
130 130 c_obj.inherit_default_permission_choices = [
131 131 ('hg.inherit_default_perms.false', _('Disabled')),
132 132 ('hg.inherit_default_perms.true', _('Enabled'))]
133 133
134 134 def get_default_perms(self, object_perms, suffix):
135 135 defaults = {}
136 136 for perm in object_perms:
137 137 # perms
138 138 if perm.permission.permission_name.startswith('repository.'):
139 139 defaults['default_repo_perm' + suffix] = perm.permission.permission_name
140 140
141 141 if perm.permission.permission_name.startswith('group.'):
142 142 defaults['default_group_perm' + suffix] = perm.permission.permission_name
143 143
144 144 if perm.permission.permission_name.startswith('usergroup.'):
145 145 defaults['default_user_group_perm' + suffix] = perm.permission.permission_name
146 146
147 147 # branch
148 148 if perm.permission.permission_name.startswith('branch.'):
149 149 defaults['default_branch_perm' + suffix] = perm.permission.permission_name
150 150
151 151 # creation of objects
152 152 if perm.permission.permission_name.startswith('hg.create.write_on_repogroup'):
153 153 defaults['default_repo_create_on_write' + suffix] = perm.permission.permission_name
154 154
155 155 elif perm.permission.permission_name.startswith('hg.create.'):
156 156 defaults['default_repo_create' + suffix] = perm.permission.permission_name
157 157
158 158 if perm.permission.permission_name.startswith('hg.fork.'):
159 159 defaults['default_fork_create' + suffix] = perm.permission.permission_name
160 160
161 161 if perm.permission.permission_name.startswith('hg.inherit_default_perms.'):
162 162 defaults['default_inherit_default_permissions' + suffix] = perm.permission.permission_name
163 163
164 164 if perm.permission.permission_name.startswith('hg.repogroup.'):
165 165 defaults['default_repo_group_create' + suffix] = perm.permission.permission_name
166 166
167 167 if perm.permission.permission_name.startswith('hg.usergroup.'):
168 168 defaults['default_user_group_create' + suffix] = perm.permission.permission_name
169 169
170 170 # registration and external account activation
171 171 if perm.permission.permission_name.startswith('hg.register.'):
172 172 defaults['default_register' + suffix] = perm.permission.permission_name
173 173
174 174 if perm.permission.permission_name.startswith('hg.password_reset.'):
175 175 defaults['default_password_reset' + suffix] = perm.permission.permission_name
176 176
177 177 if perm.permission.permission_name.startswith('hg.extern_activate.'):
178 178 defaults['default_extern_activate' + suffix] = perm.permission.permission_name
179 179
180 180 return defaults
181 181
182 182 def _make_new_user_perm(self, user, perm_name):
183 183 log.debug('Creating new user permission:%s', perm_name)
184 184 new = UserToPerm()
185 185 new.user = user
186 186 new.permission = Permission.get_by_key(perm_name)
187 187 return new
188 188
189 189 def _make_new_user_group_perm(self, user_group, perm_name):
190 190 log.debug('Creating new user group permission:%s', perm_name)
191 191 new = UserGroupToPerm()
192 192 new.users_group = user_group
193 193 new.permission = Permission.get_by_key(perm_name)
194 194 return new
195 195
196 196 def _keep_perm(self, perm_name, keep_fields):
197 197 def get_pat(field_name):
198 198 return {
199 199 # global perms
200 200 'default_repo_create': 'hg.create.',
201 201 # special case for create repos on write access to group
202 202 'default_repo_create_on_write': 'hg.create.write_on_repogroup.',
203 203 'default_repo_group_create': 'hg.repogroup.create.',
204 204 'default_user_group_create': 'hg.usergroup.create.',
205 205 'default_fork_create': 'hg.fork.',
206 206 'default_inherit_default_permissions': 'hg.inherit_default_perms.',
207 207
208 208 # application perms
209 209 'default_register': 'hg.register.',
210 210 'default_password_reset': 'hg.password_reset.',
211 211 'default_extern_activate': 'hg.extern_activate.',
212 212
213 213 # object permissions below
214 214 'default_repo_perm': 'repository.',
215 215 'default_group_perm': 'group.',
216 216 'default_user_group_perm': 'usergroup.',
217 217 # branch
218 218 'default_branch_perm': 'branch.',
219 219
220 220 }[field_name]
221 221 for field in keep_fields:
222 222 pat = get_pat(field)
223 223 if perm_name.startswith(pat):
224 224 return True
225 225 return False
226 226
227 227 def _clear_object_perm(self, object_perms, preserve=None):
228 228 preserve = preserve or []
229 229 _deleted = []
230 230 for perm in object_perms:
231 231 perm_name = perm.permission.permission_name
232 232 if not self._keep_perm(perm_name, keep_fields=preserve):
233 233 _deleted.append(perm_name)
234 234 self.sa.delete(perm)
235 235 return _deleted
236 236
237 237 def _clear_user_perms(self, user_id, preserve=None):
238 238 perms = self.sa.query(UserToPerm)\
239 239 .filter(UserToPerm.user_id == user_id)\
240 240 .all()
241 241 return self._clear_object_perm(perms, preserve=preserve)
242 242
243 243 def _clear_user_group_perms(self, user_group_id, preserve=None):
244 244 perms = self.sa.query(UserGroupToPerm)\
245 245 .filter(UserGroupToPerm.users_group_id == user_group_id)\
246 246 .all()
247 247 return self._clear_object_perm(perms, preserve=preserve)
248 248
249 def _set_new_object_perms(self, obj_type, object, form_result, preserve=None):
249 def _set_new_object_perms(self, obj_type, to_object, form_result, preserve=None):
250 250 # clear current entries, to make this function idempotent
251 251 # it will fix even if we define more permissions or permissions
252 252 # are somehow missing
253 253 preserve = preserve or []
254 254 _global_perms = self.global_perms.copy()
255 255 if obj_type not in ['user', 'user_group']:
256 256 raise ValueError("obj_type must be on of 'user' or 'user_group'")
257 257 global_perms = len(_global_perms)
258 258 default_user_perms = len(Permission.DEFAULT_USER_PERMISSIONS)
259 259 if global_perms != default_user_perms:
260 260 raise Exception(
261 261 'Inconsistent permissions definition. Got {} vs {}'.format(
262 262 global_perms, default_user_perms))
263 263
264 264 if obj_type == 'user':
265 self._clear_user_perms(object.user_id, preserve)
265 self._clear_user_perms(to_object.user_id, preserve)
266 266 if obj_type == 'user_group':
267 self._clear_user_group_perms(object.users_group_id, preserve)
267 self._clear_user_group_perms(to_object.users_group_id, preserve)
268 268
269 269 # now kill the keys that we want to preserve from the form.
270 270 for key in preserve:
271 271 del _global_perms[key]
272 272
273 273 for k in _global_perms.copy():
274 274 _global_perms[k] = form_result[k]
275 275
276 276 # at that stage we validate all are passed inside form_result
277 277 for _perm_key, perm_value in _global_perms.items():
278 278 if perm_value is None:
279 279 raise ValueError('Missing permission for %s' % (_perm_key,))
280 280
281 281 if obj_type == 'user':
282 282 p = self._make_new_user_perm(object, perm_value)
283 283 self.sa.add(p)
284 284 if obj_type == 'user_group':
285 285 p = self._make_new_user_group_perm(object, perm_value)
286 286 self.sa.add(p)
287 287
288 288 def _set_new_user_perms(self, user, form_result, preserve=None):
289 289 return self._set_new_object_perms(
290 290 'user', user, form_result, preserve)
291 291
292 292 def _set_new_user_group_perms(self, user_group, form_result, preserve=None):
293 293 return self._set_new_object_perms(
294 294 'user_group', user_group, form_result, preserve)
295 295
296 296 def set_new_user_perms(self, user, form_result):
297 297 # calculate what to preserve from what is given in form_result
298 298 preserve = set(self.global_perms.keys()).difference(set(form_result.keys()))
299 299 return self._set_new_user_perms(user, form_result, preserve)
300 300
301 301 def set_new_user_group_perms(self, user_group, form_result):
302 302 # calculate what to preserve from what is given in form_result
303 303 preserve = set(self.global_perms.keys()).difference(set(form_result.keys()))
304 304 return self._set_new_user_group_perms(user_group, form_result, preserve)
305 305
306 306 def create_permissions(self):
307 307 """
308 308 Create permissions for whole system
309 309 """
310 310 for p in Permission.PERMS:
311 311 if not Permission.get_by_key(p[0]):
312 312 new_perm = Permission()
313 313 new_perm.permission_name = p[0]
314 314 new_perm.permission_longname = p[0] # translation err with p[1]
315 315 self.sa.add(new_perm)
316 316
317 317 def _create_default_object_permission(self, obj_type, obj, obj_perms,
318 318 force=False):
319 319 if obj_type not in ['user', 'user_group']:
320 320 raise ValueError("obj_type must be on of 'user' or 'user_group'")
321 321
322 322 def _get_group(perm_name):
323 323 return '.'.join(perm_name.split('.')[:1])
324 324
325 325 defined_perms_groups = map(
326 326 _get_group, (x.permission.permission_name for x in obj_perms))
327 327 log.debug('GOT ALREADY DEFINED:%s', obj_perms)
328 328
329 329 if force:
330 330 self._clear_object_perm(obj_perms)
331 331 self.sa.commit()
332 332 defined_perms_groups = []
333 333 # for every default permission that needs to be created, we check if
334 334 # it's group is already defined, if it's not we create default perm
335 335 for perm_name in Permission.DEFAULT_USER_PERMISSIONS:
336 336 gr = _get_group(perm_name)
337 337 if gr not in defined_perms_groups:
338 338 log.debug('GR:%s not found, creating permission %s',
339 339 gr, perm_name)
340 340 if obj_type == 'user':
341 341 new_perm = self._make_new_user_perm(obj, perm_name)
342 342 self.sa.add(new_perm)
343 343 if obj_type == 'user_group':
344 344 new_perm = self._make_new_user_group_perm(obj, perm_name)
345 345 self.sa.add(new_perm)
346 346
347 347 def create_default_user_permissions(self, user, force=False):
348 348 """
349 349 Creates only missing default permissions for user, if force is set it
350 350 resets the default permissions for that user
351 351
352 352 :param user:
353 353 :param force:
354 354 """
355 355 user = self._get_user(user)
356 356 obj_perms = UserToPerm.query().filter(UserToPerm.user == user).all()
357 357 return self._create_default_object_permission(
358 358 'user', user, obj_perms, force)
359 359
360 360 def create_default_user_group_permissions(self, user_group, force=False):
361 361 """
362 362 Creates only missing default permissions for user group, if force is
363 363 set it resets the default permissions for that user group
364 364
365 365 :param user_group:
366 366 :param force:
367 367 """
368 368 user_group = self._get_user_group(user_group)
369 369 obj_perms = UserToPerm.query().filter(UserGroupToPerm.users_group == user_group).all()
370 370 return self._create_default_object_permission(
371 371 'user_group', user_group, obj_perms, force)
372 372
373 373 def update_application_permissions(self, form_result):
374 374 if 'perm_user_id' in form_result:
375 375 perm_user = User.get(safe_int(form_result['perm_user_id']))
376 376 else:
377 377 # used mostly to do lookup for default user
378 378 perm_user = User.get_by_username(form_result['perm_user_name'])
379 379
380 380 try:
381 381 # stage 1 set anonymous access
382 382 if perm_user.username == User.DEFAULT_USER:
383 383 perm_user.active = str2bool(form_result['anonymous'])
384 384 self.sa.add(perm_user)
385 385
386 386 # stage 2 reset defaults and set them from form data
387 387 self._set_new_user_perms(perm_user, form_result, preserve=[
388 388 'default_repo_perm',
389 389 'default_group_perm',
390 390 'default_user_group_perm',
391 391 'default_branch_perm',
392 392
393 393 'default_repo_group_create',
394 394 'default_user_group_create',
395 395 'default_repo_create_on_write',
396 396 'default_repo_create',
397 397 'default_fork_create',
398 'default_inherit_default_permissions',])
398 'default_inherit_default_permissions'])
399 399
400 400 self.sa.commit()
401 401 except (DatabaseError,):
402 402 log.error(traceback.format_exc())
403 403 self.sa.rollback()
404 404 raise
405 405
406 406 def update_user_permissions(self, form_result):
407 407 if 'perm_user_id' in form_result:
408 408 perm_user = User.get(safe_int(form_result['perm_user_id']))
409 409 else:
410 410 # used mostly to do lookup for default user
411 411 perm_user = User.get_by_username(form_result['perm_user_name'])
412 412 try:
413 413 # stage 2 reset defaults and set them from form data
414 414 self._set_new_user_perms(perm_user, form_result, preserve=[
415 415 'default_repo_perm',
416 416 'default_group_perm',
417 417 'default_user_group_perm',
418 418 'default_branch_perm',
419 419
420 420 'default_register',
421 421 'default_password_reset',
422 422 'default_extern_activate'])
423 423 self.sa.commit()
424 424 except (DatabaseError,):
425 425 log.error(traceback.format_exc())
426 426 self.sa.rollback()
427 427 raise
428 428
429 429 def update_user_group_permissions(self, form_result):
430 430 if 'perm_user_group_id' in form_result:
431 431 perm_user_group = UserGroup.get(safe_int(form_result['perm_user_group_id']))
432 432 else:
433 433 # used mostly to do lookup for default user
434 434 perm_user_group = UserGroup.get_by_group_name(form_result['perm_user_group_name'])
435 435 try:
436 436 # stage 2 reset defaults and set them from form data
437 437 self._set_new_user_group_perms(perm_user_group, form_result, preserve=[
438 438 'default_repo_perm',
439 439 'default_group_perm',
440 440 'default_user_group_perm',
441 441 'default_branch_perm',
442 442
443 443 'default_register',
444 444 'default_password_reset',
445 445 'default_extern_activate'])
446 446 self.sa.commit()
447 447 except (DatabaseError,):
448 448 log.error(traceback.format_exc())
449 449 self.sa.rollback()
450 450 raise
451 451
452 452 def update_object_permissions(self, form_result):
453 453 if 'perm_user_id' in form_result:
454 454 perm_user = User.get(safe_int(form_result['perm_user_id']))
455 455 else:
456 456 # used mostly to do lookup for default user
457 457 perm_user = User.get_by_username(form_result['perm_user_name'])
458 458 try:
459 459
460 460 # stage 2 reset defaults and set them from form data
461 461 self._set_new_user_perms(perm_user, form_result, preserve=[
462 462 'default_repo_group_create',
463 463 'default_user_group_create',
464 464 'default_repo_create_on_write',
465 465 'default_repo_create',
466 466 'default_fork_create',
467 467 'default_inherit_default_permissions',
468 468 'default_branch_perm',
469 469
470 470 'default_register',
471 471 'default_password_reset',
472 472 'default_extern_activate'])
473 473
474 474 # overwrite default repo permissions
475 475 if form_result['overwrite_default_repo']:
476 476 _def_name = form_result['default_repo_perm'].split('repository.')[-1]
477 477 _def = Permission.get_by_key('repository.' + _def_name)
478 478 for r2p in self.sa.query(UserRepoToPerm)\
479 479 .filter(UserRepoToPerm.user == perm_user)\
480 480 .all():
481 481 # don't reset PRIVATE repositories
482 482 if not r2p.repository.private:
483 483 r2p.permission = _def
484 484 self.sa.add(r2p)
485 485
486 486 # overwrite default repo group permissions
487 487 if form_result['overwrite_default_group']:
488 488 _def_name = form_result['default_group_perm'].split('group.')[-1]
489 489 _def = Permission.get_by_key('group.' + _def_name)
490 490 for g2p in self.sa.query(UserRepoGroupToPerm)\
491 491 .filter(UserRepoGroupToPerm.user == perm_user)\
492 492 .all():
493 493 g2p.permission = _def
494 494 self.sa.add(g2p)
495 495
496 496 # overwrite default user group permissions
497 497 if form_result['overwrite_default_user_group']:
498 498 _def_name = form_result['default_user_group_perm'].split('usergroup.')[-1]
499 499 # user groups
500 500 _def = Permission.get_by_key('usergroup.' + _def_name)
501 501 for g2p in self.sa.query(UserUserGroupToPerm)\
502 502 .filter(UserUserGroupToPerm.user == perm_user)\
503 503 .all():
504 504 g2p.permission = _def
505 505 self.sa.add(g2p)
506 506
507 507 # COMMIT
508 508 self.sa.commit()
509 509 except (DatabaseError,):
510 510 log.exception('Failed to set default object permissions')
511 511 self.sa.rollback()
512 512 raise
513 513
514 514 def update_branch_permissions(self, form_result):
515 515 if 'perm_user_id' in form_result:
516 516 perm_user = User.get(safe_int(form_result['perm_user_id']))
517 517 else:
518 518 # used mostly to do lookup for default user
519 519 perm_user = User.get_by_username(form_result['perm_user_name'])
520 520 try:
521 521
522 522 # stage 2 reset defaults and set them from form data
523 523 self._set_new_user_perms(perm_user, form_result, preserve=[
524 524 'default_repo_perm',
525 525 'default_group_perm',
526 526 'default_user_group_perm',
527 527
528 528 'default_repo_group_create',
529 529 'default_user_group_create',
530 530 'default_repo_create_on_write',
531 531 'default_repo_create',
532 532 'default_fork_create',
533 533 'default_inherit_default_permissions',
534 534
535 535 'default_register',
536 536 'default_password_reset',
537 537 'default_extern_activate'])
538 538
539 539 # overwrite default branch permissions
540 540 if form_result['overwrite_default_branch']:
541 541 _def_name = \
542 542 form_result['default_branch_perm'].split('branch.')[-1]
543 543
544 544 _def = Permission.get_by_key('branch.' + _def_name)
545 545
546 546 user_perms = UserToRepoBranchPermission.query()\
547 547 .join(UserToRepoBranchPermission.user_repo_to_perm)\
548 548 .filter(UserRepoToPerm.user == perm_user).all()
549 549
550 550 for g2p in user_perms:
551 551 g2p.permission = _def
552 552 self.sa.add(g2p)
553 553
554 554 # COMMIT
555 555 self.sa.commit()
556 556 except (DatabaseError,):
557 557 log.exception('Failed to set default branch permissions')
558 558 self.sa.rollback()
559 559 raise
560 560
561 561 def get_users_with_repo_write(self, db_repo):
562 562 write_plus = ['repository.write', 'repository.admin']
563 563 default_user_id = User.get_default_user_id()
564 564 user_write_permissions = collections.OrderedDict()
565 565
566 566 # write or higher and DEFAULT user for inheritance
567 567 for perm in db_repo.permissions():
568 568 if perm.permission in write_plus or perm.user_id == default_user_id:
569 569 user_write_permissions[perm.user_id] = perm
570 570 return user_write_permissions
571 571
572 572 def get_user_groups_with_repo_write(self, db_repo):
573 573 write_plus = ['repository.write', 'repository.admin']
574 574 user_group_write_permissions = collections.OrderedDict()
575 575
576 576 # write or higher and DEFAULT user for inheritance
577 577 for p in db_repo.permission_user_groups():
578 578 if p.permission in write_plus:
579 579 user_group_write_permissions[p.users_group_id] = p
580 580 return user_group_write_permissions
581 581
582 582 def trigger_permission_flush(self, affected_user_ids=None):
583 583 affected_user_ids = affected_user_ids or User.get_all_user_ids()
584 584 events.trigger(events.UserPermissionsChange(affected_user_ids))
585 585
586 586 def flush_user_permission_caches(self, changes, affected_user_ids=None):
587 587 affected_user_ids = affected_user_ids or []
588 588
589 589 for change in changes['added'] + changes['updated'] + changes['deleted']:
590 590 if change['type'] == 'user':
591 591 affected_user_ids.append(change['id'])
592 592 if change['type'] == 'user_group':
593 593 user_group = UserGroup.get(safe_int(change['id']))
594 594 if user_group:
595 595 group_members_ids = [x.user_id for x in user_group.members]
596 596 affected_user_ids.extend(group_members_ids)
597 597
598 598 self.trigger_permission_flush(affected_user_ids)
599 599
600 600 return affected_user_ids
@@ -1,1028 +1,1028 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 """
22 22 Scm model for RhodeCode
23 23 """
24 24
25 25 import os.path
26 26 import traceback
27 27 import logging
28 from io import StringIO
28 import io
29 29
30 30 from sqlalchemy import func
31 31 from zope.cachedescriptors.property import Lazy as LazyProperty
32 32
33 33 import rhodecode
34 34 from rhodecode.lib.vcs import get_backend
35 35 from rhodecode.lib.vcs.exceptions import RepositoryError, NodeNotChangedError
36 36 from rhodecode.lib.vcs.nodes import FileNode
37 37 from rhodecode.lib.vcs.backends.base import EmptyCommit
38 38 from rhodecode.lib import helpers as h, rc_cache
39 39 from rhodecode.lib.auth import (
40 40 HasRepoPermissionAny, HasRepoGroupPermissionAny,
41 41 HasUserGroupPermissionAny)
42 42 from rhodecode.lib.exceptions import NonRelativePathError, IMCCommitError
43 43 from rhodecode.lib import hooks_utils
44 44 from rhodecode.lib.utils import (
45 45 get_filesystem_repos, make_db_config)
46 46 from rhodecode.lib.utils2 import (safe_str, safe_unicode)
47 47 from rhodecode.lib.system_info import get_system_info
48 48 from rhodecode.model import BaseModel
49 49 from rhodecode.model.db import (
50 50 or_, false,
51 51 Repository, CacheKey, UserFollowing, UserLog, User, RepoGroup,
52 52 PullRequest, FileStore)
53 53 from rhodecode.model.settings import VcsSettingsModel
54 54 from rhodecode.model.validation_schema.validators import url_validator, InvalidCloneUrl
55 55
56 56 log = logging.getLogger(__name__)
57 57
58 58
59 59 class UserTemp(object):
60 60 def __init__(self, user_id):
61 61 self.user_id = user_id
62 62
63 63 def __repr__(self):
64 64 return "<%s('id:%s')>" % (self.__class__.__name__, self.user_id)
65 65
66 66
67 67 class RepoTemp(object):
68 68 def __init__(self, repo_id):
69 69 self.repo_id = repo_id
70 70
71 71 def __repr__(self):
72 72 return "<%s('id:%s')>" % (self.__class__.__name__, self.repo_id)
73 73
74 74
75 75 class SimpleCachedRepoList(object):
76 76 """
77 77 Lighter version of of iteration of repos without the scm initialisation,
78 78 and with cache usage
79 79 """
80 80 def __init__(self, db_repo_list, repos_path, order_by=None, perm_set=None):
81 81 self.db_repo_list = db_repo_list
82 82 self.repos_path = repos_path
83 83 self.order_by = order_by
84 84 self.reversed = (order_by or '').startswith('-')
85 85 if not perm_set:
86 86 perm_set = ['repository.read', 'repository.write',
87 87 'repository.admin']
88 88 self.perm_set = perm_set
89 89
90 90 def __len__(self):
91 91 return len(self.db_repo_list)
92 92
93 93 def __repr__(self):
94 94 return '<%s (%s)>' % (self.__class__.__name__, self.__len__())
95 95
96 96 def __iter__(self):
97 97 for dbr in self.db_repo_list:
98 98 # check permission at this level
99 99 has_perm = HasRepoPermissionAny(*self.perm_set)(
100 100 dbr.repo_name, 'SimpleCachedRepoList check')
101 101 if not has_perm:
102 102 continue
103 103
104 104 tmp_d = {
105 105 'name': dbr.repo_name,
106 106 'dbrepo': dbr.get_dict(),
107 107 'dbrepo_fork': dbr.fork.get_dict() if dbr.fork else {}
108 108 }
109 109 yield tmp_d
110 110
111 111
112 112 class _PermCheckIterator(object):
113 113
114 114 def __init__(
115 115 self, obj_list, obj_attr, perm_set, perm_checker,
116 116 extra_kwargs=None):
117 117 """
118 118 Creates iterator from given list of objects, additionally
119 119 checking permission for them from perm_set var
120 120
121 121 :param obj_list: list of db objects
122 122 :param obj_attr: attribute of object to pass into perm_checker
123 123 :param perm_set: list of permissions to check
124 124 :param perm_checker: callable to check permissions against
125 125 """
126 126 self.obj_list = obj_list
127 127 self.obj_attr = obj_attr
128 128 self.perm_set = perm_set
129 129 self.perm_checker = perm_checker(*self.perm_set)
130 130 self.extra_kwargs = extra_kwargs or {}
131 131
132 132 def __len__(self):
133 133 return len(self.obj_list)
134 134
135 135 def __repr__(self):
136 136 return '<%s (%s)>' % (self.__class__.__name__, self.__len__())
137 137
138 138 def __iter__(self):
139 139 for db_obj in self.obj_list:
140 140 # check permission at this level
141 141 # NOTE(marcink): the __dict__.get() is ~4x faster then getattr()
142 142 name = db_obj.__dict__.get(self.obj_attr, None)
143 143 if not self.perm_checker(name, self.__class__.__name__, **self.extra_kwargs):
144 144 continue
145 145
146 146 yield db_obj
147 147
148 148
149 149 class RepoList(_PermCheckIterator):
150 150
151 151 def __init__(self, db_repo_list, perm_set=None, extra_kwargs=None):
152 152 if not perm_set:
153 153 perm_set = ['repository.read', 'repository.write', 'repository.admin']
154 154
155 155 super(RepoList, self).__init__(
156 156 obj_list=db_repo_list,
157 157 obj_attr='_repo_name', perm_set=perm_set,
158 158 perm_checker=HasRepoPermissionAny,
159 159 extra_kwargs=extra_kwargs)
160 160
161 161
162 162 class RepoGroupList(_PermCheckIterator):
163 163
164 164 def __init__(self, db_repo_group_list, perm_set=None, extra_kwargs=None):
165 165 if not perm_set:
166 166 perm_set = ['group.read', 'group.write', 'group.admin']
167 167
168 168 super(RepoGroupList, self).__init__(
169 169 obj_list=db_repo_group_list,
170 170 obj_attr='_group_name', perm_set=perm_set,
171 171 perm_checker=HasRepoGroupPermissionAny,
172 172 extra_kwargs=extra_kwargs)
173 173
174 174
175 175 class UserGroupList(_PermCheckIterator):
176 176
177 177 def __init__(self, db_user_group_list, perm_set=None, extra_kwargs=None):
178 178 if not perm_set:
179 179 perm_set = ['usergroup.read', 'usergroup.write', 'usergroup.admin']
180 180
181 181 super(UserGroupList, self).__init__(
182 182 obj_list=db_user_group_list,
183 183 obj_attr='users_group_name', perm_set=perm_set,
184 184 perm_checker=HasUserGroupPermissionAny,
185 185 extra_kwargs=extra_kwargs)
186 186
187 187
188 188 class ScmModel(BaseModel):
189 189 """
190 190 Generic Scm Model
191 191 """
192 192
193 193 @LazyProperty
194 194 def repos_path(self):
195 195 """
196 196 Gets the repositories root path from database
197 197 """
198 198
199 199 settings_model = VcsSettingsModel(sa=self.sa)
200 200 return settings_model.get_repos_location()
201 201
202 202 def repo_scan(self, repos_path=None):
203 203 """
204 204 Listing of repositories in given path. This path should not be a
205 205 repository itself. Return a dictionary of repository objects
206 206
207 207 :param repos_path: path to directory containing repositories
208 208 """
209 209
210 210 if repos_path is None:
211 211 repos_path = self.repos_path
212 212
213 213 log.info('scanning for repositories in %s', repos_path)
214 214
215 215 config = make_db_config()
216 216 config.set('extensions', 'largefiles', '')
217 217 repos = {}
218 218
219 219 for name, path in get_filesystem_repos(repos_path, recursive=True):
220 220 # name need to be decomposed and put back together using the /
221 221 # since this is internal storage separator for rhodecode
222 222 name = Repository.normalize_repo_name(name)
223 223
224 224 try:
225 225 if name in repos:
226 226 raise RepositoryError('Duplicate repository name %s '
227 227 'found in %s' % (name, path))
228 228 elif path[0] in rhodecode.BACKENDS:
229 229 backend = get_backend(path[0])
230 230 repos[name] = backend(path[1], config=config,
231 231 with_wire={"cache": False})
232 232 except OSError:
233 233 continue
234 234 except RepositoryError:
235 235 log.exception('Failed to create a repo')
236 236 continue
237 237
238 238 log.debug('found %s paths with repositories', len(repos))
239 239 return repos
240 240
241 241 def get_repos(self, all_repos=None, sort_key=None):
242 242 """
243 243 Get all repositories from db and for each repo create it's
244 244 backend instance and fill that backed with information from database
245 245
246 246 :param all_repos: list of repository names as strings
247 247 give specific repositories list, good for filtering
248 248
249 249 :param sort_key: initial sorting of repositories
250 250 """
251 251 if all_repos is None:
252 252 all_repos = self.sa.query(Repository)\
253 253 .filter(Repository.group_id == None)\
254 254 .order_by(func.lower(Repository.repo_name)).all()
255 255 repo_iter = SimpleCachedRepoList(
256 256 all_repos, repos_path=self.repos_path, order_by=sort_key)
257 257 return repo_iter
258 258
259 259 def get_repo_groups(self, all_groups=None):
260 260 if all_groups is None:
261 261 all_groups = RepoGroup.query()\
262 262 .filter(RepoGroup.group_parent_id == None).all()
263 263 return [x for x in RepoGroupList(all_groups)]
264 264
265 265 def mark_for_invalidation(self, repo_name, delete=False):
266 266 """
267 267 Mark caches of this repo invalid in the database. `delete` flag
268 268 removes the cache entries
269 269
270 270 :param repo_name: the repo_name for which caches should be marked
271 271 invalid, or deleted
272 272 :param delete: delete the entry keys instead of setting bool
273 273 flag on them, and also purge caches used by the dogpile
274 274 """
275 275 repo = Repository.get_by_repo_name(repo_name)
276 276
277 277 if repo:
278 278 invalidation_namespace = CacheKey.REPO_INVALIDATION_NAMESPACE.format(
279 279 repo_id=repo.repo_id)
280 280 CacheKey.set_invalidate(invalidation_namespace, delete=delete)
281 281
282 282 repo_id = repo.repo_id
283 283 config = repo._config
284 284 config.set('extensions', 'largefiles', '')
285 285 repo.update_commit_cache(config=config, cs_cache=None)
286 286 if delete:
287 287 cache_namespace_uid = 'cache_repo.{}'.format(repo_id)
288 288 rc_cache.clear_cache_namespace(
289 289 'cache_repo', cache_namespace_uid, invalidate=True)
290 290
291 291 def toggle_following_repo(self, follow_repo_id, user_id):
292 292
293 293 f = self.sa.query(UserFollowing)\
294 294 .filter(UserFollowing.follows_repo_id == follow_repo_id)\
295 295 .filter(UserFollowing.user_id == user_id).scalar()
296 296
297 297 if f is not None:
298 298 try:
299 299 self.sa.delete(f)
300 300 return
301 301 except Exception:
302 302 log.error(traceback.format_exc())
303 303 raise
304 304
305 305 try:
306 306 f = UserFollowing()
307 307 f.user_id = user_id
308 308 f.follows_repo_id = follow_repo_id
309 309 self.sa.add(f)
310 310 except Exception:
311 311 log.error(traceback.format_exc())
312 312 raise
313 313
314 314 def toggle_following_user(self, follow_user_id, user_id):
315 315 f = self.sa.query(UserFollowing)\
316 316 .filter(UserFollowing.follows_user_id == follow_user_id)\
317 317 .filter(UserFollowing.user_id == user_id).scalar()
318 318
319 319 if f is not None:
320 320 try:
321 321 self.sa.delete(f)
322 322 return
323 323 except Exception:
324 324 log.error(traceback.format_exc())
325 325 raise
326 326
327 327 try:
328 328 f = UserFollowing()
329 329 f.user_id = user_id
330 330 f.follows_user_id = follow_user_id
331 331 self.sa.add(f)
332 332 except Exception:
333 333 log.error(traceback.format_exc())
334 334 raise
335 335
336 336 def is_following_repo(self, repo_name, user_id, cache=False):
337 337 r = self.sa.query(Repository)\
338 338 .filter(Repository.repo_name == repo_name).scalar()
339 339
340 340 f = self.sa.query(UserFollowing)\
341 341 .filter(UserFollowing.follows_repository == r)\
342 342 .filter(UserFollowing.user_id == user_id).scalar()
343 343
344 344 return f is not None
345 345
346 346 def is_following_user(self, username, user_id, cache=False):
347 347 u = User.get_by_username(username)
348 348
349 349 f = self.sa.query(UserFollowing)\
350 350 .filter(UserFollowing.follows_user == u)\
351 351 .filter(UserFollowing.user_id == user_id).scalar()
352 352
353 353 return f is not None
354 354
355 355 def get_followers(self, repo):
356 356 repo = self._get_repo(repo)
357 357
358 358 return self.sa.query(UserFollowing)\
359 359 .filter(UserFollowing.follows_repository == repo).count()
360 360
361 361 def get_forks(self, repo):
362 362 repo = self._get_repo(repo)
363 363 return self.sa.query(Repository)\
364 364 .filter(Repository.fork == repo).count()
365 365
366 366 def get_pull_requests(self, repo):
367 367 repo = self._get_repo(repo)
368 368 return self.sa.query(PullRequest)\
369 369 .filter(PullRequest.target_repo == repo)\
370 370 .filter(PullRequest.status != PullRequest.STATUS_CLOSED).count()
371 371
372 372 def get_artifacts(self, repo):
373 373 repo = self._get_repo(repo)
374 374 return self.sa.query(FileStore)\
375 375 .filter(FileStore.repo == repo)\
376 376 .filter(or_(FileStore.hidden == None, FileStore.hidden == false())).count()
377 377
378 378 def mark_as_fork(self, repo, fork, user):
379 379 repo = self._get_repo(repo)
380 380 fork = self._get_repo(fork)
381 381 if fork and repo.repo_id == fork.repo_id:
382 382 raise Exception("Cannot set repository as fork of itself")
383 383
384 384 if fork and repo.repo_type != fork.repo_type:
385 385 raise RepositoryError(
386 386 "Cannot set repository as fork of repository with other type")
387 387
388 388 repo.fork = fork
389 389 self.sa.add(repo)
390 390 return repo
391 391
392 392 def pull_changes(self, repo, username, remote_uri=None, validate_uri=True):
393 393 dbrepo = self._get_repo(repo)
394 394 remote_uri = remote_uri or dbrepo.clone_uri
395 395 if not remote_uri:
396 396 raise Exception("This repository doesn't have a clone uri")
397 397
398 398 repo = dbrepo.scm_instance(cache=False)
399 399 repo.config.clear_section('hooks')
400 400
401 401 try:
402 402 # NOTE(marcink): add extra validation so we skip invalid urls
403 403 # this is due this tasks can be executed via scheduler without
404 404 # proper validation of remote_uri
405 405 if validate_uri:
406 406 config = make_db_config(clear_session=False)
407 407 url_validator(remote_uri, dbrepo.repo_type, config)
408 408 except InvalidCloneUrl:
409 409 raise
410 410
411 411 repo_name = dbrepo.repo_name
412 412 try:
413 413 # TODO: we need to make sure those operations call proper hooks !
414 414 repo.fetch(remote_uri)
415 415
416 416 self.mark_for_invalidation(repo_name)
417 417 except Exception:
418 418 log.error(traceback.format_exc())
419 419 raise
420 420
421 421 def push_changes(self, repo, username, remote_uri=None, validate_uri=True):
422 422 dbrepo = self._get_repo(repo)
423 423 remote_uri = remote_uri or dbrepo.push_uri
424 424 if not remote_uri:
425 425 raise Exception("This repository doesn't have a clone uri")
426 426
427 427 repo = dbrepo.scm_instance(cache=False)
428 428 repo.config.clear_section('hooks')
429 429
430 430 try:
431 431 # NOTE(marcink): add extra validation so we skip invalid urls
432 432 # this is due this tasks can be executed via scheduler without
433 433 # proper validation of remote_uri
434 434 if validate_uri:
435 435 config = make_db_config(clear_session=False)
436 436 url_validator(remote_uri, dbrepo.repo_type, config)
437 437 except InvalidCloneUrl:
438 438 raise
439 439
440 440 try:
441 441 repo.push(remote_uri)
442 442 except Exception:
443 443 log.error(traceback.format_exc())
444 444 raise
445 445
446 446 def commit_change(self, repo, repo_name, commit, user, author, message,
447 447 content, f_path):
448 448 """
449 449 Commits changes
450 450
451 451 :param repo: SCM instance
452 452
453 453 """
454 454 user = self._get_user(user)
455 455
456 456 # decoding here will force that we have proper encoded values
457 457 # in any other case this will throw exceptions and deny commit
458 458 content = safe_str(content)
459 459 path = safe_str(f_path)
460 460 # message and author needs to be unicode
461 461 # proper backend should then translate that into required type
462 462 message = safe_unicode(message)
463 463 author = safe_unicode(author)
464 464 imc = repo.in_memory_commit
465 465 imc.change(FileNode(path, content, mode=commit.get_file_mode(f_path)))
466 466 try:
467 467 # TODO: handle pre-push action !
468 468 tip = imc.commit(
469 469 message=message, author=author, parents=[commit],
470 470 branch=commit.branch)
471 471 except Exception as e:
472 472 log.error(traceback.format_exc())
473 473 raise IMCCommitError(str(e))
474 474 finally:
475 475 # always clear caches, if commit fails we want fresh object also
476 476 self.mark_for_invalidation(repo_name)
477 477
478 478 # We trigger the post-push action
479 479 hooks_utils.trigger_post_push_hook(
480 480 username=user.username, action='push_local', hook_type='post_push',
481 481 repo_name=repo_name, repo_type=repo.alias, commit_ids=[tip.raw_id])
482 482 return tip
483 483
484 484 def _sanitize_path(self, f_path):
485 485 if f_path.startswith('/') or f_path.startswith('./') or '../' in f_path:
486 486 raise NonRelativePathError('%s is not an relative path' % f_path)
487 487 if f_path:
488 488 f_path = os.path.normpath(f_path)
489 489 return f_path
490 490
491 491 def get_dirnode_metadata(self, request, commit, dir_node):
492 492 if not dir_node.is_dir():
493 493 return []
494 494
495 495 data = []
496 496 for node in dir_node:
497 497 if not node.is_file():
498 498 # we skip file-nodes
499 499 continue
500 500
501 501 last_commit = node.last_commit
502 502 last_commit_date = last_commit.date
503 503 data.append({
504 504 'name': node.name,
505 505 'size': h.format_byte_size_binary(node.size),
506 506 'modified_at': h.format_date(last_commit_date),
507 507 'modified_ts': last_commit_date.isoformat(),
508 508 'revision': last_commit.revision,
509 509 'short_id': last_commit.short_id,
510 510 'message': h.escape(last_commit.message),
511 511 'author': h.escape(last_commit.author),
512 512 'user_profile': h.gravatar_with_user(
513 513 request, last_commit.author),
514 514 })
515 515
516 516 return data
517 517
518 518 def get_nodes(self, repo_name, commit_id, root_path='/', flat=True,
519 519 extended_info=False, content=False, max_file_bytes=None):
520 520 """
521 521 recursive walk in root dir and return a set of all path in that dir
522 522 based on repository walk function
523 523
524 524 :param repo_name: name of repository
525 525 :param commit_id: commit id for which to list nodes
526 526 :param root_path: root path to list
527 527 :param flat: return as a list, if False returns a dict with description
528 528 :param extended_info: show additional info such as md5, binary, size etc
529 529 :param content: add nodes content to the return data
530 530 :param max_file_bytes: will not return file contents over this limit
531 531
532 532 """
533 533 _files = list()
534 534 _dirs = list()
535 535 try:
536 536 _repo = self._get_repo(repo_name)
537 537 commit = _repo.scm_instance().get_commit(commit_id=commit_id)
538 538 root_path = root_path.lstrip('/')
539 539 for __, dirs, files in commit.walk(root_path):
540 540
541 541 for f in files:
542 542 _content = None
543 543 _data = f_name = f.unicode_path
544 544
545 545 if not flat:
546 546 _data = {
547 547 "name": h.escape(f_name),
548 548 "type": "file",
549 549 }
550 550 if extended_info:
551 551 _data.update({
552 552 "md5": f.md5,
553 553 "binary": f.is_binary,
554 554 "size": f.size,
555 555 "extension": f.extension,
556 556 "mimetype": f.mimetype,
557 557 "lines": f.lines()[0]
558 558 })
559 559
560 560 if content:
561 561 over_size_limit = (max_file_bytes is not None
562 562 and f.size > max_file_bytes)
563 563 full_content = None
564 564 if not f.is_binary and not over_size_limit:
565 565 full_content = safe_str(f.content)
566 566
567 567 _data.update({
568 568 "content": full_content,
569 569 })
570 570 _files.append(_data)
571 571
572 572 for d in dirs:
573 573 _data = d_name = d.unicode_path
574 574 if not flat:
575 575 _data = {
576 576 "name": h.escape(d_name),
577 577 "type": "dir",
578 578 }
579 579 if extended_info:
580 580 _data.update({
581 581 "md5": None,
582 582 "binary": None,
583 583 "size": None,
584 584 "extension": None,
585 585 })
586 586 if content:
587 587 _data.update({
588 588 "content": None
589 589 })
590 590 _dirs.append(_data)
591 591 except RepositoryError:
592 592 log.exception("Exception in get_nodes")
593 593 raise
594 594
595 595 return _dirs, _files
596 596
597 597 def get_quick_filter_nodes(self, repo_name, commit_id, root_path='/'):
598 598 """
599 599 Generate files for quick filter in files view
600 600 """
601 601
602 602 _files = list()
603 603 _dirs = list()
604 604 try:
605 605 _repo = self._get_repo(repo_name)
606 606 commit = _repo.scm_instance().get_commit(commit_id=commit_id)
607 607 root_path = root_path.lstrip('/')
608 608 for __, dirs, files in commit.walk(root_path):
609 609
610 610 for f in files:
611 611
612 612 _data = {
613 613 "name": h.escape(f.unicode_path),
614 614 "type": "file",
615 615 }
616 616
617 617 _files.append(_data)
618 618
619 619 for d in dirs:
620 620
621 621 _data = {
622 622 "name": h.escape(d.unicode_path),
623 623 "type": "dir",
624 624 }
625 625
626 626 _dirs.append(_data)
627 627 except RepositoryError:
628 628 log.exception("Exception in get_quick_filter_nodes")
629 629 raise
630 630
631 631 return _dirs, _files
632 632
633 633 def get_node(self, repo_name, commit_id, file_path,
634 634 extended_info=False, content=False, max_file_bytes=None, cache=True):
635 635 """
636 636 retrieve single node from commit
637 637 """
638 638 try:
639 639
640 640 _repo = self._get_repo(repo_name)
641 641 commit = _repo.scm_instance().get_commit(commit_id=commit_id)
642 642
643 643 file_node = commit.get_node(file_path)
644 644 if file_node.is_dir():
645 645 raise RepositoryError('The given path is a directory')
646 646
647 647 _content = None
648 648 f_name = file_node.unicode_path
649 649
650 650 file_data = {
651 651 "name": h.escape(f_name),
652 652 "type": "file",
653 653 }
654 654
655 655 if extended_info:
656 656 file_data.update({
657 657 "extension": file_node.extension,
658 658 "mimetype": file_node.mimetype,
659 659 })
660 660
661 661 if cache:
662 662 md5 = file_node.md5
663 663 is_binary = file_node.is_binary
664 664 size = file_node.size
665 665 else:
666 666 is_binary, md5, size, _content = file_node.metadata_uncached()
667 667
668 668 file_data.update({
669 669 "md5": md5,
670 670 "binary": is_binary,
671 671 "size": size,
672 672 })
673 673
674 674 if content and cache:
675 675 # get content + cache
676 676 size = file_node.size
677 677 over_size_limit = (max_file_bytes is not None and size > max_file_bytes)
678 678 full_content = None
679 679 all_lines = 0
680 680 if not file_node.is_binary and not over_size_limit:
681 681 full_content = safe_unicode(file_node.content)
682 682 all_lines, empty_lines = file_node.count_lines(full_content)
683 683
684 684 file_data.update({
685 685 "content": full_content,
686 686 "lines": all_lines
687 687 })
688 688 elif content:
689 689 # get content *without* cache
690 690 if _content is None:
691 691 is_binary, md5, size, _content = file_node.metadata_uncached()
692 692
693 693 over_size_limit = (max_file_bytes is not None and size > max_file_bytes)
694 694 full_content = None
695 695 all_lines = 0
696 696 if not is_binary and not over_size_limit:
697 697 full_content = safe_unicode(_content)
698 698 all_lines, empty_lines = file_node.count_lines(full_content)
699 699
700 700 file_data.update({
701 701 "content": full_content,
702 702 "lines": all_lines
703 703 })
704 704
705 705 except RepositoryError:
706 706 log.exception("Exception in get_node")
707 707 raise
708 708
709 709 return file_data
710 710
711 711 def get_fts_data(self, repo_name, commit_id, root_path='/'):
712 712 """
713 713 Fetch node tree for usage in full text search
714 714 """
715 715
716 716 tree_info = list()
717 717
718 718 try:
719 719 _repo = self._get_repo(repo_name)
720 720 commit = _repo.scm_instance().get_commit(commit_id=commit_id)
721 721 root_path = root_path.lstrip('/')
722 722 for __, dirs, files in commit.walk(root_path):
723 723
724 724 for f in files:
725 725 is_binary, md5, size, _content = f.metadata_uncached()
726 726 _data = {
727 727 "name": f.unicode_path,
728 728 "md5": md5,
729 729 "extension": f.extension,
730 730 "binary": is_binary,
731 731 "size": size
732 732 }
733 733
734 734 tree_info.append(_data)
735 735
736 736 except RepositoryError:
737 737 log.exception("Exception in get_nodes")
738 738 raise
739 739
740 740 return tree_info
741 741
742 742 def create_nodes(self, user, repo, message, nodes, parent_commit=None,
743 743 author=None, trigger_push_hook=True):
744 744 """
745 745 Commits given multiple nodes into repo
746 746
747 747 :param user: RhodeCode User object or user_id, the commiter
748 748 :param repo: RhodeCode Repository object
749 749 :param message: commit message
750 750 :param nodes: mapping {filename:{'content':content},...}
751 751 :param parent_commit: parent commit, can be empty than it's
752 752 initial commit
753 753 :param author: author of commit, cna be different that commiter
754 754 only for git
755 755 :param trigger_push_hook: trigger push hooks
756 756
757 757 :returns: new committed commit
758 758 """
759 759
760 760 user = self._get_user(user)
761 761 scm_instance = repo.scm_instance(cache=False)
762 762
763 763 processed_nodes = []
764 764 for f_path in nodes:
765 765 f_path = self._sanitize_path(f_path)
766 766 content = nodes[f_path]['content']
767 767 f_path = safe_str(f_path)
768 768 # decoding here will force that we have proper encoded values
769 769 # in any other case this will throw exceptions and deny commit
770 770 if isinstance(content, (str,)):
771 771 content = safe_str(content)
772 772 elif isinstance(content, (file, cStringIO.OutputType,)):
773 773 content = content.read()
774 774 else:
775 775 raise Exception('Content is of unrecognized type %s' % (
776 776 type(content)
777 777 ))
778 778 processed_nodes.append((f_path, content))
779 779
780 780 message = safe_unicode(message)
781 781 commiter = user.full_contact
782 782 author = safe_unicode(author) if author else commiter
783 783
784 784 imc = scm_instance.in_memory_commit
785 785
786 786 if not parent_commit:
787 787 parent_commit = EmptyCommit(alias=scm_instance.alias)
788 788
789 789 if isinstance(parent_commit, EmptyCommit):
790 790 # EmptyCommit means we we're editing empty repository
791 791 parents = None
792 792 else:
793 793 parents = [parent_commit]
794 794 # add multiple nodes
795 795 for path, content in processed_nodes:
796 796 imc.add(FileNode(path, content=content))
797 797 # TODO: handle pre push scenario
798 798 tip = imc.commit(message=message,
799 799 author=author,
800 800 parents=parents,
801 801 branch=parent_commit.branch)
802 802
803 803 self.mark_for_invalidation(repo.repo_name)
804 804 if trigger_push_hook:
805 805 hooks_utils.trigger_post_push_hook(
806 806 username=user.username, action='push_local',
807 807 repo_name=repo.repo_name, repo_type=scm_instance.alias,
808 808 hook_type='post_push',
809 809 commit_ids=[tip.raw_id])
810 810 return tip
811 811
812 812 def update_nodes(self, user, repo, message, nodes, parent_commit=None,
813 813 author=None, trigger_push_hook=True):
814 814 user = self._get_user(user)
815 815 scm_instance = repo.scm_instance(cache=False)
816 816
817 817 message = safe_unicode(message)
818 818 commiter = user.full_contact
819 819 author = safe_unicode(author) if author else commiter
820 820
821 821 imc = scm_instance.in_memory_commit
822 822
823 823 if not parent_commit:
824 824 parent_commit = EmptyCommit(alias=scm_instance.alias)
825 825
826 826 if isinstance(parent_commit, EmptyCommit):
827 827 # EmptyCommit means we we're editing empty repository
828 828 parents = None
829 829 else:
830 830 parents = [parent_commit]
831 831
832 832 # add multiple nodes
833 833 for _filename, data in nodes.items():
834 834 # new filename, can be renamed from the old one, also sanitaze
835 835 # the path for any hack around relative paths like ../../ etc.
836 836 filename = self._sanitize_path(data['filename'])
837 837 old_filename = self._sanitize_path(_filename)
838 838 content = data['content']
839 839 file_mode = data.get('mode')
840 840 filenode = FileNode(old_filename, content=content, mode=file_mode)
841 841 op = data['op']
842 842 if op == 'add':
843 843 imc.add(filenode)
844 844 elif op == 'del':
845 845 imc.remove(filenode)
846 846 elif op == 'mod':
847 847 if filename != old_filename:
848 848 # TODO: handle renames more efficient, needs vcs lib changes
849 849 imc.remove(filenode)
850 850 imc.add(FileNode(filename, content=content, mode=file_mode))
851 851 else:
852 852 imc.change(filenode)
853 853
854 854 try:
855 855 # TODO: handle pre push scenario commit changes
856 856 tip = imc.commit(message=message,
857 857 author=author,
858 858 parents=parents,
859 859 branch=parent_commit.branch)
860 860 except NodeNotChangedError:
861 861 raise
862 862 except Exception as e:
863 863 log.exception("Unexpected exception during call to imc.commit")
864 864 raise IMCCommitError(str(e))
865 865 finally:
866 866 # always clear caches, if commit fails we want fresh object also
867 867 self.mark_for_invalidation(repo.repo_name)
868 868
869 869 if trigger_push_hook:
870 870 hooks_utils.trigger_post_push_hook(
871 871 username=user.username, action='push_local', hook_type='post_push',
872 872 repo_name=repo.repo_name, repo_type=scm_instance.alias,
873 873 commit_ids=[tip.raw_id])
874 874
875 875 return tip
876 876
877 877 def delete_nodes(self, user, repo, message, nodes, parent_commit=None,
878 878 author=None, trigger_push_hook=True):
879 879 """
880 880 Deletes given multiple nodes into `repo`
881 881
882 882 :param user: RhodeCode User object or user_id, the committer
883 883 :param repo: RhodeCode Repository object
884 884 :param message: commit message
885 885 :param nodes: mapping {filename:{'content':content},...}
886 886 :param parent_commit: parent commit, can be empty than it's initial
887 887 commit
888 888 :param author: author of commit, cna be different that commiter only
889 889 for git
890 890 :param trigger_push_hook: trigger push hooks
891 891
892 892 :returns: new commit after deletion
893 893 """
894 894
895 895 user = self._get_user(user)
896 896 scm_instance = repo.scm_instance(cache=False)
897 897
898 898 processed_nodes = []
899 899 for f_path in nodes:
900 900 f_path = self._sanitize_path(f_path)
901 901 # content can be empty but for compatabilty it allows same dicts
902 902 # structure as add_nodes
903 903 content = nodes[f_path].get('content')
904 904 processed_nodes.append((f_path, content))
905 905
906 906 message = safe_unicode(message)
907 907 commiter = user.full_contact
908 908 author = safe_unicode(author) if author else commiter
909 909
910 910 imc = scm_instance.in_memory_commit
911 911
912 912 if not parent_commit:
913 913 parent_commit = EmptyCommit(alias=scm_instance.alias)
914 914
915 915 if isinstance(parent_commit, EmptyCommit):
916 916 # EmptyCommit means we we're editing empty repository
917 917 parents = None
918 918 else:
919 919 parents = [parent_commit]
920 920 # add multiple nodes
921 921 for path, content in processed_nodes:
922 922 imc.remove(FileNode(path, content=content))
923 923
924 924 # TODO: handle pre push scenario
925 925 tip = imc.commit(message=message,
926 926 author=author,
927 927 parents=parents,
928 928 branch=parent_commit.branch)
929 929
930 930 self.mark_for_invalidation(repo.repo_name)
931 931 if trigger_push_hook:
932 932 hooks_utils.trigger_post_push_hook(
933 933 username=user.username, action='push_local', hook_type='post_push',
934 934 repo_name=repo.repo_name, repo_type=scm_instance.alias,
935 935 commit_ids=[tip.raw_id])
936 936 return tip
937 937
938 938 def strip(self, repo, commit_id, branch):
939 939 scm_instance = repo.scm_instance(cache=False)
940 940 scm_instance.config.clear_section('hooks')
941 941 scm_instance.strip(commit_id, branch)
942 942 self.mark_for_invalidation(repo.repo_name)
943 943
944 944 def get_unread_journal(self):
945 945 return self.sa.query(UserLog).count()
946 946
947 947 @classmethod
948 948 def backend_landing_ref(cls, repo_type):
949 949 """
950 950 Return a default landing ref based on a repository type.
951 951 """
952 952
953 953 landing_ref = {
954 954 'hg': ('branch:default', 'default'),
955 955 'git': ('branch:master', 'master'),
956 956 'svn': ('rev:tip', 'latest tip'),
957 957 'default': ('rev:tip', 'latest tip'),
958 958 }
959 959
960 960 return landing_ref.get(repo_type) or landing_ref['default']
961 961
962 962 def get_repo_landing_revs(self, translator, repo=None):
963 963 """
964 964 Generates select option with tags branches and bookmarks (for hg only)
965 965 grouped by type
966 966
967 967 :param repo:
968 968 """
969 969 from rhodecode.lib.vcs.backends.git import GitRepository
970 970
971 971 _ = translator
972 972 repo = self._get_repo(repo)
973 973
974 974 if repo:
975 975 repo_type = repo.repo_type
976 976 else:
977 977 repo_type = 'default'
978 978
979 979 default_landing_ref, landing_ref_lbl = self.backend_landing_ref(repo_type)
980 980
981 981 default_ref_options = [
982 982 [default_landing_ref, landing_ref_lbl]
983 983 ]
984 984 default_choices = [
985 985 default_landing_ref
986 986 ]
987 987
988 988 if not repo:
989 989 # presented at NEW repo creation
990 990 return default_choices, default_ref_options
991 991
992 992 repo = repo.scm_instance()
993 993
994 994 ref_options = [(default_landing_ref, landing_ref_lbl)]
995 995 choices = [default_landing_ref]
996 996
997 997 # branches
998 998 branch_group = [(u'branch:%s' % safe_unicode(b), safe_unicode(b)) for b in repo.branches]
999 999 if not branch_group:
1000 1000 # new repo, or without maybe a branch?
1001 1001 branch_group = default_ref_options
1002 1002
1003 1003 branches_group = (branch_group, _("Branches"))
1004 1004 ref_options.append(branches_group)
1005 1005 choices.extend([x[0] for x in branches_group[0]])
1006 1006
1007 1007 # bookmarks for HG
1008 1008 if repo.alias == 'hg':
1009 1009 bookmarks_group = (
1010 1010 [(u'book:%s' % safe_unicode(b), safe_unicode(b))
1011 1011 for b in repo.bookmarks],
1012 1012 _("Bookmarks"))
1013 1013 ref_options.append(bookmarks_group)
1014 1014 choices.extend([x[0] for x in bookmarks_group[0]])
1015 1015
1016 1016 # tags
1017 1017 tags_group = (
1018 1018 [(u'tag:%s' % safe_unicode(t), safe_unicode(t))
1019 1019 for t in repo.tags],
1020 1020 _("Tags"))
1021 1021 ref_options.append(tags_group)
1022 1022 choices.extend([x[0] for x in tags_group[0]])
1023 1023
1024 1024 return choices, ref_options
1025 1025
1026 1026 def get_server_info(self, environ=None):
1027 1027 server_info = get_system_info(environ)
1028 1028 return server_info
@@ -1,1115 +1,1115 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 """
22 22 Set of generic validators
23 23 """
24 24
25 25
26 26 import os
27 27 import re
28 28 import logging
29 29 import collections
30 30
31 31 import formencode
32 32 import ipaddress
33 33 from formencode.validators import (
34 34 UnicodeString, OneOf, Int, Number, Regex, Email, Bool, StringBoolean, Set,
35 35 NotEmpty, IPAddress, CIDR, String, FancyValidator
36 36 )
37 37
38 38 from sqlalchemy.sql.expression import true
39 39 from sqlalchemy.util import OrderedSet
40 40
41 41 from rhodecode.authentication import (
42 42 legacy_plugin_prefix, _import_legacy_plugin)
43 43 from rhodecode.authentication.base import loadplugin
44 44 from rhodecode.apps._base import ADMIN_PREFIX
45 45 from rhodecode.lib.auth import HasRepoGroupPermissionAny, HasPermissionAny
46 46 from rhodecode.lib.utils import repo_name_slug, make_db_config
47 47 from rhodecode.lib.utils2 import safe_int, str2bool, aslist, md5, safe_unicode
48 48 from rhodecode.lib.vcs.backends.git.repository import GitRepository
49 49 from rhodecode.lib.vcs.backends.hg.repository import MercurialRepository
50 50 from rhodecode.lib.vcs.backends.svn.repository import SubversionRepository
51 51 from rhodecode.model.db import (
52 52 RepoGroup, Repository, UserGroup, User, ChangesetStatus, Gist)
53 53 from rhodecode.model.settings import VcsSettingsModel
54 54
55 55 # silence warnings and pylint
56 56 UnicodeString, OneOf, Int, Number, Regex, Email, Bool, StringBoolean, Set, \
57 57 NotEmpty, IPAddress, CIDR, String, FancyValidator
58 58
59 59 log = logging.getLogger(__name__)
60 60
61 61
62 62 class _Missing(object):
63 63 pass
64 64
65 65
66 66 Missing = _Missing()
67 67
68 68
69 69 def M(self, key, state, **kwargs):
70 70 """
71 71 returns string from self.message based on given key,
72 72 passed kw params are used to substitute %(named)s params inside
73 73 translated strings
74 74
75 75 :param msg:
76 76 :param state:
77 77 """
78 78
79 79 #state._ = staticmethod(_)
80 80 # inject validator into state object
81 81 return self.message(key, state, **kwargs)
82 82
83 83
84 84 def UniqueList(localizer, convert=None):
85 85 _ = localizer
86 86
87 87 class _validator(formencode.FancyValidator):
88 88 """
89 89 Unique List !
90 90 """
91 91 messages = {
92 'empty': _(u'Value cannot be an empty list'),
93 'missing_value': _(u'Value cannot be an empty list'),
92 'empty': _('Value cannot be an empty list'),
93 'missing_value': _('Value cannot be an empty list'),
94 94 }
95 95
96 96 def _to_python(self, value, state):
97 97 ret_val = []
98 98
99 99 def make_unique(value):
100 100 seen = []
101 101 return [c for c in value if not (c in seen or seen.append(c))]
102 102
103 103 if isinstance(value, list):
104 104 ret_val = make_unique(value)
105 105 elif isinstance(value, set):
106 106 ret_val = make_unique(list(value))
107 107 elif isinstance(value, tuple):
108 108 ret_val = make_unique(list(value))
109 109 elif value is None:
110 110 ret_val = []
111 111 else:
112 112 ret_val = [value]
113 113
114 114 if convert:
115 115 ret_val = map(convert, ret_val)
116 116 return ret_val
117 117
118 118 def empty_value(self, value):
119 119 return []
120 120 return _validator
121 121
122 122
123 123 def UniqueListFromString(localizer):
124 124 _ = localizer
125 125
126 126 class _validator(UniqueList(localizer)):
127 127 def _to_python(self, value, state):
128 128 if isinstance(value, str):
129 129 value = aslist(value, ',')
130 130 return super(_validator, self)._to_python(value, state)
131 131 return _validator
132 132
133 133
134 134 def ValidSvnPattern(localizer, section, repo_name=None):
135 135 _ = localizer
136 136
137 137 class _validator(formencode.validators.FancyValidator):
138 138 messages = {
139 'pattern_exists': _(u'Pattern already exists'),
139 'pattern_exists': _('Pattern already exists'),
140 140 }
141 141
142 142 def validate_python(self, value, state):
143 143 if not value:
144 144 return
145 145 model = VcsSettingsModel(repo=repo_name)
146 146 ui_settings = model.get_svn_patterns(section=section)
147 147 for entry in ui_settings:
148 148 if value == entry.value:
149 149 msg = M(self, 'pattern_exists', state)
150 150 raise formencode.Invalid(msg, value, state)
151 151 return _validator
152 152
153 153
154 154 def ValidUsername(localizer, edit=False, old_data=None):
155 155 _ = localizer
156 156 old_data = old_data or {}
157 157
158 158 class _validator(formencode.validators.FancyValidator):
159 159 messages = {
160 'username_exists': _(u'Username "%(username)s" already exists'),
160 'username_exists': _('Username "%(username)s" already exists'),
161 161 'system_invalid_username':
162 _(u'Username "%(username)s" is forbidden'),
162 _('Username "%(username)s" is forbidden'),
163 163 'invalid_username':
164 _(u'Username may only contain alphanumeric characters '
165 u'underscores, periods or dashes and must begin with '
166 u'alphanumeric character or underscore')
164 _('Username may only contain alphanumeric characters '
165 'underscores, periods or dashes and must begin with '
166 'alphanumeric character or underscore')
167 167 }
168 168
169 169 def validate_python(self, value, state):
170 170 if value in ['default', 'new_user']:
171 171 msg = M(self, 'system_invalid_username', state, username=value)
172 172 raise formencode.Invalid(msg, value, state)
173 173 # check if user is unique
174 174 old_un = None
175 175 if edit:
176 176 old_un = User.get(old_data.get('user_id')).username
177 177
178 178 if old_un != value or not edit:
179 179 if User.get_by_username(value, case_insensitive=True):
180 180 msg = M(self, 'username_exists', state, username=value)
181 181 raise formencode.Invalid(msg, value, state)
182 182
183 183 if (re.match(r'^[\w]{1}[\w\-\.]{0,254}$', value)
184 184 is None):
185 185 msg = M(self, 'invalid_username', state)
186 186 raise formencode.Invalid(msg, value, state)
187 187 return _validator
188 188
189 189
190 190 def ValidRepoUser(localizer, allow_disabled=False):
191 191 _ = localizer
192 192
193 193 class _validator(formencode.validators.FancyValidator):
194 194 messages = {
195 'invalid_username': _(u'Username %(username)s is not valid'),
196 'disabled_username': _(u'Username %(username)s is disabled')
195 'invalid_username': _('Username %(username)s is not valid'),
196 'disabled_username': _('Username %(username)s is disabled')
197 197 }
198 198
199 199 def validate_python(self, value, state):
200 200 try:
201 201 user = User.query().filter(User.username == value).one()
202 202 except Exception:
203 203 msg = M(self, 'invalid_username', state, username=value)
204 204 raise formencode.Invalid(
205 205 msg, value, state, error_dict={'username': msg}
206 206 )
207 207 if user and (not allow_disabled and not user.active):
208 208 msg = M(self, 'disabled_username', state, username=value)
209 209 raise formencode.Invalid(
210 210 msg, value, state, error_dict={'username': msg}
211 211 )
212 212 return _validator
213 213
214 214
215 215 def ValidUserGroup(localizer, edit=False, old_data=None):
216 216 _ = localizer
217 217 old_data = old_data or {}
218 218
219 219 class _validator(formencode.validators.FancyValidator):
220 220 messages = {
221 'invalid_group': _(u'Invalid user group name'),
222 'group_exist': _(u'User group `%(usergroup)s` already exists'),
221 'invalid_group': _('Invalid user group name'),
222 'group_exist': _('User group `%(usergroup)s` already exists'),
223 223 'invalid_usergroup_name':
224 _(u'user group name may only contain alphanumeric '
225 u'characters underscores, periods or dashes and must begin '
226 u'with alphanumeric character')
224 _('user group name may only contain alphanumeric '
225 'characters underscores, periods or dashes and must begin '
226 'with alphanumeric character')
227 227 }
228 228
229 229 def validate_python(self, value, state):
230 230 if value in ['default']:
231 231 msg = M(self, 'invalid_group', state)
232 232 raise formencode.Invalid(
233 233 msg, value, state, error_dict={'users_group_name': msg}
234 234 )
235 235 # check if group is unique
236 236 old_ugname = None
237 237 if edit:
238 238 old_id = old_data.get('users_group_id')
239 239 old_ugname = UserGroup.get(old_id).users_group_name
240 240
241 241 if old_ugname != value or not edit:
242 242 is_existing_group = UserGroup.get_by_group_name(
243 243 value, case_insensitive=True)
244 244 if is_existing_group:
245 245 msg = M(self, 'group_exist', state, usergroup=value)
246 246 raise formencode.Invalid(
247 247 msg, value, state, error_dict={'users_group_name': msg}
248 248 )
249 249
250 250 if re.match(r'^[a-zA-Z0-9]{1}[a-zA-Z0-9\-\_\.]+$', value) is None:
251 251 msg = M(self, 'invalid_usergroup_name', state)
252 252 raise formencode.Invalid(
253 253 msg, value, state, error_dict={'users_group_name': msg}
254 254 )
255 255 return _validator
256 256
257 257
258 258 def ValidRepoGroup(localizer, edit=False, old_data=None, can_create_in_root=False):
259 259 _ = localizer
260 260 old_data = old_data or {}
261 261
262 262 class _validator(formencode.validators.FancyValidator):
263 263 messages = {
264 'group_parent_id': _(u'Cannot assign this group as parent'),
265 'group_exists': _(u'Group "%(group_name)s" already exists'),
266 'repo_exists': _(u'Repository with name "%(group_name)s" '
267 u'already exists'),
268 'permission_denied': _(u"no permission to store repository group"
269 u"in this location"),
264 'group_parent_id': _('Cannot assign this group as parent'),
265 'group_exists': _('Group "%(group_name)s" already exists'),
266 'repo_exists': _('Repository with name "%(group_name)s" '
267 'already exists'),
268 'permission_denied': _("no permission to store repository group"
269 "in this location"),
270 270 'permission_denied_root': _(
271 u"no permission to store repository group "
272 u"in root location")
271 "no permission to store repository group "
272 "in root location")
273 273 }
274 274
275 275 def _to_python(self, value, state):
276 276 group_name = repo_name_slug(value.get('group_name', ''))
277 277 group_parent_id = safe_int(value.get('group_parent_id'))
278 278 gr = RepoGroup.get(group_parent_id)
279 279 if gr:
280 280 parent_group_path = gr.full_path
281 281 # value needs to be aware of group name in order to check
282 282 # db key This is an actual just the name to store in the
283 283 # database
284 284 group_name_full = (
285 285 parent_group_path + RepoGroup.url_sep() + group_name)
286 286 else:
287 287 group_name_full = group_name
288 288
289 289 value['group_name'] = group_name
290 290 value['group_name_full'] = group_name_full
291 291 value['group_parent_id'] = group_parent_id
292 292 return value
293 293
294 294 def validate_python(self, value, state):
295 295
296 296 old_group_name = None
297 297 group_name = value.get('group_name')
298 298 group_name_full = value.get('group_name_full')
299 299 group_parent_id = safe_int(value.get('group_parent_id'))
300 300 if group_parent_id == -1:
301 301 group_parent_id = None
302 302
303 303 group_obj = RepoGroup.get(old_data.get('group_id'))
304 304 parent_group_changed = False
305 305 if edit:
306 306 old_group_name = group_obj.group_name
307 307 old_group_parent_id = group_obj.group_parent_id
308 308
309 309 if group_parent_id != old_group_parent_id:
310 310 parent_group_changed = True
311 311
312 312 # TODO: mikhail: the following if statement is not reached
313 313 # since group_parent_id's OneOf validation fails before.
314 314 # Can be removed.
315 315
316 316 # check against setting a parent of self
317 317 parent_of_self = (
318 318 old_data['group_id'] == group_parent_id
319 319 if group_parent_id else False
320 320 )
321 321 if parent_of_self:
322 322 msg = M(self, 'group_parent_id', state)
323 323 raise formencode.Invalid(
324 324 msg, value, state, error_dict={'group_parent_id': msg}
325 325 )
326 326
327 327 # group we're moving current group inside
328 328 child_group = None
329 329 if group_parent_id:
330 330 child_group = RepoGroup.query().filter(
331 331 RepoGroup.group_id == group_parent_id).scalar()
332 332
333 333 # do a special check that we cannot move a group to one of
334 334 # it's children
335 335 if edit and child_group:
336 336 parents = [x.group_id for x in child_group.parents]
337 337 move_to_children = old_data['group_id'] in parents
338 338 if move_to_children:
339 339 msg = M(self, 'group_parent_id', state)
340 340 raise formencode.Invalid(
341 341 msg, value, state, error_dict={'group_parent_id': msg})
342 342
343 343 # Check if we have permission to store in the parent.
344 344 # Only check if the parent group changed.
345 345 if parent_group_changed:
346 346 if child_group is None:
347 347 if not can_create_in_root:
348 348 msg = M(self, 'permission_denied_root', state)
349 349 raise formencode.Invalid(
350 350 msg, value, state,
351 351 error_dict={'group_parent_id': msg})
352 352 else:
353 353 valid = HasRepoGroupPermissionAny('group.admin')
354 354 forbidden = not valid(
355 355 child_group.group_name, 'can create group validator')
356 356 if forbidden:
357 357 msg = M(self, 'permission_denied', state)
358 358 raise formencode.Invalid(
359 359 msg, value, state,
360 360 error_dict={'group_parent_id': msg})
361 361
362 362 # if we change the name or it's new group, check for existing names
363 363 # or repositories with the same name
364 364 if old_group_name != group_name_full or not edit:
365 365 # check group
366 366 gr = RepoGroup.get_by_group_name(group_name_full)
367 367 if gr:
368 368 msg = M(self, 'group_exists', state, group_name=group_name)
369 369 raise formencode.Invalid(
370 370 msg, value, state, error_dict={'group_name': msg})
371 371
372 372 # check for same repo
373 373 repo = Repository.get_by_repo_name(group_name_full)
374 374 if repo:
375 375 msg = M(self, 'repo_exists', state, group_name=group_name)
376 376 raise formencode.Invalid(
377 377 msg, value, state, error_dict={'group_name': msg})
378 378 return _validator
379 379
380 380
381 381 def ValidPassword(localizer):
382 382 _ = localizer
383 383
384 384 class _validator(formencode.validators.FancyValidator):
385 385 messages = {
386 386 'invalid_password':
387 _(u'Invalid characters (non-ascii) in password')
387 _('Invalid characters (non-ascii) in password')
388 388 }
389 389
390 390 def validate_python(self, value, state):
391 391 try:
392 392 (value or '').decode('ascii')
393 393 except UnicodeError:
394 394 msg = M(self, 'invalid_password', state)
395 395 raise formencode.Invalid(msg, value, state,)
396 396 return _validator
397 397
398 398
399 399 def ValidPasswordsMatch(
400 400 localizer, passwd='new_password',
401 401 passwd_confirmation='password_confirmation'):
402 402 _ = localizer
403 403
404 404 class _validator(formencode.validators.FancyValidator):
405 405 messages = {
406 'password_mismatch': _(u'Passwords do not match'),
406 'password_mismatch': _('Passwords do not match'),
407 407 }
408 408
409 409 def validate_python(self, value, state):
410 410
411 411 pass_val = value.get('password') or value.get(passwd)
412 412 if pass_val != value[passwd_confirmation]:
413 413 msg = M(self, 'password_mismatch', state)
414 414 raise formencode.Invalid(
415 415 msg, value, state,
416 416 error_dict={passwd: msg, passwd_confirmation: msg}
417 417 )
418 418 return _validator
419 419
420 420
421 421 def ValidAuth(localizer):
422 422 _ = localizer
423 423
424 424 class _validator(formencode.validators.FancyValidator):
425 425 messages = {
426 'invalid_password': _(u'invalid password'),
427 'invalid_username': _(u'invalid user name'),
428 'disabled_account': _(u'Your account is disabled')
426 'invalid_password': _('invalid password'),
427 'invalid_username': _('invalid user name'),
428 'disabled_account': _('Your account is disabled')
429 429 }
430 430
431 431 def validate_python(self, value, state):
432 432 from rhodecode.authentication.base import authenticate, HTTP_TYPE
433 433
434 434 password = value['password']
435 435 username = value['username']
436 436
437 437 if not authenticate(username, password, '', HTTP_TYPE,
438 438 skip_missing=True):
439 439 user = User.get_by_username(username)
440 440 if user and not user.active:
441 441 log.warning('user %s is disabled', username)
442 442 msg = M(self, 'disabled_account', state)
443 443 raise formencode.Invalid(
444 444 msg, value, state, error_dict={'username': msg}
445 445 )
446 446 else:
447 447 log.warning('user `%s` failed to authenticate', username)
448 448 msg = M(self, 'invalid_username', state)
449 449 msg2 = M(self, 'invalid_password', state)
450 450 raise formencode.Invalid(
451 451 msg, value, state,
452 452 error_dict={'username': msg, 'password': msg2}
453 453 )
454 454 return _validator
455 455
456 456
457 457 def ValidRepoName(localizer, edit=False, old_data=None):
458 458 old_data = old_data or {}
459 459 _ = localizer
460 460
461 461 class _validator(formencode.validators.FancyValidator):
462 462 messages = {
463 463 'invalid_repo_name':
464 _(u'Repository name %(repo)s is disallowed'),
464 _('Repository name %(repo)s is disallowed'),
465 465 # top level
466 'repository_exists': _(u'Repository with name %(repo)s '
467 u'already exists'),
468 'group_exists': _(u'Repository group with name "%(repo)s" '
469 u'already exists'),
466 'repository_exists': _('Repository with name %(repo)s '
467 'already exists'),
468 'group_exists': _('Repository group with name "%(repo)s" '
469 'already exists'),
470 470 # inside a group
471 'repository_in_group_exists': _(u'Repository with name %(repo)s '
472 u'exists in group "%(group)s"'),
471 'repository_in_group_exists': _('Repository with name %(repo)s '
472 'exists in group "%(group)s"'),
473 473 'group_in_group_exists': _(
474 u'Repository group with name "%(repo)s" '
475 u'exists in group "%(group)s"'),
474 'Repository group with name "%(repo)s" '
475 'exists in group "%(group)s"'),
476 476 }
477 477
478 478 def _to_python(self, value, state):
479 479 repo_name = repo_name_slug(value.get('repo_name', ''))
480 480 repo_group = value.get('repo_group')
481 481 if repo_group:
482 482 gr = RepoGroup.get(repo_group)
483 483 group_path = gr.full_path
484 484 group_name = gr.group_name
485 485 # value needs to be aware of group name in order to check
486 486 # db key This is an actual just the name to store in the
487 487 # database
488 488 repo_name_full = group_path + RepoGroup.url_sep() + repo_name
489 489 else:
490 490 group_name = group_path = ''
491 491 repo_name_full = repo_name
492 492
493 493 value['repo_name'] = repo_name
494 494 value['repo_name_full'] = repo_name_full
495 495 value['group_path'] = group_path
496 496 value['group_name'] = group_name
497 497 return value
498 498
499 499 def validate_python(self, value, state):
500 500
501 501 repo_name = value.get('repo_name')
502 502 repo_name_full = value.get('repo_name_full')
503 503 group_path = value.get('group_path')
504 504 group_name = value.get('group_name')
505 505
506 506 if repo_name in [ADMIN_PREFIX, '']:
507 507 msg = M(self, 'invalid_repo_name', state, repo=repo_name)
508 508 raise formencode.Invalid(
509 509 msg, value, state, error_dict={'repo_name': msg})
510 510
511 511 rename = old_data.get('repo_name') != repo_name_full
512 512 create = not edit
513 513 if rename or create:
514 514
515 515 if group_path:
516 516 if Repository.get_by_repo_name(repo_name_full):
517 517 msg = M(self, 'repository_in_group_exists', state,
518 518 repo=repo_name, group=group_name)
519 519 raise formencode.Invalid(
520 520 msg, value, state, error_dict={'repo_name': msg})
521 521 if RepoGroup.get_by_group_name(repo_name_full):
522 522 msg = M(self, 'group_in_group_exists', state,
523 523 repo=repo_name, group=group_name)
524 524 raise formencode.Invalid(
525 525 msg, value, state, error_dict={'repo_name': msg})
526 526 else:
527 527 if RepoGroup.get_by_group_name(repo_name_full):
528 528 msg = M(self, 'group_exists', state, repo=repo_name)
529 529 raise formencode.Invalid(
530 530 msg, value, state, error_dict={'repo_name': msg})
531 531
532 532 if Repository.get_by_repo_name(repo_name_full):
533 533 msg = M(
534 534 self, 'repository_exists', state, repo=repo_name)
535 535 raise formencode.Invalid(
536 536 msg, value, state, error_dict={'repo_name': msg})
537 537 return value
538 538 return _validator
539 539
540 540
541 541 def ValidForkName(localizer, *args, **kwargs):
542 542 _ = localizer
543 543
544 544 return ValidRepoName(localizer, *args, **kwargs)
545 545
546 546
547 547 def SlugifyName(localizer):
548 548 _ = localizer
549 549
550 550 class _validator(formencode.validators.FancyValidator):
551 551
552 552 def _to_python(self, value, state):
553 553 return repo_name_slug(value)
554 554
555 555 def validate_python(self, value, state):
556 556 pass
557 557 return _validator
558 558
559 559
560 560 def CannotHaveGitSuffix(localizer):
561 561 _ = localizer
562 562
563 563 class _validator(formencode.validators.FancyValidator):
564 564 messages = {
565 565 'has_git_suffix':
566 _(u'Repository name cannot end with .git'),
566 _('Repository name cannot end with .git'),
567 567 }
568 568
569 569 def _to_python(self, value, state):
570 570 return value
571 571
572 572 def validate_python(self, value, state):
573 573 if value and value.endswith('.git'):
574 574 msg = M(
575 575 self, 'has_git_suffix', state)
576 576 raise formencode.Invalid(
577 577 msg, value, state, error_dict={'repo_name': msg})
578 578 return _validator
579 579
580 580
581 581 def ValidCloneUri(localizer):
582 582 _ = localizer
583 583
584 584 class InvalidCloneUrl(Exception):
585 585 allowed_prefixes = ()
586 586
587 587 def url_handler(repo_type, url):
588 588 config = make_db_config(clear_session=False)
589 589 if repo_type == 'hg':
590 590 allowed_prefixes = ('http', 'svn+http', 'git+http')
591 591
592 592 if 'http' in url[:4]:
593 593 # initially check if it's at least the proper URL
594 594 # or does it pass basic auth
595 595 MercurialRepository.check_url(url, config)
596 596 elif 'svn+http' in url[:8]: # svn->hg import
597 597 SubversionRepository.check_url(url, config)
598 598 elif 'git+http' in url[:8]: # git->hg import
599 599 raise NotImplementedError()
600 600 else:
601 601 exc = InvalidCloneUrl('Clone from URI %s not allowed. '
602 602 'Allowed url must start with one of %s'
603 603 % (url, ','.join(allowed_prefixes)))
604 604 exc.allowed_prefixes = allowed_prefixes
605 605 raise exc
606 606
607 607 elif repo_type == 'git':
608 608 allowed_prefixes = ('http', 'svn+http', 'hg+http')
609 609 if 'http' in url[:4]:
610 610 # initially check if it's at least the proper URL
611 611 # or does it pass basic auth
612 612 GitRepository.check_url(url, config)
613 613 elif 'svn+http' in url[:8]: # svn->git import
614 614 raise NotImplementedError()
615 615 elif 'hg+http' in url[:8]: # hg->git import
616 616 raise NotImplementedError()
617 617 else:
618 618 exc = InvalidCloneUrl('Clone from URI %s not allowed. '
619 619 'Allowed url must start with one of %s'
620 620 % (url, ','.join(allowed_prefixes)))
621 621 exc.allowed_prefixes = allowed_prefixes
622 622 raise exc
623 623
624 624 class _validator(formencode.validators.FancyValidator):
625 625 messages = {
626 'clone_uri': _(u'invalid clone url or credentials for %(rtype)s repository'),
626 'clone_uri': _('invalid clone url or credentials for %(rtype)s repository'),
627 627 'invalid_clone_uri': _(
628 u'Invalid clone url, provide a valid clone '
629 u'url starting with one of %(allowed_prefixes)s')
628 'Invalid clone url, provide a valid clone '
629 'url starting with one of %(allowed_prefixes)s')
630 630 }
631 631
632 632 def validate_python(self, value, state):
633 633 repo_type = value.get('repo_type')
634 634 url = value.get('clone_uri')
635 635
636 636 if url:
637 637 try:
638 638 url_handler(repo_type, url)
639 639 except InvalidCloneUrl as e:
640 640 log.warning(e)
641 641 msg = M(self, 'invalid_clone_uri', state, rtype=repo_type,
642 642 allowed_prefixes=','.join(e.allowed_prefixes))
643 643 raise formencode.Invalid(msg, value, state,
644 644 error_dict={'clone_uri': msg})
645 645 except Exception:
646 646 log.exception('Url validation failed')
647 647 msg = M(self, 'clone_uri', state, rtype=repo_type)
648 648 raise formencode.Invalid(msg, value, state,
649 649 error_dict={'clone_uri': msg})
650 650 return _validator
651 651
652 652
653 653 def ValidForkType(localizer, old_data=None):
654 654 _ = localizer
655 655 old_data = old_data or {}
656 656
657 657 class _validator(formencode.validators.FancyValidator):
658 658 messages = {
659 'invalid_fork_type': _(u'Fork have to be the same type as parent')
659 'invalid_fork_type': _('Fork have to be the same type as parent')
660 660 }
661 661
662 662 def validate_python(self, value, state):
663 663 if old_data['repo_type'] != value:
664 664 msg = M(self, 'invalid_fork_type', state)
665 665 raise formencode.Invalid(
666 666 msg, value, state, error_dict={'repo_type': msg}
667 667 )
668 668 return _validator
669 669
670 670
671 671 def CanWriteGroup(localizer, old_data=None):
672 672 _ = localizer
673 673
674 674 class _validator(formencode.validators.FancyValidator):
675 675 messages = {
676 676 'permission_denied': _(
677 u"You do not have the permission "
678 u"to create repositories in this group."),
677 "You do not have the permission "
678 "to create repositories in this group."),
679 679 'permission_denied_root': _(
680 u"You do not have the permission to store repositories in "
681 u"the root location.")
680 "You do not have the permission to store repositories in "
681 "the root location.")
682 682 }
683 683
684 684 def _to_python(self, value, state):
685 685 # root location
686 686 if value in [-1, "-1"]:
687 687 return None
688 688 return value
689 689
690 690 def validate_python(self, value, state):
691 691 gr = RepoGroup.get(value)
692 692 gr_name = gr.group_name if gr else None # None means ROOT location
693 693 # create repositories with write permission on group is set to true
694 694 create_on_write = HasPermissionAny(
695 695 'hg.create.write_on_repogroup.true')()
696 696 group_admin = HasRepoGroupPermissionAny('group.admin')(
697 697 gr_name, 'can write into group validator')
698 698 group_write = HasRepoGroupPermissionAny('group.write')(
699 699 gr_name, 'can write into group validator')
700 700 forbidden = not (group_admin or (group_write and create_on_write))
701 701 can_create_repos = HasPermissionAny(
702 702 'hg.admin', 'hg.create.repository')
703 703 gid = (old_data['repo_group'].get('group_id')
704 704 if (old_data and 'repo_group' in old_data) else None)
705 705 value_changed = gid != safe_int(value)
706 706 new = not old_data
707 707 # do check if we changed the value, there's a case that someone got
708 708 # revoked write permissions to a repository, he still created, we
709 709 # don't need to check permission if he didn't change the value of
710 710 # groups in form box
711 711 if value_changed or new:
712 712 # parent group need to be existing
713 713 if gr and forbidden:
714 714 msg = M(self, 'permission_denied', state)
715 715 raise formencode.Invalid(
716 716 msg, value, state, error_dict={'repo_type': msg}
717 717 )
718 718 # check if we can write to root location !
719 719 elif gr is None and not can_create_repos():
720 720 msg = M(self, 'permission_denied_root', state)
721 721 raise formencode.Invalid(
722 722 msg, value, state, error_dict={'repo_type': msg}
723 723 )
724 724 return _validator
725 725
726 726
727 727 def ValidPerms(localizer, type_='repo'):
728 728 _ = localizer
729 729 if type_ == 'repo_group':
730 730 EMPTY_PERM = 'group.none'
731 731 elif type_ == 'repo':
732 732 EMPTY_PERM = 'repository.none'
733 733 elif type_ == 'user_group':
734 734 EMPTY_PERM = 'usergroup.none'
735 735
736 736 class _validator(formencode.validators.FancyValidator):
737 737 messages = {
738 738 'perm_new_member_name':
739 _(u'This username or user group name is not valid')
739 _('This username or user group name is not valid')
740 740 }
741 741
742 742 def _to_python(self, value, state):
743 743 perm_updates = OrderedSet()
744 744 perm_additions = OrderedSet()
745 745 perm_deletions = OrderedSet()
746 746 # build a list of permission to update/delete and new permission
747 747
748 748 # Read the perm_new_member/perm_del_member attributes and group
749 749 # them by they IDs
750 750 new_perms_group = collections.defaultdict(dict)
751 751 del_perms_group = collections.defaultdict(dict)
752 752 for k, v in value.copy().items():
753 753 if k.startswith('perm_del_member'):
754 754 # delete from org storage so we don't process that later
755 755 del value[k]
756 756 # part is `id`, `type`
757 757 _type, part = k.split('perm_del_member_')
758 758 args = part.split('_')
759 759 if len(args) == 2:
760 760 _key, pos = args
761 761 del_perms_group[pos][_key] = v
762 762 if k.startswith('perm_new_member'):
763 763 # delete from org storage so we don't process that later
764 764 del value[k]
765 765 # part is `id`, `type`, `perm`
766 766 _type, part = k.split('perm_new_member_')
767 767 args = part.split('_')
768 768 if len(args) == 2:
769 769 _key, pos = args
770 770 new_perms_group[pos][_key] = v
771 771
772 772 # store the deletes
773 773 for k in sorted(del_perms_group.keys()):
774 774 perm_dict = del_perms_group[k]
775 775 del_member = perm_dict.get('id')
776 776 del_type = perm_dict.get('type')
777 777 if del_member and del_type:
778 778 perm_deletions.add(
779 779 (del_member, None, del_type))
780 780
781 781 # store additions in order of how they were added in web form
782 782 for k in sorted(new_perms_group.keys()):
783 783 perm_dict = new_perms_group[k]
784 784 new_member = perm_dict.get('id')
785 785 new_type = perm_dict.get('type')
786 786 new_perm = perm_dict.get('perm')
787 787 if new_member and new_perm and new_type:
788 788 perm_additions.add(
789 789 (new_member, new_perm, new_type))
790 790
791 791 # get updates of permissions
792 792 # (read the existing radio button states)
793 793 default_user_id = User.get_default_user_id()
794 794
795 795 for k, update_value in value.items():
796 796 if k.startswith('u_perm_') or k.startswith('g_perm_'):
797 797 obj_type = k[0]
798 798 obj_id = k[7:]
799 799 update_type = {'u': 'user',
800 800 'g': 'user_group'}[obj_type]
801 801
802 802 if obj_type == 'u' and safe_int(obj_id) == default_user_id:
803 803 if str2bool(value.get('repo_private')):
804 804 # prevent from updating default user permissions
805 805 # when this repository is marked as private
806 806 update_value = EMPTY_PERM
807 807
808 808 perm_updates.add(
809 809 (obj_id, update_value, update_type))
810 810
811 811 value['perm_additions'] = [] # propagated later
812 812 value['perm_updates'] = list(perm_updates)
813 813 value['perm_deletions'] = list(perm_deletions)
814 814
815 815 updates_map = dict(
816 816 (x[0], (x[1], x[2])) for x in value['perm_updates'])
817 817 # make sure Additions don't override updates.
818 818 for member_id, perm, member_type in list(perm_additions):
819 819 if member_id in updates_map:
820 820 perm = updates_map[member_id][0]
821 821 value['perm_additions'].append((member_id, perm, member_type))
822 822
823 823 # on new entries validate users they exist and they are active !
824 824 # this leaves feedback to the form
825 825 try:
826 826 if member_type == 'user':
827 827 User.query()\
828 828 .filter(User.active == true())\
829 829 .filter(User.user_id == member_id).one()
830 830 if member_type == 'user_group':
831 831 UserGroup.query()\
832 832 .filter(UserGroup.users_group_active == true())\
833 833 .filter(UserGroup.users_group_id == member_id)\
834 834 .one()
835 835
836 836 except Exception:
837 837 log.exception('Updated permission failed: org_exc:')
838 838 msg = M(self, 'perm_new_member_type', state)
839 839 raise formencode.Invalid(
840 840 msg, value, state, error_dict={
841 841 'perm_new_member_name': msg}
842 842 )
843 843 return value
844 844 return _validator
845 845
846 846
847 847 def ValidPath(localizer):
848 848 _ = localizer
849 849
850 850 class _validator(formencode.validators.FancyValidator):
851 851 messages = {
852 'invalid_path': _(u'This is not a valid path')
852 'invalid_path': _('This is not a valid path')
853 853 }
854 854
855 855 def validate_python(self, value, state):
856 856 if not os.path.isdir(value):
857 857 msg = M(self, 'invalid_path', state)
858 858 raise formencode.Invalid(
859 859 msg, value, state, error_dict={'paths_root_path': msg}
860 860 )
861 861 return _validator
862 862
863 863
864 864 def UniqSystemEmail(localizer, old_data=None):
865 865 _ = localizer
866 866 old_data = old_data or {}
867 867
868 868 class _validator(formencode.validators.FancyValidator):
869 869 messages = {
870 'email_taken': _(u'This e-mail address is already taken')
870 'email_taken': _('This e-mail address is already taken')
871 871 }
872 872
873 873 def _to_python(self, value, state):
874 874 return value.lower()
875 875
876 876 def validate_python(self, value, state):
877 877 if (old_data.get('email') or '').lower() != value:
878 878 user = User.get_by_email(value, case_insensitive=True)
879 879 if user:
880 880 msg = M(self, 'email_taken', state)
881 881 raise formencode.Invalid(
882 882 msg, value, state, error_dict={'email': msg}
883 883 )
884 884 return _validator
885 885
886 886
887 887 def ValidSystemEmail(localizer):
888 888 _ = localizer
889 889
890 890 class _validator(formencode.validators.FancyValidator):
891 891 messages = {
892 'non_existing_email': _(u'e-mail "%(email)s" does not exist.')
892 'non_existing_email': _('e-mail "%(email)s" does not exist.')
893 893 }
894 894
895 895 def _to_python(self, value, state):
896 896 return value.lower()
897 897
898 898 def validate_python(self, value, state):
899 899 user = User.get_by_email(value, case_insensitive=True)
900 900 if user is None:
901 901 msg = M(self, 'non_existing_email', state, email=value)
902 902 raise formencode.Invalid(
903 903 msg, value, state, error_dict={'email': msg}
904 904 )
905 905 return _validator
906 906
907 907
908 908 def NotReviewedRevisions(localizer, repo_id):
909 909 _ = localizer
910 910 class _validator(formencode.validators.FancyValidator):
911 911 messages = {
912 912 'rev_already_reviewed':
913 _(u'Revisions %(revs)s are already part of pull request '
914 u'or have set status'),
913 _('Revisions %(revs)s are already part of pull request '
914 'or have set status'),
915 915 }
916 916
917 917 def validate_python(self, value, state):
918 918 # check revisions if they are not reviewed, or a part of another
919 919 # pull request
920 920 statuses = ChangesetStatus.query()\
921 921 .filter(ChangesetStatus.revision.in_(value))\
922 922 .filter(ChangesetStatus.repo_id == repo_id)\
923 923 .all()
924 924
925 925 errors = []
926 926 for status in statuses:
927 927 if status.pull_request_id:
928 928 errors.append(['pull_req', status.revision[:12]])
929 929 elif status.status:
930 930 errors.append(['status', status.revision[:12]])
931 931
932 932 if errors:
933 933 revs = ','.join([x[1] for x in errors])
934 934 msg = M(self, 'rev_already_reviewed', state, revs=revs)
935 935 raise formencode.Invalid(
936 936 msg, value, state, error_dict={'revisions': revs})
937 937
938 938 return _validator
939 939
940 940
941 941 def ValidIp(localizer):
942 942 _ = localizer
943 943
944 944 class _validator(CIDR):
945 945 messages = {
946 'badFormat': _(u'Please enter a valid IPv4 or IpV6 address'),
946 'badFormat': _('Please enter a valid IPv4 or IpV6 address'),
947 947 'illegalBits': _(
948 u'The network size (bits) must be within the range '
949 u'of 0-32 (not %(bits)r)'),
948 'The network size (bits) must be within the range '
949 'of 0-32 (not %(bits)r)'),
950 950 }
951 951
952 952 # we ovveride the default to_python() call
953 953 def to_python(self, value, state):
954 954 v = super(_validator, self).to_python(value, state)
955 955 v = safe_unicode(v.strip())
956 956 net = ipaddress.ip_network(address=v, strict=False)
957 957 return str(net)
958 958
959 959 def validate_python(self, value, state):
960 960 try:
961 961 addr = safe_unicode(value.strip())
962 962 # this raises an ValueError if address is not IpV4 or IpV6
963 963 ipaddress.ip_network(addr, strict=False)
964 964 except ValueError:
965 965 raise formencode.Invalid(self.message('badFormat', state),
966 966 value, state)
967 967 return _validator
968 968
969 969
970 970 def FieldKey(localizer):
971 971 _ = localizer
972 972
973 973 class _validator(formencode.validators.FancyValidator):
974 974 messages = {
975 975 'badFormat': _(
976 u'Key name can only consist of letters, '
977 u'underscore, dash or numbers'),
976 'Key name can only consist of letters, '
977 'underscore, dash or numbers'),
978 978 }
979 979
980 980 def validate_python(self, value, state):
981 981 if not re.match('[a-zA-Z0-9_-]+$', value):
982 982 raise formencode.Invalid(self.message('badFormat', state),
983 983 value, state)
984 984 return _validator
985 985
986 986
987 987 def ValidAuthPlugins(localizer):
988 988 _ = localizer
989 989
990 990 class _validator(formencode.validators.FancyValidator):
991 991 messages = {
992 992 'import_duplicate': _(
993 u'Plugins %(loaded)s and %(next_to_load)s '
994 u'both export the same name'),
993 'Plugins %(loaded)s and %(next_to_load)s '
994 'both export the same name'),
995 995 'missing_includeme': _(
996 u'The plugin "%(plugin_id)s" is missing an includeme '
997 u'function.'),
996 'The plugin "%(plugin_id)s" is missing an includeme '
997 'function.'),
998 998 'import_error': _(
999 u'Can not load plugin "%(plugin_id)s"'),
999 'Can not load plugin "%(plugin_id)s"'),
1000 1000 'no_plugin': _(
1001 u'No plugin available with ID "%(plugin_id)s"'),
1001 'No plugin available with ID "%(plugin_id)s"'),
1002 1002 }
1003 1003
1004 1004 def _to_python(self, value, state):
1005 1005 # filter empty values
1006 1006 return filter(lambda s: s not in [None, ''], value)
1007 1007
1008 1008 def _validate_legacy_plugin_id(self, plugin_id, value, state):
1009 1009 """
1010 1010 Validates that the plugin import works. It also checks that the
1011 1011 plugin has an includeme attribute.
1012 1012 """
1013 1013 try:
1014 1014 plugin = _import_legacy_plugin(plugin_id)
1015 1015 except Exception as e:
1016 1016 log.exception(
1017 1017 'Exception during import of auth legacy plugin "{}"'
1018 1018 .format(plugin_id))
1019 1019 msg = M(self, 'import_error', state, plugin_id=plugin_id)
1020 1020 raise formencode.Invalid(msg, value, state)
1021 1021
1022 1022 if not hasattr(plugin, 'includeme'):
1023 1023 msg = M(self, 'missing_includeme', state, plugin_id=plugin_id)
1024 1024 raise formencode.Invalid(msg, value, state)
1025 1025
1026 1026 return plugin
1027 1027
1028 1028 def _validate_plugin_id(self, plugin_id, value, state):
1029 1029 """
1030 1030 Plugins are already imported during app start up. Therefore this
1031 1031 validation only retrieves the plugin from the plugin registry and
1032 1032 if it returns something not None everything is OK.
1033 1033 """
1034 1034 plugin = loadplugin(plugin_id)
1035 1035
1036 1036 if plugin is None:
1037 1037 msg = M(self, 'no_plugin', state, plugin_id=plugin_id)
1038 1038 raise formencode.Invalid(msg, value, state)
1039 1039
1040 1040 return plugin
1041 1041
1042 1042 def validate_python(self, value, state):
1043 1043 unique_names = {}
1044 1044 for plugin_id in value:
1045 1045
1046 1046 # Validate legacy or normal plugin.
1047 1047 if plugin_id.startswith(legacy_plugin_prefix):
1048 1048 plugin = self._validate_legacy_plugin_id(
1049 1049 plugin_id, value, state)
1050 1050 else:
1051 1051 plugin = self._validate_plugin_id(plugin_id, value, state)
1052 1052
1053 1053 # Only allow unique plugin names.
1054 1054 if plugin.name in unique_names:
1055 1055 msg = M(self, 'import_duplicate', state,
1056 1056 loaded=unique_names[plugin.name],
1057 1057 next_to_load=plugin)
1058 1058 raise formencode.Invalid(msg, value, state)
1059 1059 unique_names[plugin.name] = plugin
1060 1060 return _validator
1061 1061
1062 1062
1063 1063 def ValidPattern(localizer):
1064 1064 _ = localizer
1065 1065
1066 1066 class _validator(formencode.validators.FancyValidator):
1067 1067 messages = {
1068 'bad_format': _(u'Url must start with http or /'),
1068 'bad_format': _('Url must start with http or /'),
1069 1069 }
1070 1070
1071 1071 def _to_python(self, value, state):
1072 1072 patterns = []
1073 1073
1074 1074 prefix = 'new_pattern'
1075 1075 for name, v in value.items():
1076 1076 pattern_name = '_'.join((prefix, 'pattern'))
1077 1077 if name.startswith(pattern_name):
1078 1078 new_item_id = name[len(pattern_name)+1:]
1079 1079
1080 1080 def _field(name):
1081 1081 return '%s_%s_%s' % (prefix, name, new_item_id)
1082 1082
1083 1083 values = {
1084 1084 'issuetracker_pat': value.get(_field('pattern')),
1085 1085 'issuetracker_url': value.get(_field('url')),
1086 1086 'issuetracker_pref': value.get(_field('prefix')),
1087 1087 'issuetracker_desc': value.get(_field('description'))
1088 1088 }
1089 1089 new_uid = md5(values['issuetracker_pat'])
1090 1090
1091 1091 has_required_fields = (
1092 1092 values['issuetracker_pat']
1093 1093 and values['issuetracker_url'])
1094 1094
1095 1095 if has_required_fields:
1096 1096 # validate url that it starts with http or /
1097 1097 # otherwise it can lead to JS injections
1098 1098 # e.g specifig javascript:<malicios code>
1099 1099 if not values['issuetracker_url'].startswith(('http', '/')):
1100 1100 raise formencode.Invalid(
1101 1101 self.message('bad_format', state),
1102 1102 value, state)
1103 1103
1104 1104 settings = [
1105 1105 ('_'.join((key, new_uid)), values[key], 'unicode')
1106 1106 for key in values]
1107 1107 patterns.append(settings)
1108 1108
1109 1109 value['patterns'] = patterns
1110 1110 delete_patterns = value.get('uid') or []
1111 1111 if not isinstance(delete_patterns, (list, tuple)):
1112 1112 delete_patterns = [delete_patterns]
1113 1113 value['delete_patterns'] = delete_patterns
1114 1114 return value
1115 1115 return _validator
@@ -1,398 +1,398 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20 import io
21 21 import shlex
22 22
23 23 import math
24 24 import re
25 25 import os
26 26 import datetime
27 27 import logging
28 28 import queue
29 29 import subprocess
30 30
31 31
32 32 from dateutil.parser import parse
33 33 from pyramid.threadlocal import get_current_request
34 34 from pyramid.interfaces import IRoutesMapper
35 35 from pyramid.settings import asbool
36 36 from pyramid.path import AssetResolver
37 37 from threading import Thread
38 38
39 39 from rhodecode.config.jsroutes import generate_jsroutes_content
40 40 from rhodecode.lib.base import get_auth_user
41 41
42 42 import rhodecode
43 43
44 44
45 45 log = logging.getLogger(__name__)
46 46
47 47
48 48 def add_renderer_globals(event):
49 49 from rhodecode.lib import helpers
50 50
51 51 # TODO: When executed in pyramid view context the request is not available
52 52 # in the event. Find a better solution to get the request.
53 53 request = event['request'] or get_current_request()
54 54
55 55 # Add Pyramid translation as '_' to context
56 56 event['_'] = request.translate
57 57 event['_ungettext'] = request.plularize
58 58 event['h'] = helpers
59 59
60 60
61 61 def set_user_lang(event):
62 62 request = event.request
63 63 cur_user = getattr(request, 'user', None)
64 64
65 65 if cur_user:
66 66 user_lang = cur_user.get_instance().user_data.get('language')
67 67 if user_lang:
68 68 log.debug('lang: setting current user:%s language to: %s', cur_user, user_lang)
69 69 event.request._LOCALE_ = user_lang
70 70
71 71
72 72 def update_celery_conf(event):
73 73 from rhodecode.lib.celerylib.loader import set_celery_conf
74 74 log.debug('Setting celery config from new request')
75 75 set_celery_conf(request=event.request, registry=event.request.registry)
76 76
77 77
78 78 def add_request_user_context(event):
79 79 """
80 80 Adds auth user into request context
81 81 """
82 82 request = event.request
83 83 # access req_id as soon as possible
84 84 req_id = request.req_id
85 85
86 86 if hasattr(request, 'vcs_call'):
87 87 # skip vcs calls
88 88 return
89 89
90 90 if hasattr(request, 'rpc_method'):
91 91 # skip api calls
92 92 return
93 93
94 94 auth_user, auth_token = get_auth_user(request)
95 95 request.user = auth_user
96 96 request.user_auth_token = auth_token
97 97 request.environ['rc_auth_user'] = auth_user
98 98 request.environ['rc_auth_user_id'] = auth_user.user_id
99 99 request.environ['rc_req_id'] = req_id
100 100
101 101
102 102 def reset_log_bucket(event):
103 103 """
104 104 reset the log bucket on new request
105 105 """
106 106 request = event.request
107 107 request.req_id_records_init()
108 108
109 109
110 110 def scan_repositories_if_enabled(event):
111 111 """
112 112 This is subscribed to the `pyramid.events.ApplicationCreated` event. It
113 113 does a repository scan if enabled in the settings.
114 114 """
115 115 settings = event.app.registry.settings
116 116 vcs_server_enabled = settings['vcs.server.enable']
117 117 import_on_startup = settings['startup.import_repos']
118 118 if vcs_server_enabled and import_on_startup:
119 119 from rhodecode.model.scm import ScmModel
120 120 from rhodecode.lib.utils import repo2db_mapper, get_rhodecode_base_path
121 121 repositories = ScmModel().repo_scan(get_rhodecode_base_path())
122 122 repo2db_mapper(repositories, remove_obsolete=False)
123 123
124 124
125 125 def write_metadata_if_needed(event):
126 126 """
127 127 Writes upgrade metadata
128 128 """
129 129 import rhodecode
130 130 from rhodecode.lib import system_info
131 131 from rhodecode.lib import ext_json
132 132
133 133 fname = '.rcmetadata.json'
134 134 ini_loc = os.path.dirname(rhodecode.CONFIG.get('__file__'))
135 135 metadata_destination = os.path.join(ini_loc, fname)
136 136
137 137 def get_update_age():
138 138 now = datetime.datetime.utcnow()
139 139
140 140 with open(metadata_destination, 'rb') as f:
141 141 data = ext_json.json.loads(f.read())
142 142 if 'created_on' in data:
143 143 update_date = parse(data['created_on'])
144 144 diff = now - update_date
145 145 return diff.total_seconds() / 60.0
146 146
147 147 return 0
148 148
149 149 def write():
150 150 configuration = system_info.SysInfo(
151 151 system_info.rhodecode_config)()['value']
152 152 license_token = configuration['config']['license_token']
153 153
154 154 setup = dict(
155 155 workers=configuration['config']['server:main'].get(
156 156 'workers', '?'),
157 157 worker_type=configuration['config']['server:main'].get(
158 158 'worker_class', 'sync'),
159 159 )
160 160 dbinfo = system_info.SysInfo(system_info.database_info)()['value']
161 161 del dbinfo['url']
162 162
163 163 metadata = dict(
164 164 desc='upgrade metadata info',
165 165 license_token=license_token,
166 166 created_on=datetime.datetime.utcnow().isoformat(),
167 167 usage=system_info.SysInfo(system_info.usage_info)()['value'],
168 168 platform=system_info.SysInfo(system_info.platform_type)()['value'],
169 169 database=dbinfo,
170 170 cpu=system_info.SysInfo(system_info.cpu)()['value'],
171 171 memory=system_info.SysInfo(system_info.memory)()['value'],
172 172 setup=setup
173 173 )
174 174
175 175 with open(metadata_destination, 'wb') as f:
176 176 f.write(ext_json.json.dumps(metadata))
177 177
178 178 settings = event.app.registry.settings
179 179 if settings.get('metadata.skip'):
180 180 return
181 181
182 182 # only write this every 24h, workers restart caused unwanted delays
183 183 try:
184 184 age_in_min = get_update_age()
185 185 except Exception:
186 186 age_in_min = 0
187 187
188 188 if age_in_min > 60 * 60 * 24:
189 189 return
190 190
191 191 try:
192 192 write()
193 193 except Exception:
194 194 pass
195 195
196 196
197 197 def write_usage_data(event):
198 198 import rhodecode
199 199 from rhodecode.lib import system_info
200 200 from rhodecode.lib import ext_json
201 201
202 202 settings = event.app.registry.settings
203 203 instance_tag = settings.get('metadata.write_usage_tag')
204 204 if not settings.get('metadata.write_usage'):
205 205 return
206 206
207 207 def get_update_age(dest_file):
208 208 now = datetime.datetime.utcnow()
209 209
210 210 with open(dest_file, 'rb') as f:
211 211 data = ext_json.json.loads(f.read())
212 212 if 'created_on' in data:
213 213 update_date = parse(data['created_on'])
214 214 diff = now - update_date
215 215 return math.ceil(diff.total_seconds() / 60.0)
216 216
217 217 return 0
218 218
219 219 utc_date = datetime.datetime.utcnow()
220 220 hour_quarter = int(math.ceil((utc_date.hour + utc_date.minute/60.0) / 6.))
221 221 fname = '.rc_usage_{date.year}{date.month:02d}{date.day:02d}_{hour}.json'.format(
222 222 date=utc_date, hour=hour_quarter)
223 223 ini_loc = os.path.dirname(rhodecode.CONFIG.get('__file__'))
224 224
225 225 usage_dir = os.path.join(ini_loc, '.rcusage')
226 226 if not os.path.isdir(usage_dir):
227 227 os.makedirs(usage_dir)
228 228 usage_metadata_destination = os.path.join(usage_dir, fname)
229 229
230 230 try:
231 231 age_in_min = get_update_age(usage_metadata_destination)
232 232 except Exception:
233 233 age_in_min = 0
234 234
235 235 # write every 6th hour
236 236 if age_in_min and age_in_min < 60 * 6:
237 237 log.debug('Usage file created %s minutes ago, skipping (threshold: %s minutes)...',
238 238 age_in_min, 60 * 6)
239 239 return
240 240
241 241 def write(dest_file):
242 242 configuration = system_info.SysInfo(system_info.rhodecode_config)()['value']
243 243 license_token = configuration['config']['license_token']
244 244
245 245 metadata = dict(
246 246 desc='Usage data',
247 247 instance_tag=instance_tag,
248 248 license_token=license_token,
249 249 created_on=datetime.datetime.utcnow().isoformat(),
250 250 usage=system_info.SysInfo(system_info.usage_info)()['value'],
251 251 )
252 252
253 253 with open(dest_file, 'wb') as f:
254 254 f.write(ext_json.json.dumps(metadata, indent=2, sort_keys=True))
255 255
256 256 try:
257 257 log.debug('Writing usage file at: %s', usage_metadata_destination)
258 258 write(usage_metadata_destination)
259 259 except Exception:
260 260 pass
261 261
262 262
263 263 def write_js_routes_if_enabled(event):
264 264 registry = event.app.registry
265 265
266 266 mapper = registry.queryUtility(IRoutesMapper)
267 _argument_prog = re.compile('\{(.*?)\}|:\((.*)\)')
267 _argument_prog = re.compile(r'\{(.*?)\}|:\((.*)\)')
268 268
269 269 def _extract_route_information(route):
270 270 """
271 271 Convert a route into tuple(name, path, args), eg:
272 272 ('show_user', '/profile/%(username)s', ['username'])
273 273 """
274 274
275 275 routepath = route.pattern
276 276 pattern = route.pattern
277 277
278 278 def replace(matchobj):
279 279 if matchobj.group(1):
280 280 return "%%(%s)s" % matchobj.group(1).split(':')[0]
281 281 else:
282 282 return "%%(%s)s" % matchobj.group(2)
283 283
284 284 routepath = _argument_prog.sub(replace, routepath)
285 285
286 286 if not routepath.startswith('/'):
287 287 routepath = '/'+routepath
288 288
289 289 return (
290 290 route.name,
291 291 routepath,
292 292 [(arg[0].split(':')[0] if arg[0] != '' else arg[1])
293 293 for arg in _argument_prog.findall(pattern)]
294 294 )
295 295
296 296 def get_routes():
297 297 # pyramid routes
298 298 for route in mapper.get_routes():
299 299 if not route.name.startswith('__'):
300 300 yield _extract_route_information(route)
301 301
302 302 if asbool(registry.settings.get('generate_js_files', 'false')):
303 303 static_path = AssetResolver().resolve('rhodecode:public').abspath()
304 304 jsroutes = get_routes()
305 305 jsroutes_file_content = generate_jsroutes_content(jsroutes)
306 306 jsroutes_file_path = os.path.join(
307 307 static_path, 'js', 'rhodecode', 'routes.js')
308 308
309 309 try:
310 310 with io.open(jsroutes_file_path, 'w', encoding='utf-8') as f:
311 311 f.write(jsroutes_file_content)
312 312 except Exception:
313 313 log.exception('Failed to write routes.js into %s', jsroutes_file_path)
314 314
315 315
316 316 class Subscriber(object):
317 317 """
318 318 Base class for subscribers to the pyramid event system.
319 319 """
320 320 def __call__(self, event):
321 321 self.run(event)
322 322
323 323 def run(self, event):
324 324 raise NotImplementedError('Subclass has to implement this.')
325 325
326 326
327 327 class AsyncSubscriber(Subscriber):
328 328 """
329 329 Subscriber that handles the execution of events in a separate task to not
330 330 block the execution of the code which triggers the event. It puts the
331 331 received events into a queue from which the worker process takes them in
332 332 order.
333 333 """
334 334 def __init__(self):
335 335 self._stop = False
336 336 self._eventq = queue.Queue()
337 337 self._worker = self.create_worker()
338 338 self._worker.start()
339 339
340 340 def __call__(self, event):
341 341 self._eventq.put(event)
342 342
343 343 def create_worker(self):
344 344 worker = Thread(target=self.do_work)
345 345 worker.daemon = True
346 346 return worker
347 347
348 348 def stop_worker(self):
349 349 self._stop = False
350 350 self._eventq.put(None)
351 351 self._worker.join()
352 352
353 353 def do_work(self):
354 354 while not self._stop:
355 355 event = self._eventq.get()
356 356 if event is not None:
357 357 self.run(event)
358 358
359 359
360 360 class AsyncSubprocessSubscriber(AsyncSubscriber):
361 361 """
362 362 Subscriber that uses the subprocess module to execute a command if an
363 363 event is received. Events are handled asynchronously::
364 364
365 365 subscriber = AsyncSubprocessSubscriber('ls -la', timeout=10)
366 366 subscriber(dummyEvent) # running __call__(event)
367 367
368 368 """
369 369
370 370 def __init__(self, cmd, timeout=None):
371 371 if not isinstance(cmd, (list, tuple)):
372 372 cmd = shlex.split(cmd)
373 373 super(AsyncSubprocessSubscriber, self).__init__()
374 374 self._cmd = cmd
375 375 self._timeout = timeout
376 376
377 377 def run(self, event):
378 378 cmd = self._cmd
379 379 timeout = self._timeout
380 380 log.debug('Executing command %s.', cmd)
381 381
382 382 try:
383 383 output = subprocess.check_output(
384 384 cmd, timeout=timeout, stderr=subprocess.STDOUT)
385 385 log.debug('Command finished %s', cmd)
386 386 if output:
387 387 log.debug('Command output: %s', output)
388 388 except subprocess.TimeoutExpired as e:
389 389 log.exception('Timeout while executing command.')
390 390 if e.output:
391 391 log.error('Command output: %s', e.output)
392 392 except subprocess.CalledProcessError as e:
393 393 log.exception('Error while executing command.')
394 394 if e.output:
395 395 log.error('Command output: %s', e.output)
396 396 except Exception:
397 397 log.exception(
398 398 'Exception while executing command %s.', cmd)
@@ -1,1404 +1,1404 b''
1 1 <%namespace name="base" file="/base/base.mako"/>
2 2 <%namespace name="commentblock" file="/changeset/changeset_file_comment.mako"/>
3 3
4 4 <%def name="diff_line_anchor(commit, filename, line, type)"><%
5 5 return '%s_%s_%i' % (h.md5_safe(commit+filename), type, line)
6 6 %></%def>
7 7
8 8 <%def name="action_class(action)">
9 9 <%
10 10 return {
11 11 '-': 'cb-deletion',
12 12 '+': 'cb-addition',
13 13 ' ': 'cb-context',
14 14 }.get(action, 'cb-empty')
15 15 %>
16 16 </%def>
17 17
18 18 <%def name="op_class(op_id)">
19 19 <%
20 20 return {
21 21 DEL_FILENODE: 'deletion', # file deleted
22 22 BIN_FILENODE: 'warning' # binary diff hidden
23 23 }.get(op_id, 'addition')
24 24 %>
25 25 </%def>
26 26
27 27
28 28
29 29 <%def name="render_diffset(diffset, commit=None,
30 30
31 31 # collapse all file diff entries when there are more than this amount of files in the diff
32 32 collapse_when_files_over=20,
33 33
34 34 # collapse lines in the diff when more than this amount of lines changed in the file diff
35 35 lines_changed_limit=500,
36 36
37 37 # add a ruler at to the output
38 38 ruler_at_chars=0,
39 39
40 40 # show inline comments
41 41 use_comments=False,
42 42
43 43 # disable new comments
44 44 disable_new_comments=False,
45 45
46 46 # special file-comments that were deleted in previous versions
47 47 # it's used for showing outdated comments for deleted files in a PR
48 48 deleted_files_comments=None,
49 49
50 50 # for cache purpose
51 51 inline_comments=None,
52 52
53 53 # additional menu for PRs
54 54 pull_request_menu=None,
55 55
56 56 # show/hide todo next to comments
57 57 show_todos=True,
58 58
59 59 )">
60 60
61 61 <%
62 62 diffset_container_id = h.md5(diffset.target_ref)
63 63 collapse_all = len(diffset.files) > collapse_when_files_over
64 64 active_pattern_entries = h.get_active_pattern_entries(getattr(c, 'repo_name', None))
65 65 from rhodecode.lib.diffs import NEW_FILENODE, DEL_FILENODE, \
66 66 MOD_FILENODE, RENAMED_FILENODE, CHMOD_FILENODE, BIN_FILENODE, COPIED_FILENODE
67 67 %>
68 68
69 69 %if use_comments:
70 70
71 71 ## Template for injecting comments
72 72 <div id="cb-comments-inline-container-template" class="js-template">
73 73 ${inline_comments_container([])}
74 74 </div>
75 75
76 76 <div class="js-template" id="cb-comment-inline-form-template">
77 77 <div class="comment-inline-form ac">
78 78 %if not c.rhodecode_user.is_default:
79 79 ## render template for inline comments
80 80 ${commentblock.comment_form(form_type='inline')}
81 81 %endif
82 82 </div>
83 83 </div>
84 84
85 85 %endif
86 86
87 87 %if c.user_session_attrs["diffmode"] == 'sideside':
88 88 <style>
89 89 .wrapper {
90 90 max-width: 1600px !important;
91 91 }
92 92 </style>
93 93 %endif
94 94
95 95 %if ruler_at_chars:
96 96 <style>
97 97 .diff table.cb .cb-content:after {
98 98 content: "";
99 99 border-left: 1px solid blue;
100 100 position: absolute;
101 101 top: 0;
102 102 height: 18px;
103 103 opacity: .2;
104 104 z-index: 10;
105 105 //## +5 to account for diff action (+/-)
106 106 left: ${ruler_at_chars + 5}ch;
107 107 </style>
108 108 %endif
109 109
110 110 <div class="diffset ${disable_new_comments and 'diffset-comments-disabled'}">
111 111
112 112 <div style="height: 20px; line-height: 20px">
113 113 ## expand/collapse action
114 114 <div class="pull-left">
115 115 <a class="${'collapsed' if collapse_all else ''}" href="#expand-files" onclick="toggleExpand(this, '${diffset_container_id}'); return false">
116 116 % if collapse_all:
117 117 <i class="icon-plus-squared-alt icon-no-margin"></i>${_('Expand all files')}
118 118 % else:
119 119 <i class="icon-minus-squared-alt icon-no-margin"></i>${_('Collapse all files')}
120 120 % endif
121 121 </a>
122 122
123 123 </div>
124 124
125 125 ## todos
126 126 % if show_todos and getattr(c, 'at_version', None):
127 127 <div class="pull-right">
128 128 <i class="icon-flag-filled" style="color: #949494">TODOs:</i>
129 129 ${_('not available in this view')}
130 130 </div>
131 131 % elif show_todos:
132 132 <div class="pull-right">
133 133 <div class="comments-number" style="padding-left: 10px">
134 134 % if hasattr(c, 'unresolved_comments') and hasattr(c, 'resolved_comments'):
135 135 <i class="icon-flag-filled" style="color: #949494">TODOs:</i>
136 136 % if c.unresolved_comments:
137 137 <a href="#show-todos" onclick="$('#todo-box').toggle(); return false">
138 138 ${_('{} unresolved').format(len(c.unresolved_comments))}
139 139 </a>
140 140 % else:
141 141 ${_('0 unresolved')}
142 142 % endif
143 143
144 144 ${_('{} Resolved').format(len(c.resolved_comments))}
145 145 % endif
146 146 </div>
147 147 </div>
148 148 % endif
149 149
150 150 ## ## comments
151 151 ## <div class="pull-right">
152 152 ## <div class="comments-number" style="padding-left: 10px">
153 153 ## % if hasattr(c, 'comments') and hasattr(c, 'inline_cnt'):
154 154 ## <i class="icon-comment" style="color: #949494">COMMENTS:</i>
155 155 ## % if c.comments:
156 156 ## <a href="#comments">${_ungettext("{} General", "{} General", len(c.comments)).format(len(c.comments))}</a>,
157 157 ## % else:
158 158 ## ${_('0 General')}
159 159 ## % endif
160 160 ##
161 161 ## % if c.inline_cnt:
162 162 ## <a href="#" onclick="return Rhodecode.comments.nextComment();"
163 163 ## id="inline-comments-counter">${_ungettext("{} Inline", "{} Inline", c.inline_cnt).format(c.inline_cnt)}
164 164 ## </a>
165 165 ## % else:
166 166 ## ${_('0 Inline')}
167 167 ## % endif
168 168 ## % endif
169 169 ##
170 170 ## % if pull_request_menu:
171 171 ## <%
172 172 ## outdated_comm_count_ver = pull_request_menu['outdated_comm_count_ver']
173 173 ## %>
174 174 ##
175 175 ## % if outdated_comm_count_ver:
176 176 ## <a href="#" onclick="showOutdated(); Rhodecode.comments.nextOutdatedComment(); return false;">
177 177 ## (${_("{} Outdated").format(outdated_comm_count_ver)})
178 178 ## </a>
179 179 ## <a href="#" class="showOutdatedComments" onclick="showOutdated(this); return false;"> | ${_('show outdated')}</a>
180 180 ## <a href="#" class="hideOutdatedComments" style="display: none" onclick="hideOutdated(this); return false;"> | ${_('hide outdated')}</a>
181 181 ## % else:
182 182 ## (${_("{} Outdated").format(outdated_comm_count_ver)})
183 183 ## % endif
184 184 ##
185 185 ## % endif
186 186 ##
187 187 ## </div>
188 188 ## </div>
189 189
190 190 </div>
191 191
192 192 % if diffset.limited_diff:
193 193 <div class="diffset-heading ${(diffset.limited_diff and 'diffset-heading-warning' or '')}">
194 194 <h2 class="clearinner">
195 195 ${_('The requested changes are too big and content was truncated.')}
196 196 <a href="${h.current_route_path(request, fulldiff=1)}" onclick="return confirm('${_("Showing a big diff might take some time and resources, continue?")}')">${_('Show full diff')}</a>
197 197 </h2>
198 198 </div>
199 199 % endif
200 200
201 201 <div id="todo-box">
202 202 % if hasattr(c, 'unresolved_comments') and c.unresolved_comments:
203 203 % for co in c.unresolved_comments:
204 204 <a class="permalink" href="#comment-${co.comment_id}"
205 205 onclick="Rhodecode.comments.scrollToComment($('#comment-${co.comment_id}'))">
206 206 <i class="icon-flag-filled-red"></i>
207 207 ${co.comment_id}</a>${('' if loop.last else ',')}
208 208 % endfor
209 209 % endif
210 210 </div>
211 211 %if diffset.has_hidden_changes:
212 212 <p class="empty_data">${_('Some changes may be hidden')}</p>
213 213 %elif not diffset.files:
214 214 <p class="empty_data">${_('No files')}</p>
215 215 %endif
216 216
217 217 <div class="filediffs">
218 218
219 219 ## initial value could be marked as False later on
220 220 <% over_lines_changed_limit = False %>
221 221 %for i, filediff in enumerate(diffset.files):
222 222
223 223 %if filediff.source_file_path and filediff.target_file_path:
224 224 %if filediff.source_file_path != filediff.target_file_path:
225 225 ## file was renamed, or copied
226 226 %if RENAMED_FILENODE in filediff.patch['stats']['ops']:
227 227 <%
228 228 final_file_name = h.literal(u'{} <i class="icon-angle-left"></i> <del>{}</del>'.format(filediff.target_file_path, filediff.source_file_path))
229 229 final_path = filediff.target_file_path
230 230 %>
231 231 %elif COPIED_FILENODE in filediff.patch['stats']['ops']:
232 232 <%
233 233 final_file_name = h.literal(u'{} <i class="icon-angle-left"></i> {}'.format(filediff.target_file_path, filediff.source_file_path))
234 234 final_path = filediff.target_file_path
235 235 %>
236 236 %endif
237 237 %else:
238 238 ## file was modified
239 239 <%
240 240 final_file_name = filediff.source_file_path
241 241 final_path = final_file_name
242 242 %>
243 243 %endif
244 244 %else:
245 245 %if filediff.source_file_path:
246 246 ## file was deleted
247 247 <%
248 248 final_file_name = filediff.source_file_path
249 249 final_path = final_file_name
250 250 %>
251 251 %else:
252 252 ## file was added
253 253 <%
254 254 final_file_name = filediff.target_file_path
255 255 final_path = final_file_name
256 256 %>
257 257 %endif
258 258 %endif
259 259
260 260 <%
261 261 lines_changed = filediff.patch['stats']['added'] + filediff.patch['stats']['deleted']
262 262 over_lines_changed_limit = lines_changed > lines_changed_limit
263 263 %>
264 264 ## anchor with support of sticky header
265 265 <div class="anchor" id="a_${h.FID(filediff.raw_id, filediff.patch['filename'])}"></div>
266 266
267 267 <input ${(collapse_all and 'checked' or '')} class="filediff-collapse-state collapse-${diffset_container_id}" id="filediff-collapse-${id(filediff)}" type="checkbox" onchange="updateSticky();">
268 268 <div
269 269 class="filediff"
270 270 data-f-path="${filediff.patch['filename']}"
271 271 data-anchor-id="${h.FID(filediff.raw_id, filediff.patch['filename'])}"
272 272 >
273 273 <label for="filediff-collapse-${id(filediff)}" class="filediff-heading">
274 274 <%
275 275 file_comments = (get_inline_comments(inline_comments, filediff.patch['filename']) or {}).values()
276 276 total_file_comments = [_c for _c in h.itertools.chain.from_iterable(file_comments) if not (_c.outdated or _c.draft)]
277 277 %>
278 278 <div class="filediff-collapse-indicator icon-"></div>
279 279
280 280 ## Comments/Options PILL
281 281 <span class="pill-group pull-right">
282 282 <span class="pill" op="comments">
283 283 <i class="icon-comment"></i> ${len(total_file_comments)}
284 284 </span>
285 285
286 286 <details class="details-reset details-inline-block">
287 287 <summary class="noselect">
288 288 <i class="pill icon-options cursor-pointer" op="options"></i>
289 289 </summary>
290 290 <details-menu class="details-dropdown">
291 291
292 292 <div class="dropdown-item">
293 293 <span>${final_path}</span>
294 294 <span class="pull-right icon-clipboard clipboard-action" data-clipboard-text="${final_path}" title="Copy file path"></span>
295 295 </div>
296 296
297 297 <div class="dropdown-divider"></div>
298 298
299 299 <div class="dropdown-item">
300 300 <% permalink = request.current_route_url(_anchor='a_{}'.format(h.FID(filediff.raw_id, filediff.patch['filename']))) %>
301 301 <a href="${permalink}">¶ permalink</a>
302 302 <span class="pull-right icon-clipboard clipboard-action" data-clipboard-text="${permalink}" title="Copy permalink"></span>
303 303 </div>
304 304
305 305
306 306 </details-menu>
307 307 </details>
308 308
309 309 </span>
310 310
311 311 ${diff_ops(final_file_name, filediff)}
312 312
313 313 </label>
314 314
315 315 ${diff_menu(filediff, use_comments=use_comments)}
316 316 <table id="file-${h.safeid(h.safe_unicode(filediff.patch['filename']))}" data-f-path="${filediff.patch['filename']}" data-anchor-id="${h.FID(filediff.raw_id, filediff.patch['filename'])}" class="code-visible-block cb cb-diff-${c.user_session_attrs["diffmode"]} code-highlight ${(over_lines_changed_limit and 'cb-collapsed' or '')}">
317 317
318 318 ## new/deleted/empty content case
319 319 % if not filediff.hunks:
320 320 ## Comment container, on "fakes" hunk that contains all data to render comments
321 321 ${render_hunk_lines(filediff, c.user_session_attrs["diffmode"], filediff.hunk_ops, use_comments=use_comments, inline_comments=inline_comments, active_pattern_entries=active_pattern_entries)}
322 322 % endif
323 323
324 324 %if filediff.limited_diff:
325 325 <tr class="cb-warning cb-collapser">
326 326 <td class="cb-text" ${(c.user_session_attrs["diffmode"] == 'unified' and 'colspan=4' or 'colspan=6')}>
327 327 ${_('The requested commit or file is too big and content was truncated.')} <a href="${h.current_route_path(request, fulldiff=1)}" onclick="return confirm('${_("Showing a big diff might take some time and resources, continue?")}')">${_('Show full diff')}</a>
328 328 </td>
329 329 </tr>
330 330 %else:
331 331 %if over_lines_changed_limit:
332 332 <tr class="cb-warning cb-collapser">
333 333 <td class="cb-text" ${(c.user_session_attrs["diffmode"] == 'unified' and 'colspan=4' or 'colspan=6')}>
334 334 ${_('This diff has been collapsed as it changes many lines, (%i lines changed)' % lines_changed)}
335 335 <a href="#" class="cb-expand"
336 336 onclick="$(this).closest('table').removeClass('cb-collapsed'); updateSticky(); return false;">${_('Show them')}
337 337 </a>
338 338 <a href="#" class="cb-collapse"
339 339 onclick="$(this).closest('table').addClass('cb-collapsed'); updateSticky(); return false;">${_('Hide them')}
340 340 </a>
341 341 </td>
342 342 </tr>
343 343 %endif
344 344 %endif
345 345
346 346 % for hunk in filediff.hunks:
347 347 <tr class="cb-hunk">
348 348 <td ${(c.user_session_attrs["diffmode"] == 'unified' and 'colspan=3' or '')}>
349 349 ## TODO: dan: add ajax loading of more context here
350 350 ## <a href="#">
351 351 <i class="icon-more"></i>
352 352 ## </a>
353 353 </td>
354 354 <td ${(c.user_session_attrs["diffmode"] == 'sideside' and 'colspan=5' or '')}>
355 355 @@
356 356 -${hunk.source_start},${hunk.source_length}
357 357 +${hunk.target_start},${hunk.target_length}
358 358 ${hunk.section_header}
359 359 </td>
360 360 </tr>
361 361
362 362 ${render_hunk_lines(filediff, c.user_session_attrs["diffmode"], hunk, use_comments=use_comments, inline_comments=inline_comments, active_pattern_entries=active_pattern_entries)}
363 363 % endfor
364 364
365 365 <% unmatched_comments = (inline_comments or {}).get(filediff.patch['filename'], {}) %>
366 366
367 367 ## outdated comments that do not fit into currently displayed lines
368 368 % for lineno, comments in unmatched_comments.items():
369 369
370 370 %if c.user_session_attrs["diffmode"] == 'unified':
371 371 % if loop.index == 0:
372 372 <tr class="cb-hunk">
373 373 <td colspan="3"></td>
374 374 <td>
375 375 <div>
376 376 ${_('Unmatched/outdated inline comments below')}
377 377 </div>
378 378 </td>
379 379 </tr>
380 380 % endif
381 381 <tr class="cb-line">
382 382 <td class="cb-data cb-context"></td>
383 383 <td class="cb-lineno cb-context"></td>
384 384 <td class="cb-lineno cb-context"></td>
385 385 <td class="cb-content cb-context">
386 386 ${inline_comments_container(comments, active_pattern_entries=active_pattern_entries)}
387 387 </td>
388 388 </tr>
389 389 %elif c.user_session_attrs["diffmode"] == 'sideside':
390 390 % if loop.index == 0:
391 391 <tr class="cb-comment-info">
392 392 <td colspan="2"></td>
393 393 <td class="cb-line">
394 394 <div>
395 395 ${_('Unmatched/outdated inline comments below')}
396 396 </div>
397 397 </td>
398 398 <td colspan="2"></td>
399 399 <td class="cb-line">
400 400 <div>
401 401 ${_('Unmatched/outdated comments below')}
402 402 </div>
403 403 </td>
404 404 </tr>
405 405 % endif
406 406 <tr class="cb-line">
407 407 <td class="cb-data cb-context"></td>
408 408 <td class="cb-lineno cb-context"></td>
409 409 <td class="cb-content cb-context">
410 410 % if lineno.startswith('o'):
411 411 ${inline_comments_container(comments, active_pattern_entries=active_pattern_entries)}
412 412 % endif
413 413 </td>
414 414
415 415 <td class="cb-data cb-context"></td>
416 416 <td class="cb-lineno cb-context"></td>
417 417 <td class="cb-content cb-context">
418 418 % if lineno.startswith('n'):
419 419 ${inline_comments_container(comments, active_pattern_entries=active_pattern_entries)}
420 420 % endif
421 421 </td>
422 422 </tr>
423 423 %endif
424 424
425 425 % endfor
426 426
427 427 </table>
428 428 </div>
429 429 %endfor
430 430
431 431 ## outdated comments that are made for a file that has been deleted
432 432 % for filename, comments_dict in (deleted_files_comments or {}).items():
433 433
434 434 <%
435 435 display_state = 'display: none'
436 436 open_comments_in_file = [x for x in comments_dict['comments'] if x.outdated is False]
437 437 if open_comments_in_file:
438 438 display_state = ''
439 439 fid = str(id(filename))
440 440 %>
441 441 <div class="filediffs filediff-outdated" style="${display_state}">
442 442 <input ${(collapse_all and 'checked' or '')} class="filediff-collapse-state collapse-${diffset_container_id}" id="filediff-collapse-${id(filename)}" type="checkbox" onchange="updateSticky();">
443 443 <div class="filediff" data-f-path="${filename}" id="a_${h.FID(fid, filename)}">
444 444 <label for="filediff-collapse-${id(filename)}" class="filediff-heading">
445 445 <div class="filediff-collapse-indicator icon-"></div>
446 446
447 447 <span class="pill">
448 448 ## file was deleted
449 449 ${filename}
450 450 </span>
451 451 <span class="pill-group pull-left" >
452 452 ## file op, doesn't need translation
453 453 <span class="pill" op="removed">unresolved comments</span>
454 454 </span>
455 455 <a class="pill filediff-anchor" href="#a_${h.FID(fid, filename)}"></a>
456 456 <span class="pill-group pull-right">
457 457 <span class="pill" op="deleted">
458 458 % if comments_dict['stats'] >0:
459 459 -${comments_dict['stats']}
460 460 % else:
461 461 ${comments_dict['stats']}
462 462 % endif
463 463 </span>
464 464 </span>
465 465 </label>
466 466
467 467 <table class="cb cb-diff-${c.user_session_attrs["diffmode"]} code-highlight ${(over_lines_changed_limit and 'cb-collapsed' or '')}">
468 468 <tr>
469 469 % if c.user_session_attrs["diffmode"] == 'unified':
470 470 <td></td>
471 471 %endif
472 472
473 473 <td></td>
474 474 <td class="cb-text cb-${op_class(BIN_FILENODE)}" ${(c.user_session_attrs["diffmode"] == 'unified' and 'colspan=4' or 'colspan=5')}>
475 475 <strong>${_('This file was removed from diff during updates to this pull-request.')}</strong><br/>
476 476 ${_('There are still outdated/unresolved comments attached to it.')}
477 477 </td>
478 478 </tr>
479 479 %if c.user_session_attrs["diffmode"] == 'unified':
480 480 <tr class="cb-line">
481 481 <td class="cb-data cb-context"></td>
482 482 <td class="cb-lineno cb-context"></td>
483 483 <td class="cb-lineno cb-context"></td>
484 484 <td class="cb-content cb-context">
485 485 ${inline_comments_container(comments_dict['comments'], active_pattern_entries=active_pattern_entries)}
486 486 </td>
487 487 </tr>
488 488 %elif c.user_session_attrs["diffmode"] == 'sideside':
489 489 <tr class="cb-line">
490 490 <td class="cb-data cb-context"></td>
491 491 <td class="cb-lineno cb-context"></td>
492 492 <td class="cb-content cb-context"></td>
493 493
494 494 <td class="cb-data cb-context"></td>
495 495 <td class="cb-lineno cb-context"></td>
496 496 <td class="cb-content cb-context">
497 497 ${inline_comments_container(comments_dict['comments'], active_pattern_entries=active_pattern_entries)}
498 498 </td>
499 499 </tr>
500 500 %endif
501 501 </table>
502 502 </div>
503 503 </div>
504 504 % endfor
505 505
506 506 </div>
507 507 </div>
508 508 </%def>
509 509
510 510 <%def name="diff_ops(file_name, filediff)">
511 511 <%
512 512 from rhodecode.lib.diffs import NEW_FILENODE, DEL_FILENODE, \
513 513 MOD_FILENODE, RENAMED_FILENODE, CHMOD_FILENODE, BIN_FILENODE, COPIED_FILENODE
514 514 %>
515 515 <span class="pill">
516 516 <i class="icon-file-text"></i>
517 517 ${file_name}
518 518 </span>
519 519
520 520 <span class="pill-group pull-right">
521 521
522 522 ## ops pills
523 523 %if filediff.limited_diff:
524 524 <span class="pill tooltip" op="limited" title="The stats for this diff are not complete">limited diff</span>
525 525 %endif
526 526
527 527 %if NEW_FILENODE in filediff.patch['stats']['ops']:
528 528 <span class="pill" op="created">created</span>
529 529 %if filediff['target_mode'].startswith('120'):
530 530 <span class="pill" op="symlink">symlink</span>
531 531 %else:
532 532 <span class="pill" op="mode">${nice_mode(filediff['target_mode'])}</span>
533 533 %endif
534 534 %endif
535 535
536 536 %if RENAMED_FILENODE in filediff.patch['stats']['ops']:
537 537 <span class="pill" op="renamed">renamed</span>
538 538 %endif
539 539
540 540 %if COPIED_FILENODE in filediff.patch['stats']['ops']:
541 541 <span class="pill" op="copied">copied</span>
542 542 %endif
543 543
544 544 %if DEL_FILENODE in filediff.patch['stats']['ops']:
545 545 <span class="pill" op="removed">removed</span>
546 546 %endif
547 547
548 548 %if CHMOD_FILENODE in filediff.patch['stats']['ops']:
549 549 <span class="pill" op="mode">
550 550 ${nice_mode(filediff['source_mode'])}${nice_mode(filediff['target_mode'])}
551 551 </span>
552 552 %endif
553 553
554 554 %if BIN_FILENODE in filediff.patch['stats']['ops']:
555 555 <span class="pill" op="binary">binary</span>
556 556 %if MOD_FILENODE in filediff.patch['stats']['ops']:
557 557 <span class="pill" op="modified">modified</span>
558 558 %endif
559 559 %endif
560 560
561 561 <span class="pill" op="added">${('+' if filediff.patch['stats']['added'] else '')}${filediff.patch['stats']['added']}</span>
562 562 <span class="pill" op="deleted">${((h.safe_int(filediff.patch['stats']['deleted']) or 0) * -1)}</span>
563 563
564 564 </span>
565 565
566 566 </%def>
567 567
568 568 <%def name="nice_mode(filemode)">
569 569 ${(filemode.startswith('100') and filemode[3:] or filemode)}
570 570 </%def>
571 571
572 572 <%def name="diff_menu(filediff, use_comments=False)">
573 573 <div class="filediff-menu">
574 574
575 575 %if filediff.diffset.source_ref:
576 576
577 577 ## FILE BEFORE CHANGES
578 578 %if filediff.operation in ['D', 'M']:
579 579 <a
580 580 class="tooltip"
581 581 href="${h.route_path('repo_files',repo_name=filediff.diffset.target_repo_name,commit_id=filediff.diffset.source_ref,f_path=filediff.source_file_path)}"
582 582 title="${h.tooltip(_('Show file at commit: %(commit_id)s') % {'commit_id': filediff.diffset.source_ref[:12]})}"
583 583 >
584 584 ${_('Show file before')}
585 585 </a> |
586 586 %else:
587 587 <span
588 588 class="tooltip"
589 589 title="${h.tooltip(_('File not present at commit: %(commit_id)s') % {'commit_id': filediff.diffset.source_ref[:12]})}"
590 590 >
591 591 ${_('Show file before')}
592 592 </span> |
593 593 %endif
594 594
595 595 ## FILE AFTER CHANGES
596 596 %if filediff.operation in ['A', 'M']:
597 597 <a
598 598 class="tooltip"
599 599 href="${h.route_path('repo_files',repo_name=filediff.diffset.source_repo_name,commit_id=filediff.diffset.target_ref,f_path=filediff.target_file_path)}"
600 600 title="${h.tooltip(_('Show file at commit: %(commit_id)s') % {'commit_id': filediff.diffset.target_ref[:12]})}"
601 601 >
602 602 ${_('Show file after')}
603 603 </a>
604 604 %else:
605 605 <span
606 606 class="tooltip"
607 607 title="${h.tooltip(_('File not present at commit: %(commit_id)s') % {'commit_id': filediff.diffset.target_ref[:12]})}"
608 608 >
609 609 ${_('Show file after')}
610 610 </span>
611 611 %endif
612 612
613 613 % if use_comments:
614 614 |
615 615 <a href="#" onclick="Rhodecode.comments.toggleDiffComments(this);return toggleElement(this)"
616 616 data-toggle-on="${_('Hide comments')}"
617 617 data-toggle-off="${_('Show comments')}">
618 618 <span class="hide-comment-button">${_('Hide comments')}</span>
619 619 </a>
620 620 % endif
621 621
622 622 %endif
623 623
624 624 </div>
625 625 </%def>
626 626
627 627
628 628 <%def name="inline_comments_container(comments, active_pattern_entries=None, line_no='', f_path='')">
629 629
630 630 <div class="inline-comments">
631 631 %for comment in comments:
632 632 ${commentblock.comment_block(comment, inline=True, active_pattern_entries=active_pattern_entries)}
633 633 %endfor
634 634
635 635 <%
636 636 extra_class = ''
637 637 extra_style = ''
638 638
639 639 if comments and comments[-1].outdated_at_version(c.at_version_num):
640 640 extra_class = ' comment-outdated'
641 641 extra_style = 'display: none;'
642 642
643 643 %>
644 644
645 645 <div class="reply-thread-container-wrapper${extra_class}" style="${extra_style}">
646 646 <div class="reply-thread-container${extra_class}">
647 647 <div class="reply-thread-gravatar">
648 648 % if c.rhodecode_user.username != h.DEFAULT_USER:
649 649 ${base.gravatar(c.rhodecode_user.email, 20, tooltip=True, user=c.rhodecode_user)}
650 650 % endif
651 651 </div>
652 652
653 653 <div class="reply-thread-reply-button">
654 654 % if c.rhodecode_user.username != h.DEFAULT_USER:
655 655 ## initial reply button, some JS logic can append here a FORM to leave a first comment.
656 656 <button class="cb-comment-add-button" onclick="return Rhodecode.comments.createComment(this, '${f_path}', '${line_no}', null)">Reply...</button>
657 657 % endif
658 658 </div>
659 659 ##% endif
660 660 <div class="reply-thread-last"></div>
661 661 </div>
662 662 </div>
663 663 </div>
664 664
665 665 </%def>
666 666
667 667 <%!
668 668
669 669 def get_inline_comments(comments, filename):
670 670 if hasattr(filename, 'unicode_path'):
671 671 filename = filename.unicode_path
672 672
673 if not isinstance(filename, (unicode, str)):
673 if not isinstance(filename, str):
674 674 return None
675 675
676 676 if comments and filename in comments:
677 677 return comments[filename]
678 678
679 679 return None
680 680
681 681 def get_comments_for(diff_type, comments, filename, line_version, line_number):
682 682 if hasattr(filename, 'unicode_path'):
683 683 filename = filename.unicode_path
684 684
685 if not isinstance(filename, (unicode, str)):
685 if not isinstance(filename, str):
686 686 return None
687 687
688 688 file_comments = get_inline_comments(comments, filename)
689 689 if file_comments is None:
690 690 return None
691 691
692 692 line_key = '{}{}'.format(line_version, line_number) ## e.g o37, n12
693 693 if line_key in file_comments:
694 694 data = file_comments.pop(line_key)
695 695 return data
696 696 %>
697 697
698 698 <%def name="render_hunk_lines_sideside(filediff, hunk, use_comments=False, inline_comments=None, active_pattern_entries=None)">
699 699
700 700 <% chunk_count = 1 %>
701 701 %for loop_obj, item in h.looper(hunk.sideside):
702 702 <%
703 703 line = item
704 704 i = loop_obj.index
705 705 prev_line = loop_obj.previous
706 706 old_line_anchor, new_line_anchor = None, None
707 707
708 708 if line.original.lineno:
709 709 old_line_anchor = diff_line_anchor(filediff.raw_id, hunk.source_file_path, line.original.lineno, 'o')
710 710 if line.modified.lineno:
711 711 new_line_anchor = diff_line_anchor(filediff.raw_id, hunk.target_file_path, line.modified.lineno, 'n')
712 712
713 713 line_action = line.modified.action or line.original.action
714 714 prev_line_action = prev_line and (prev_line.modified.action or prev_line.original.action)
715 715 %>
716 716
717 717 <tr class="cb-line">
718 718 <td class="cb-data ${action_class(line.original.action)}"
719 719 data-line-no="${line.original.lineno}"
720 720 >
721 721
722 722 <% line_old_comments, line_old_comments_no_drafts = None, None %>
723 723 %if line.original.get_comment_args:
724 724 <%
725 725 line_old_comments = get_comments_for('side-by-side', inline_comments, *line.original.get_comment_args)
726 726 line_old_comments_no_drafts = [c for c in line_old_comments if not c.draft] if line_old_comments else []
727 727 has_outdated = any([x.outdated for x in line_old_comments_no_drafts])
728 728 %>
729 729 %endif
730 730 %if line_old_comments_no_drafts:
731 731 % if has_outdated:
732 732 <i class="tooltip toggle-comment-action icon-comment-toggle" title="${_('Comments including outdated: {}. Click here to toggle them.').format(len(line_old_comments_no_drafts))}" onclick="return Rhodecode.comments.toggleLineComments(this)"></i>
733 733 % else:
734 734 <i class="tooltip toggle-comment-action icon-comment" title="${_('Comments: {}. Click to toggle them.').format(len(line_old_comments_no_drafts))}" onclick="return Rhodecode.comments.toggleLineComments(this)"></i>
735 735 % endif
736 736 %endif
737 737 </td>
738 738 <td class="cb-lineno ${action_class(line.original.action)}"
739 739 data-line-no="${line.original.lineno}"
740 740 %if old_line_anchor:
741 741 id="${old_line_anchor}"
742 742 %endif
743 743 >
744 744 %if line.original.lineno:
745 745 <a name="${old_line_anchor}" href="#${old_line_anchor}">${line.original.lineno}</a>
746 746 %endif
747 747 </td>
748 748
749 749 <% line_no = 'o{}'.format(line.original.lineno) %>
750 750 <td class="cb-content ${action_class(line.original.action)}"
751 751 data-line-no="${line_no}"
752 752 >
753 753 %if use_comments and line.original.lineno:
754 754 ${render_add_comment_button(line_no=line_no, f_path=filediff.patch['filename'])}
755 755 %endif
756 756 <span class="cb-code"><span class="cb-action ${action_class(line.original.action)}"></span>${line.original.content or '' | n}</span>
757 757
758 758 %if use_comments and line.original.lineno and line_old_comments:
759 759 ${inline_comments_container(line_old_comments, active_pattern_entries=active_pattern_entries, line_no=line_no, f_path=filediff.patch['filename'])}
760 760 %endif
761 761
762 762 </td>
763 763 <td class="cb-data ${action_class(line.modified.action)}"
764 764 data-line-no="${line.modified.lineno}"
765 765 >
766 766 <div>
767 767
768 768 <% line_new_comments, line_new_comments_no_drafts = None, None %>
769 769 %if line.modified.get_comment_args:
770 770 <%
771 771 line_new_comments = get_comments_for('side-by-side', inline_comments, *line.modified.get_comment_args)
772 772 line_new_comments_no_drafts = [c for c in line_new_comments if not c.draft] if line_new_comments else []
773 773 has_outdated = any([x.outdated for x in line_new_comments_no_drafts])
774 774 %>
775 775 %endif
776 776
777 777 %if line_new_comments_no_drafts:
778 778 % if has_outdated:
779 779 <i class="tooltip toggle-comment-action icon-comment-toggle" title="${_('Comments including outdated: {}. Click here to toggle them.').format(len(line_new_comments_no_drafts))}" onclick="return Rhodecode.comments.toggleLineComments(this)"></i>
780 780 % else:
781 781 <i class="tooltip toggle-comment-action icon-comment" title="${_('Comments: {}. Click to toggle them.').format(len(line_new_comments_no_drafts))}" onclick="return Rhodecode.comments.toggleLineComments(this)"></i>
782 782 % endif
783 783 %endif
784 784 </div>
785 785 </td>
786 786 <td class="cb-lineno ${action_class(line.modified.action)}"
787 787 data-line-no="${line.modified.lineno}"
788 788 %if new_line_anchor:
789 789 id="${new_line_anchor}"
790 790 %endif
791 791 >
792 792 %if line.modified.lineno:
793 793 <a name="${new_line_anchor}" href="#${new_line_anchor}">${line.modified.lineno}</a>
794 794 %endif
795 795 </td>
796 796
797 797 <% line_no = 'n{}'.format(line.modified.lineno) %>
798 798 <td class="cb-content ${action_class(line.modified.action)}"
799 799 data-line-no="${line_no}"
800 800 >
801 801 %if use_comments and line.modified.lineno:
802 802 ${render_add_comment_button(line_no=line_no, f_path=filediff.patch['filename'])}
803 803 %endif
804 804 <span class="cb-code"><span class="cb-action ${action_class(line.modified.action)}"></span>${line.modified.content or '' | n}</span>
805 805 % if line_action in ['+', '-'] and prev_line_action not in ['+', '-']:
806 806 <div class="nav-chunk" style="visibility: hidden">
807 807 <i class="icon-eye" title="viewing diff hunk-${hunk.index}-${chunk_count}"></i>
808 808 </div>
809 809 <% chunk_count +=1 %>
810 810 % endif
811 811 %if use_comments and line.modified.lineno and line_new_comments:
812 812 ${inline_comments_container(line_new_comments, active_pattern_entries=active_pattern_entries, line_no=line_no, f_path=filediff.patch['filename'])}
813 813 %endif
814 814
815 815 </td>
816 816 </tr>
817 817 %endfor
818 818 </%def>
819 819
820 820
821 821 <%def name="render_hunk_lines_unified(filediff, hunk, use_comments=False, inline_comments=None, active_pattern_entries=None)">
822 822 %for old_line_no, new_line_no, action, content, comments_args in hunk.unified:
823 823
824 824 <%
825 825 old_line_anchor, new_line_anchor = None, None
826 826 if old_line_no:
827 827 old_line_anchor = diff_line_anchor(filediff.raw_id, hunk.source_file_path, old_line_no, 'o')
828 828 if new_line_no:
829 829 new_line_anchor = diff_line_anchor(filediff.raw_id, hunk.target_file_path, new_line_no, 'n')
830 830 %>
831 831 <tr class="cb-line">
832 832 <td class="cb-data ${action_class(action)}">
833 833 <div>
834 834
835 835 <% comments, comments_no_drafts = None, None %>
836 836 %if comments_args:
837 837 <%
838 838 comments = get_comments_for('unified', inline_comments, *comments_args)
839 839 comments_no_drafts = [c for c in line_new_comments if not c.draft] if line_new_comments else []
840 840 has_outdated = any([x.outdated for x in comments_no_drafts])
841 841 %>
842 842 %endif
843 843
844 844 % if comments_no_drafts:
845 845 % if has_outdated:
846 846 <i class="tooltip toggle-comment-action icon-comment-toggle" title="${_('Comments including outdated: {}. Click here to toggle them.').format(len(comments_no_drafts))}" onclick="return Rhodecode.comments.toggleLineComments(this)"></i>
847 847 % else:
848 848 <i class="tooltip toggle-comment-action icon-comment" title="${_('Comments: {}. Click to toggle them.').format(len(comments_no_drafts))}" onclick="return Rhodecode.comments.toggleLineComments(this)"></i>
849 849 % endif
850 850 % endif
851 851 </div>
852 852 </td>
853 853 <td class="cb-lineno ${action_class(action)}"
854 854 data-line-no="${old_line_no}"
855 855 %if old_line_anchor:
856 856 id="${old_line_anchor}"
857 857 %endif
858 858 >
859 859 %if old_line_anchor:
860 860 <a name="${old_line_anchor}" href="#${old_line_anchor}">${old_line_no}</a>
861 861 %endif
862 862 </td>
863 863 <td class="cb-lineno ${action_class(action)}"
864 864 data-line-no="${new_line_no}"
865 865 %if new_line_anchor:
866 866 id="${new_line_anchor}"
867 867 %endif
868 868 >
869 869 %if new_line_anchor:
870 870 <a name="${new_line_anchor}" href="#${new_line_anchor}">${new_line_no}</a>
871 871 %endif
872 872 </td>
873 873 <% line_no = '{}{}'.format(new_line_no and 'n' or 'o', new_line_no or old_line_no) %>
874 874 <td class="cb-content ${action_class(action)}"
875 875 data-line-no="${line_no}"
876 876 >
877 877 %if use_comments:
878 878 ${render_add_comment_button(line_no=line_no, f_path=filediff.patch['filename'])}
879 879 %endif
880 880 <span class="cb-code"><span class="cb-action ${action_class(action)}"></span> ${content or '' | n}</span>
881 881 %if use_comments and comments:
882 882 ${inline_comments_container(comments, active_pattern_entries=active_pattern_entries, line_no=line_no, f_path=filediff.patch['filename'])}
883 883 %endif
884 884 </td>
885 885 </tr>
886 886 %endfor
887 887 </%def>
888 888
889 889
890 890 <%def name="render_hunk_lines(filediff, diff_mode, hunk, use_comments, inline_comments, active_pattern_entries)">
891 891 % if diff_mode == 'unified':
892 892 ${render_hunk_lines_unified(filediff, hunk, use_comments=use_comments, inline_comments=inline_comments, active_pattern_entries=active_pattern_entries)}
893 893 % elif diff_mode == 'sideside':
894 894 ${render_hunk_lines_sideside(filediff, hunk, use_comments=use_comments, inline_comments=inline_comments, active_pattern_entries=active_pattern_entries)}
895 895 % else:
896 896 <tr class="cb-line">
897 897 <td>unknown diff mode</td>
898 898 </tr>
899 899 % endif
900 900 </%def>file changes
901 901
902 902
903 903 <%def name="render_add_comment_button(line_no='', f_path='')">
904 904 % if not c.rhodecode_user.is_default:
905 905 <button class="btn btn-small btn-primary cb-comment-box-opener" onclick="return Rhodecode.comments.createComment(this, '${f_path}', '${line_no}', null)">
906 906 <span><i class="icon-comment"></i></span>
907 907 </button>
908 908 % endif
909 909 </%def>
910 910
911 911 <%def name="render_diffset_menu(diffset, range_diff_on=None, commit=None, pull_request_menu=None)">
912 912 <% diffset_container_id = h.md5(diffset.target_ref) %>
913 913
914 914 <div id="diff-file-sticky" class="diffset-menu clearinner">
915 915 ## auto adjustable
916 916 <div class="sidebar__inner">
917 917 <div class="sidebar__bar">
918 918 <div class="pull-right">
919 919
920 920 <div class="btn-group" style="margin-right: 5px;">
921 921 <a class="tooltip btn" onclick="scrollDown();return false" title="${_('Scroll to page bottom')}">
922 922 <i class="icon-arrow_down"></i>
923 923 </a>
924 924 <a class="tooltip btn" onclick="scrollUp();return false" title="${_('Scroll to page top')}">
925 925 <i class="icon-arrow_up"></i>
926 926 </a>
927 927 </div>
928 928
929 929 <div class="btn-group">
930 930 <a class="btn tooltip toggle-wide-diff" href="#toggle-wide-diff" onclick="toggleWideDiff(this); return false" title="${h.tooltip(_('Toggle wide diff'))}">
931 931 <i class="icon-wide-mode"></i>
932 932 </a>
933 933 </div>
934 934 <div class="btn-group">
935 935
936 936 <a
937 937 class="btn ${(c.user_session_attrs["diffmode"] == 'sideside' and 'btn-active')} tooltip"
938 938 title="${h.tooltip(_('View diff as side by side'))}"
939 939 href="${h.current_route_path(request, diffmode='sideside')}">
940 940 <span>${_('Side by Side')}</span>
941 941 </a>
942 942
943 943 <a
944 944 class="btn ${(c.user_session_attrs["diffmode"] == 'unified' and 'btn-active')} tooltip"
945 945 title="${h.tooltip(_('View diff as unified'))}" href="${h.current_route_path(request, diffmode='unified')}">
946 946 <span>${_('Unified')}</span>
947 947 </a>
948 948
949 949 % if range_diff_on is True:
950 950 <a
951 951 title="${_('Turn off: Show the diff as commit range')}"
952 952 class="btn btn-primary"
953 953 href="${h.current_route_path(request, **{"range-diff":"0"})}">
954 954 <span>${_('Range Diff')}</span>
955 955 </a>
956 956 % elif range_diff_on is False:
957 957 <a
958 958 title="${_('Show the diff as commit range')}"
959 959 class="btn"
960 960 href="${h.current_route_path(request, **{"range-diff":"1"})}">
961 961 <span>${_('Range Diff')}</span>
962 962 </a>
963 963 % endif
964 964 </div>
965 965 <div class="btn-group">
966 966
967 967 <details class="details-reset details-inline-block">
968 968 <summary class="noselect btn">
969 969 <i class="icon-options cursor-pointer" op="options"></i>
970 970 </summary>
971 971
972 972 <div>
973 973 <details-menu class="details-dropdown" style="top: 35px;">
974 974
975 975 <div class="dropdown-item">
976 976 <div style="padding: 2px 0px">
977 977 % if request.GET.get('ignorews', '') == '1':
978 978 <a href="${h.current_route_path(request, ignorews=0)}">${_('Show whitespace changes')}</a>
979 979 % else:
980 980 <a href="${h.current_route_path(request, ignorews=1)}">${_('Hide whitespace changes')}</a>
981 981 % endif
982 982 </div>
983 983 </div>
984 984
985 985 <div class="dropdown-item">
986 986 <div style="padding: 2px 0px">
987 987 % if request.GET.get('fullcontext', '') == '1':
988 988 <a href="${h.current_route_path(request, fullcontext=0)}">${_('Hide full context diff')}</a>
989 989 % else:
990 990 <a href="${h.current_route_path(request, fullcontext=1)}">${_('Show full context diff')}</a>
991 991 % endif
992 992 </div>
993 993 </div>
994 994
995 995 </details-menu>
996 996 </div>
997 997 </details>
998 998
999 999 </div>
1000 1000 </div>
1001 1001 <div class="pull-left">
1002 1002 <div class="btn-group">
1003 1003 <div class="pull-left">
1004 1004 ${h.hidden('file_filter_{}'.format(diffset_container_id))}
1005 1005 </div>
1006 1006
1007 1007 </div>
1008 1008 </div>
1009 1009 </div>
1010 1010 <div class="fpath-placeholder pull-left">
1011 1011 <i class="icon-file-text"></i>
1012 1012 <strong class="fpath-placeholder-text">
1013 1013 Context file:
1014 1014 </strong>
1015 1015 </div>
1016 1016 <div class="pull-right noselect">
1017 1017 %if commit:
1018 1018 <span>
1019 1019 <code>${h.show_id(commit)}</code>
1020 1020 </span>
1021 1021 %elif pull_request_menu and pull_request_menu.get('pull_request'):
1022 1022 <span>
1023 1023 <code>!${pull_request_menu['pull_request'].pull_request_id}</code>
1024 1024 </span>
1025 1025 %endif
1026 1026 % if commit or pull_request_menu:
1027 1027 <span class="tooltip" title="Navigate to previous or next change inside files." id="diff_nav">Loading diff...:</span>
1028 1028 <span class="cursor-pointer" onclick="scrollToPrevChunk(); return false">
1029 1029 <i class="icon-angle-up"></i>
1030 1030 </span>
1031 1031 <span class="cursor-pointer" onclick="scrollToNextChunk(); return false">
1032 1032 <i class="icon-angle-down"></i>
1033 1033 </span>
1034 1034 % endif
1035 1035 </div>
1036 1036 <div class="sidebar_inner_shadow"></div>
1037 1037 </div>
1038 1038 </div>
1039 1039
1040 1040 % if diffset:
1041 1041 %if diffset.limited_diff:
1042 1042 <% file_placeholder = _ungettext('%(num)s file changed', '%(num)s files changed', diffset.changed_files) % {'num': diffset.changed_files} %>
1043 1043 %else:
1044 1044 <% file_placeholder = h.literal(_ungettext('%(num)s file changed: <span class="op-added">%(linesadd)s inserted</span>, <span class="op-deleted">%(linesdel)s deleted</span>', '%(num)s files changed: <span class="op-added">%(linesadd)s inserted</span>, <span class="op-deleted">%(linesdel)s deleted</span>',
1045 1045 diffset.changed_files) % {'num': diffset.changed_files, 'linesadd': diffset.lines_added, 'linesdel': diffset.lines_deleted}) %>
1046 1046
1047 1047 %endif
1048 1048 ## case on range-diff placeholder needs to be updated
1049 1049 % if range_diff_on is True:
1050 1050 <% file_placeholder = _('Disabled on range diff') %>
1051 1051 % endif
1052 1052
1053 1053 <script type="text/javascript">
1054 1054 var feedFilesOptions = function (query, initialData) {
1055 1055 var data = {results: []};
1056 1056 var isQuery = typeof query.term !== 'undefined';
1057 1057
1058 1058 var section = _gettext('Changed files');
1059 1059 var filteredData = [];
1060 1060
1061 1061 //filter results
1062 1062 $.each(initialData.results, function (idx, value) {
1063 1063
1064 1064 if (!isQuery || query.term.length === 0 || value.text.toUpperCase().indexOf(query.term.toUpperCase()) >= 0) {
1065 1065 filteredData.push({
1066 1066 'id': this.id,
1067 1067 'text': this.text,
1068 1068 "ops": this.ops,
1069 1069 })
1070 1070 }
1071 1071
1072 1072 });
1073 1073
1074 1074 data.results = filteredData;
1075 1075
1076 1076 query.callback(data);
1077 1077 };
1078 1078
1079 1079 var selectionFormatter = function(data, escapeMarkup) {
1080 1080 var container = '<div class="filelist" style="padding-right:100px">{0}</div>';
1081 1081 var tmpl = '<div><strong>{0}</strong></div>'.format(escapeMarkup(data['text']));
1082 1082 var pill = '<div class="pill-group" style="position: absolute; top:7px; right: 0">' +
1083 1083 '<span class="pill" op="added">{0}</span>' +
1084 1084 '<span class="pill" op="deleted">{1}</span>' +
1085 1085 '</div>'
1086 1086 ;
1087 1087 var added = data['ops']['added'];
1088 1088 if (added === 0) {
1089 1089 // don't show +0
1090 1090 added = 0;
1091 1091 } else {
1092 1092 added = '+' + added;
1093 1093 }
1094 1094
1095 1095 var deleted = -1*data['ops']['deleted'];
1096 1096
1097 1097 tmpl += pill.format(added, deleted);
1098 1098 return container.format(tmpl);
1099 1099 };
1100 1100 var formatFileResult = function(result, container, query, escapeMarkup) {
1101 1101 return selectionFormatter(result, escapeMarkup);
1102 1102 };
1103 1103
1104 1104 var formatSelection = function (data, container) {
1105 1105 return '${file_placeholder}'
1106 1106 };
1107 1107
1108 1108 if (window.preloadFileFilterData === undefined) {
1109 1109 window.preloadFileFilterData = {}
1110 1110 }
1111 1111
1112 1112 preloadFileFilterData["${diffset_container_id}"] = {
1113 1113 results: [
1114 1114 % for filediff in diffset.files:
1115 1115 {id:"a_${h.FID(filediff.raw_id, filediff.patch['filename'])}",
1116 1116 text:"${filediff.patch['filename']}",
1117 1117 ops:${h.json.dumps(filediff.patch['stats'])|n}}${('' if loop.last else ',')}
1118 1118 % endfor
1119 1119 ]
1120 1120 };
1121 1121
1122 1122 var diffFileFilterId = "#file_filter_" + "${diffset_container_id}";
1123 1123 var diffFileFilter = $(diffFileFilterId).select2({
1124 1124 'dropdownAutoWidth': true,
1125 1125 'width': 'auto',
1126 1126
1127 1127 containerCssClass: "drop-menu",
1128 1128 dropdownCssClass: "drop-menu-dropdown",
1129 1129 data: preloadFileFilterData["${diffset_container_id}"],
1130 1130 query: function(query) {
1131 1131 feedFilesOptions(query, preloadFileFilterData["${diffset_container_id}"]);
1132 1132 },
1133 1133 initSelection: function(element, callback) {
1134 1134 callback({'init': true});
1135 1135 },
1136 1136 formatResult: formatFileResult,
1137 1137 formatSelection: formatSelection
1138 1138 });
1139 1139
1140 1140 % if range_diff_on is True:
1141 1141 diffFileFilter.select2("enable", false);
1142 1142 % endif
1143 1143
1144 1144 $(diffFileFilterId).on('select2-selecting', function (e) {
1145 1145 var idSelector = e.choice.id;
1146 1146
1147 1147 // expand the container if we quick-select the field
1148 1148 $('#'+idSelector).next().prop('checked', false);
1149 1149 // hide the mast as we later do preventDefault()
1150 1150 $("#select2-drop-mask").click();
1151 1151
1152 1152 window.location.hash = '#'+idSelector;
1153 1153 updateSticky();
1154 1154
1155 1155 e.preventDefault();
1156 1156 });
1157 1157
1158 1158 diffNavText = 'diff navigation:'
1159 1159
1160 1160 getCurrentChunk = function () {
1161 1161
1162 1162 var chunksAll = $('.nav-chunk').filter(function () {
1163 1163 return $(this).parents('.filediff').prev().get(0).checked !== true
1164 1164 })
1165 1165 var chunkSelected = $('.nav-chunk.selected');
1166 1166 var initial = false;
1167 1167
1168 1168 if (chunkSelected.length === 0) {
1169 1169 // no initial chunk selected, we pick first
1170 1170 chunkSelected = $(chunksAll.get(0));
1171 1171 var initial = true;
1172 1172 }
1173 1173
1174 1174 return {
1175 1175 'all': chunksAll,
1176 1176 'selected': chunkSelected,
1177 1177 'initial': initial,
1178 1178 }
1179 1179 }
1180 1180
1181 1181 animateDiffNavText = function () {
1182 1182 var $diffNav = $('#diff_nav')
1183 1183
1184 1184 var callback = function () {
1185 1185 $diffNav.animate({'opacity': 1.00}, 200)
1186 1186 };
1187 1187 $diffNav.animate({'opacity': 0.15}, 200, callback);
1188 1188 }
1189 1189
1190 1190 scrollToChunk = function (moveBy) {
1191 1191 var chunk = getCurrentChunk();
1192 1192 var all = chunk.all
1193 1193 var selected = chunk.selected
1194 1194
1195 1195 var curPos = all.index(selected);
1196 1196 var newPos = curPos;
1197 1197 if (!chunk.initial) {
1198 1198 var newPos = curPos + moveBy;
1199 1199 }
1200 1200
1201 1201 var curElem = all.get(newPos);
1202 1202
1203 1203 if (curElem === undefined) {
1204 1204 // end or back
1205 1205 $('#diff_nav').html('no next diff element:')
1206 1206 animateDiffNavText()
1207 1207 return
1208 1208 } else if (newPos < 0) {
1209 1209 $('#diff_nav').html('no previous diff element:')
1210 1210 animateDiffNavText()
1211 1211 return
1212 1212 } else {
1213 1213 $('#diff_nav').html(diffNavText)
1214 1214 }
1215 1215
1216 1216 curElem = $(curElem)
1217 1217 var offset = 100;
1218 1218 $(window).scrollTop(curElem.position().top - offset);
1219 1219
1220 1220 //clear selection
1221 1221 all.removeClass('selected')
1222 1222 curElem.addClass('selected')
1223 1223 }
1224 1224
1225 1225 scrollToPrevChunk = function () {
1226 1226 scrollToChunk(-1)
1227 1227 }
1228 1228 scrollToNextChunk = function () {
1229 1229 scrollToChunk(1)
1230 1230 }
1231 1231
1232 1232 </script>
1233 1233 % endif
1234 1234
1235 1235 <script type="text/javascript">
1236 1236 $('#diff_nav').html('loading diff...') // wait until whole page is loaded
1237 1237
1238 1238 $(document).ready(function () {
1239 1239
1240 1240 var contextPrefix = _gettext('Context file: ');
1241 1241 ## sticky sidebar
1242 1242 var sidebarElement = document.getElementById('diff-file-sticky');
1243 1243 sidebar = new StickySidebar(sidebarElement, {
1244 1244 topSpacing: 0,
1245 1245 bottomSpacing: 0,
1246 1246 innerWrapperSelector: '.sidebar__inner'
1247 1247 });
1248 1248 sidebarElement.addEventListener('affixed.static.stickySidebar', function () {
1249 1249 // reset our file so it's not holding new value
1250 1250 $('.fpath-placeholder-text').html(contextPrefix + ' - ')
1251 1251 });
1252 1252
1253 1253 updateSticky = function () {
1254 1254 sidebar.updateSticky();
1255 1255 Waypoint.refreshAll();
1256 1256 };
1257 1257
1258 1258 var animateText = function (fPath, anchorId) {
1259 1259 fPath = Select2.util.escapeMarkup(fPath);
1260 1260 $('.fpath-placeholder-text').html(contextPrefix + '<a href="#a_' + anchorId + '">' + fPath + '</a>')
1261 1261 };
1262 1262
1263 1263 ## dynamic file waypoints
1264 1264 var setFPathInfo = function(fPath, anchorId){
1265 1265 animateText(fPath, anchorId)
1266 1266 };
1267 1267
1268 1268 var codeBlock = $('.filediff');
1269 1269
1270 1270 // forward waypoint
1271 1271 codeBlock.waypoint(
1272 1272 function(direction) {
1273 1273 if (direction === "down"){
1274 1274 setFPathInfo($(this.element).data('fPath'), $(this.element).data('anchorId'))
1275 1275 }
1276 1276 }, {
1277 1277 offset: function () {
1278 1278 return 70;
1279 1279 },
1280 1280 context: '.fpath-placeholder'
1281 1281 }
1282 1282 );
1283 1283
1284 1284 // backward waypoint
1285 1285 codeBlock.waypoint(
1286 1286 function(direction) {
1287 1287 if (direction === "up"){
1288 1288 setFPathInfo($(this.element).data('fPath'), $(this.element).data('anchorId'))
1289 1289 }
1290 1290 }, {
1291 1291 offset: function () {
1292 1292 return -this.element.clientHeight + 90;
1293 1293 },
1294 1294 context: '.fpath-placeholder'
1295 1295 }
1296 1296 );
1297 1297
1298 1298 toggleWideDiff = function (el) {
1299 1299 updateSticky();
1300 1300 var wide = Rhodecode.comments.toggleWideMode(this);
1301 1301 storeUserSessionAttr('rc_user_session_attr.wide_diff_mode', wide);
1302 1302 if (wide === true) {
1303 1303 $(el).addClass('btn-active');
1304 1304 } else {
1305 1305 $(el).removeClass('btn-active');
1306 1306 }
1307 1307 return null;
1308 1308 };
1309 1309
1310 1310 toggleExpand = function (el, diffsetEl) {
1311 1311 var el = $(el);
1312 1312 if (el.hasClass('collapsed')) {
1313 1313 $('.filediff-collapse-state.collapse-{0}'.format(diffsetEl)).prop('checked', false);
1314 1314 el.removeClass('collapsed');
1315 1315 el.html(
1316 1316 '<i class="icon-minus-squared-alt icon-no-margin"></i>' +
1317 1317 _gettext('Collapse all files'));
1318 1318 }
1319 1319 else {
1320 1320 $('.filediff-collapse-state.collapse-{0}'.format(diffsetEl)).prop('checked', true);
1321 1321 el.addClass('collapsed');
1322 1322 el.html(
1323 1323 '<i class="icon-plus-squared-alt icon-no-margin"></i>' +
1324 1324 _gettext('Expand all files'));
1325 1325 }
1326 1326 updateSticky()
1327 1327 };
1328 1328
1329 1329 toggleCommitExpand = function (el) {
1330 1330 var $el = $(el);
1331 1331 var commits = $el.data('toggleCommitsCnt');
1332 1332 var collapseMsg = _ngettext('Collapse {0} commit', 'Collapse {0} commits', commits).format(commits);
1333 1333 var expandMsg = _ngettext('Expand {0} commit', 'Expand {0} commits', commits).format(commits);
1334 1334
1335 1335 if ($el.hasClass('collapsed')) {
1336 1336 $('.compare_select').show();
1337 1337 $('.compare_select_hidden').hide();
1338 1338
1339 1339 $el.removeClass('collapsed');
1340 1340 $el.html(
1341 1341 '<i class="icon-minus-squared-alt icon-no-margin"></i>' +
1342 1342 collapseMsg);
1343 1343 }
1344 1344 else {
1345 1345 $('.compare_select').hide();
1346 1346 $('.compare_select_hidden').show();
1347 1347 $el.addClass('collapsed');
1348 1348 $el.html(
1349 1349 '<i class="icon-plus-squared-alt icon-no-margin"></i>' +
1350 1350 expandMsg);
1351 1351 }
1352 1352 updateSticky();
1353 1353 };
1354 1354
1355 1355 // get stored diff mode and pre-enable it
1356 1356 if (templateContext.session_attrs.wide_diff_mode === "true") {
1357 1357 Rhodecode.comments.toggleWideMode(null);
1358 1358 $('.toggle-wide-diff').addClass('btn-active');
1359 1359 updateSticky();
1360 1360 }
1361 1361
1362 1362 // DIFF NAV //
1363 1363
1364 1364 // element to detect scroll direction of
1365 1365 var $window = $(window);
1366 1366
1367 1367 // initialize last scroll position
1368 1368 var lastScrollY = $window.scrollTop();
1369 1369
1370 1370 $window.on('resize scrollstop', {latency: 350}, function () {
1371 1371 var visibleChunks = $('.nav-chunk').withinviewport({top: 75});
1372 1372
1373 1373 // get current scroll position
1374 1374 var currentScrollY = $window.scrollTop();
1375 1375
1376 1376 // determine current scroll direction
1377 1377 if (currentScrollY > lastScrollY) {
1378 1378 var y = 'down'
1379 1379 } else if (currentScrollY !== lastScrollY) {
1380 1380 var y = 'up';
1381 1381 }
1382 1382
1383 1383 var pos = -1; // by default we use last element in viewport
1384 1384 if (y === 'down') {
1385 1385 pos = -1;
1386 1386 } else if (y === 'up') {
1387 1387 pos = 0;
1388 1388 }
1389 1389
1390 1390 if (visibleChunks.length > 0) {
1391 1391 $('.nav-chunk').removeClass('selected');
1392 1392 $(visibleChunks.get(pos)).addClass('selected');
1393 1393 }
1394 1394
1395 1395 // update last scroll position to current position
1396 1396 lastScrollY = currentScrollY;
1397 1397
1398 1398 });
1399 1399 $('#diff_nav').html(diffNavText);
1400 1400
1401 1401 });
1402 1402 </script>
1403 1403
1404 1404 </%def>
@@ -1,209 +1,208 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
21 from io import StringIO
20 import io
22 21
23 22 import pytest
24 23 from mock import patch, Mock
25 24
26 25 from rhodecode.lib.middleware.simplesvn import SimpleSvn, SimpleSvnApp
27 26 from rhodecode.lib.utils import get_rhodecode_base_path
28 27
29 28
30 29 class TestSimpleSvn(object):
31 30 @pytest.fixture(autouse=True)
32 31 def simple_svn(self, baseapp, request_stub):
33 32 base_path = get_rhodecode_base_path()
34 33 self.app = SimpleSvn(
35 34 config={'auth_ret_code': '', 'base_path': base_path},
36 35 registry=request_stub.registry)
37 36
38 37 def test_get_config(self):
39 38 extras = {'foo': 'FOO', 'bar': 'BAR'}
40 39 config = self.app._create_config(extras, repo_name='test-repo')
41 40 assert config == extras
42 41
43 42 @pytest.mark.parametrize(
44 43 'method', ['OPTIONS', 'PROPFIND', 'GET', 'REPORT'])
45 44 def test_get_action_returns_pull(self, method):
46 45 environment = {'REQUEST_METHOD': method}
47 46 action = self.app._get_action(environment)
48 47 assert action == 'pull'
49 48
50 49 @pytest.mark.parametrize(
51 50 'method', [
52 51 'MKACTIVITY', 'PROPPATCH', 'PUT', 'CHECKOUT', 'MKCOL', 'MOVE',
53 52 'COPY', 'DELETE', 'LOCK', 'UNLOCK', 'MERGE'
54 53 ])
55 54 def test_get_action_returns_push(self, method):
56 55 environment = {'REQUEST_METHOD': method}
57 56 action = self.app._get_action(environment)
58 57 assert action == 'push'
59 58
60 59 @pytest.mark.parametrize(
61 60 'path, expected_name', [
62 61 ('/hello-svn', 'hello-svn'),
63 62 ('/hello-svn/', 'hello-svn'),
64 63 ('/group/hello-svn/', 'group/hello-svn'),
65 64 ('/group/hello-svn/!svn/vcc/default', 'group/hello-svn'),
66 65 ])
67 66 def test_get_repository_name(self, path, expected_name):
68 67 environment = {'PATH_INFO': path}
69 68 name = self.app._get_repository_name(environment)
70 69 assert name == expected_name
71 70
72 71 def test_get_repository_name_subfolder(self, backend_svn):
73 72 repo = backend_svn.repo
74 73 environment = {
75 74 'PATH_INFO': '/{}/path/with/subfolders'.format(repo.repo_name)}
76 75 name = self.app._get_repository_name(environment)
77 76 assert name == repo.repo_name
78 77
79 78 def test_create_wsgi_app(self):
80 79 with patch.object(SimpleSvn, '_is_svn_enabled') as mock_method:
81 80 mock_method.return_value = False
82 81 with patch('rhodecode.lib.middleware.simplesvn.DisabledSimpleSvnApp') as (
83 82 wsgi_app_mock):
84 83 config = Mock()
85 84 wsgi_app = self.app._create_wsgi_app(
86 85 repo_path='', repo_name='', config=config)
87 86
88 87 wsgi_app_mock.assert_called_once_with(config)
89 88 assert wsgi_app == wsgi_app_mock()
90 89
91 90 def test_create_wsgi_app_when_enabled(self):
92 91 with patch.object(SimpleSvn, '_is_svn_enabled') as mock_method:
93 92 mock_method.return_value = True
94 93 with patch('rhodecode.lib.middleware.simplesvn.SimpleSvnApp') as (
95 94 wsgi_app_mock):
96 95 config = Mock()
97 96 wsgi_app = self.app._create_wsgi_app(
98 97 repo_path='', repo_name='', config=config)
99 98
100 99 wsgi_app_mock.assert_called_once_with(config)
101 100 assert wsgi_app == wsgi_app_mock()
102 101
103 102
104 103 class TestSimpleSvnApp(object):
105 104 data = '<xml></xml>'
106 105 path = '/group/my-repo'
107 wsgi_input = StringIO(data)
106 wsgi_input = io.StringIO(data)
108 107 environment = {
109 108 'HTTP_DAV': (
110 109 'http://subversion.tigris.org/xmlns/dav/svn/depth,'
111 110 ' http://subversion.tigris.org/xmlns/dav/svn/mergeinfo'),
112 111 'HTTP_USER_AGENT': 'SVN/1.8.11 (x86_64-linux) serf/1.3.8',
113 112 'REQUEST_METHOD': 'OPTIONS',
114 113 'PATH_INFO': path,
115 114 'wsgi.input': wsgi_input,
116 115 'CONTENT_TYPE': 'text/xml',
117 116 'CONTENT_LENGTH': '130'
118 117 }
119 118
120 119 def setup_method(self, method):
121 120 self.host = 'http://localhost/'
122 121 base_path = get_rhodecode_base_path()
123 122 self.app = SimpleSvnApp(
124 123 config={'subversion_http_server_url': self.host,
125 124 'base_path': base_path})
126 125
127 126 def test_get_request_headers_with_content_type(self):
128 127 expected_headers = {
129 128 'Dav': self.environment['HTTP_DAV'],
130 129 'User-Agent': self.environment['HTTP_USER_AGENT'],
131 130 'Content-Type': self.environment['CONTENT_TYPE'],
132 131 'Content-Length': self.environment['CONTENT_LENGTH']
133 132 }
134 133 headers = self.app._get_request_headers(self.environment)
135 134 assert headers == expected_headers
136 135
137 136 def test_get_request_headers_without_content_type(self):
138 137 environment = self.environment.copy()
139 138 environment.pop('CONTENT_TYPE')
140 139 expected_headers = {
141 140 'Dav': environment['HTTP_DAV'],
142 141 'Content-Length': self.environment['CONTENT_LENGTH'],
143 142 'User-Agent': environment['HTTP_USER_AGENT'],
144 143 }
145 144 request_headers = self.app._get_request_headers(environment)
146 145 assert request_headers == expected_headers
147 146
148 147 def test_get_response_headers(self):
149 148 headers = {
150 149 'Connection': 'keep-alive',
151 150 'Keep-Alive': 'timeout=5, max=100',
152 151 'Transfer-Encoding': 'chunked',
153 152 'Content-Encoding': 'gzip',
154 153 'MS-Author-Via': 'DAV',
155 154 'SVN-Supported-Posts': 'create-txn-with-props'
156 155 }
157 156 expected_headers = [
158 157 ('MS-Author-Via', 'DAV'),
159 158 ('SVN-Supported-Posts', 'create-txn-with-props'),
160 159 ]
161 160 response_headers = self.app._get_response_headers(headers)
162 161 assert sorted(response_headers) == sorted(expected_headers)
163 162
164 163 @pytest.mark.parametrize('svn_http_url, path_info, expected_url', [
165 164 ('http://localhost:8200', '/repo_name', 'http://localhost:8200/repo_name'),
166 165 ('http://localhost:8200///', '/repo_name', 'http://localhost:8200/repo_name'),
167 166 ('http://localhost:8200', '/group/repo_name', 'http://localhost:8200/group/repo_name'),
168 167 ('http://localhost:8200/', '/group/repo_name', 'http://localhost:8200/group/repo_name'),
169 168 ('http://localhost:8200/prefix', '/repo_name', 'http://localhost:8200/prefix/repo_name'),
170 169 ('http://localhost:8200/prefix', 'repo_name', 'http://localhost:8200/prefix/repo_name'),
171 170 ('http://localhost:8200/prefix', '/group/repo_name', 'http://localhost:8200/prefix/group/repo_name')
172 171 ])
173 172 def test_get_url(self, svn_http_url, path_info, expected_url):
174 173 url = self.app._get_url(svn_http_url, path_info)
175 174 assert url == expected_url
176 175
177 176 def test_call(self):
178 177 start_response = Mock()
179 178 response_mock = Mock()
180 179 response_mock.headers = {
181 180 'Content-Encoding': 'gzip',
182 181 'MS-Author-Via': 'DAV',
183 182 'SVN-Supported-Posts': 'create-txn-with-props'
184 183 }
185 184 response_mock.status_code = 200
186 185 response_mock.reason = 'OK'
187 186 with patch('rhodecode.lib.middleware.simplesvn.requests.request') as (
188 187 request_mock):
189 188 request_mock.return_value = response_mock
190 189 self.app(self.environment, start_response)
191 190
192 191 expected_url = '{}{}'.format(self.host.strip('/'), self.path)
193 192 expected_request_headers = {
194 193 'Dav': self.environment['HTTP_DAV'],
195 194 'User-Agent': self.environment['HTTP_USER_AGENT'],
196 195 'Content-Type': self.environment['CONTENT_TYPE'],
197 196 'Content-Length': self.environment['CONTENT_LENGTH']
198 197 }
199 198 expected_response_headers = [
200 199 ('SVN-Supported-Posts', 'create-txn-with-props'),
201 200 ('MS-Author-Via', 'DAV'),
202 201 ]
203 202 request_mock.assert_called_once_with(
204 203 self.environment['REQUEST_METHOD'], expected_url,
205 204 data=self.data, headers=expected_request_headers, stream=False)
206 205 response_mock.iter_content.assert_called_once_with(chunk_size=1024)
207 206 args, _ = start_response.call_args
208 207 assert args[0] == '200 OK'
209 208 assert sorted(args[1]) == sorted(expected_response_headers)
@@ -1,342 +1,342 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import json
22 22 import logging
23 from io import StringIO
23 import io
24 24
25 25 import mock
26 26 import pytest
27 27
28 28 from rhodecode.lib import hooks_daemon
29 29 from rhodecode.tests.utils import assert_message_in_log
30 30
31 31
32 32 class TestDummyHooksCallbackDaemon(object):
33 33 def test_hooks_module_path_set_properly(self):
34 34 daemon = hooks_daemon.DummyHooksCallbackDaemon()
35 35 assert daemon.hooks_module == 'rhodecode.lib.hooks_daemon'
36 36
37 37 def test_logs_entering_the_hook(self):
38 38 daemon = hooks_daemon.DummyHooksCallbackDaemon()
39 39 with mock.patch.object(hooks_daemon.log, 'debug') as log_mock:
40 40 with daemon as return_value:
41 41 log_mock.assert_called_once_with(
42 42 'Running `%s` callback daemon', 'DummyHooksCallbackDaemon')
43 43 assert return_value == daemon
44 44
45 45 def test_logs_exiting_the_hook(self):
46 46 daemon = hooks_daemon.DummyHooksCallbackDaemon()
47 47 with mock.patch.object(hooks_daemon.log, 'debug') as log_mock:
48 48 with daemon:
49 49 pass
50 50 log_mock.assert_called_with(
51 51 'Exiting `%s` callback daemon', 'DummyHooksCallbackDaemon')
52 52
53 53
54 54 class TestHooks(object):
55 55 def test_hooks_can_be_used_as_a_context_processor(self):
56 56 hooks = hooks_daemon.Hooks()
57 57 with hooks as return_value:
58 58 pass
59 59 assert hooks == return_value
60 60
61 61
62 62 class TestHooksHttpHandler(object):
63 63 def test_read_request_parses_method_name_and_arguments(self):
64 64 data = {
65 65 'method': 'test',
66 66 'extras': {
67 67 'param1': 1,
68 68 'param2': 'a'
69 69 }
70 70 }
71 71 request = self._generate_post_request(data)
72 72 hooks_patcher = mock.patch.object(
73 73 hooks_daemon.Hooks, data['method'], create=True, return_value=1)
74 74
75 75 with hooks_patcher as hooks_mock:
76 76 MockServer(hooks_daemon.HooksHttpHandler, request)
77 77
78 78 hooks_mock.assert_called_once_with(data['extras'])
79 79
80 80 def test_hooks_serialized_result_is_returned(self):
81 81 request = self._generate_post_request({})
82 82 rpc_method = 'test'
83 83 hook_result = {
84 84 'first': 'one',
85 85 'second': 2
86 86 }
87 87 read_patcher = mock.patch.object(
88 88 hooks_daemon.HooksHttpHandler, '_read_request',
89 89 return_value=(rpc_method, {}))
90 90 hooks_patcher = mock.patch.object(
91 91 hooks_daemon.Hooks, rpc_method, create=True,
92 92 return_value=hook_result)
93 93
94 94 with read_patcher, hooks_patcher:
95 95 server = MockServer(hooks_daemon.HooksHttpHandler, request)
96 96
97 97 expected_result = json.dumps(hook_result)
98 98 assert server.request.output_stream.buflist[-1] == expected_result
99 99
100 100 def test_exception_is_returned_in_response(self):
101 101 request = self._generate_post_request({})
102 102 rpc_method = 'test'
103 103 read_patcher = mock.patch.object(
104 104 hooks_daemon.HooksHttpHandler, '_read_request',
105 105 return_value=(rpc_method, {}))
106 106 hooks_patcher = mock.patch.object(
107 107 hooks_daemon.Hooks, rpc_method, create=True,
108 108 side_effect=Exception('Test exception'))
109 109
110 110 with read_patcher, hooks_patcher:
111 111 server = MockServer(hooks_daemon.HooksHttpHandler, request)
112 112
113 113 org_exc = json.loads(server.request.output_stream.buflist[-1])
114 114 expected_result = {
115 115 'exception': 'Exception',
116 116 'exception_traceback': org_exc['exception_traceback'],
117 117 'exception_args': ['Test exception']
118 118 }
119 119 assert org_exc == expected_result
120 120
121 121 def test_log_message_writes_to_debug_log(self, caplog):
122 122 ip_port = ('0.0.0.0', 8888)
123 123 handler = hooks_daemon.HooksHttpHandler(
124 124 MockRequest('POST /'), ip_port, mock.Mock())
125 125 fake_date = '1/Nov/2015 00:00:00'
126 126 date_patcher = mock.patch.object(
127 127 handler, 'log_date_time_string', return_value=fake_date)
128 128 with date_patcher, caplog.at_level(logging.DEBUG):
129 129 handler.log_message('Some message %d, %s', 123, 'string')
130 130
131 131 expected_message = "HOOKS: {} - - [{}] Some message 123, string".format(ip_port, fake_date)
132 132 assert_message_in_log(
133 133 caplog.records, expected_message,
134 134 levelno=logging.DEBUG, module='hooks_daemon')
135 135
136 136 def _generate_post_request(self, data):
137 137 payload = json.dumps(data)
138 138 return 'POST / HTTP/1.0\nContent-Length: {}\n\n{}'.format(
139 139 len(payload), payload)
140 140
141 141
142 142 class ThreadedHookCallbackDaemon(object):
143 143 def test_constructor_calls_prepare(self):
144 144 prepare_daemon_patcher = mock.patch.object(
145 145 hooks_daemon.ThreadedHookCallbackDaemon, '_prepare')
146 146 with prepare_daemon_patcher as prepare_daemon_mock:
147 147 hooks_daemon.ThreadedHookCallbackDaemon()
148 148 prepare_daemon_mock.assert_called_once_with()
149 149
150 150 def test_run_is_called_on_context_start(self):
151 151 patchers = mock.patch.multiple(
152 152 hooks_daemon.ThreadedHookCallbackDaemon,
153 153 _run=mock.DEFAULT, _prepare=mock.DEFAULT, __exit__=mock.DEFAULT)
154 154
155 155 with patchers as mocks:
156 156 daemon = hooks_daemon.ThreadedHookCallbackDaemon()
157 157 with daemon as daemon_context:
158 158 pass
159 159 mocks['_run'].assert_called_once_with()
160 160 assert daemon_context == daemon
161 161
162 162 def test_stop_is_called_on_context_exit(self):
163 163 patchers = mock.patch.multiple(
164 164 hooks_daemon.ThreadedHookCallbackDaemon,
165 165 _run=mock.DEFAULT, _prepare=mock.DEFAULT, _stop=mock.DEFAULT)
166 166
167 167 with patchers as mocks:
168 168 daemon = hooks_daemon.ThreadedHookCallbackDaemon()
169 169 with daemon as daemon_context:
170 170 assert mocks['_stop'].call_count == 0
171 171
172 172 mocks['_stop'].assert_called_once_with()
173 173 assert daemon_context == daemon
174 174
175 175
176 176 class TestHttpHooksCallbackDaemon(object):
177 177 def test_hooks_callback_generates_new_port(self, caplog):
178 178 with caplog.at_level(logging.DEBUG):
179 179 daemon = hooks_daemon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
180 180 assert daemon._daemon.server_address == ('127.0.0.1', 8881)
181 181
182 182 with caplog.at_level(logging.DEBUG):
183 183 daemon = hooks_daemon.HttpHooksCallbackDaemon(host=None, port=None)
184 184 assert daemon._daemon.server_address[1] in range(0, 66000)
185 185 assert daemon._daemon.server_address[0] != '127.0.0.1'
186 186
187 187 def test_prepare_inits_daemon_variable(self, tcp_server, caplog):
188 188 with self._tcp_patcher(tcp_server), caplog.at_level(logging.DEBUG):
189 189 daemon = hooks_daemon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
190 190 assert daemon._daemon == tcp_server
191 191
192 192 _, port = tcp_server.server_address
193 193 expected_uri = '{}:{}'.format('127.0.0.1', port)
194 194 msg = 'HOOKS: {} Preparing HTTP callback daemon registering ' \
195 195 'hook object: rhodecode.lib.hooks_daemon.HooksHttpHandler'.format(expected_uri)
196 196 assert_message_in_log(
197 197 caplog.records, msg, levelno=logging.DEBUG, module='hooks_daemon')
198 198
199 199 def test_prepare_inits_hooks_uri_and_logs_it(
200 200 self, tcp_server, caplog):
201 201 with self._tcp_patcher(tcp_server), caplog.at_level(logging.DEBUG):
202 202 daemon = hooks_daemon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881)
203 203
204 204 _, port = tcp_server.server_address
205 205 expected_uri = '{}:{}'.format('127.0.0.1', port)
206 206 assert daemon.hooks_uri == expected_uri
207 207
208 208 msg = 'HOOKS: {} Preparing HTTP callback daemon registering ' \
209 209 'hook object: rhodecode.lib.hooks_daemon.HooksHttpHandler'.format(expected_uri)
210 210 assert_message_in_log(
211 211 caplog.records, msg,
212 212 levelno=logging.DEBUG, module='hooks_daemon')
213 213
214 214 def test_run_creates_a_thread(self, tcp_server):
215 215 thread = mock.Mock()
216 216
217 217 with self._tcp_patcher(tcp_server):
218 218 daemon = hooks_daemon.HttpHooksCallbackDaemon()
219 219
220 220 with self._thread_patcher(thread) as thread_mock:
221 221 daemon._run()
222 222
223 223 thread_mock.assert_called_once_with(
224 224 target=tcp_server.serve_forever,
225 225 kwargs={'poll_interval': daemon.POLL_INTERVAL})
226 226 assert thread.daemon is True
227 227 thread.start.assert_called_once_with()
228 228
229 229 def test_run_logs(self, tcp_server, caplog):
230 230
231 231 with self._tcp_patcher(tcp_server):
232 232 daemon = hooks_daemon.HttpHooksCallbackDaemon()
233 233
234 234 with self._thread_patcher(mock.Mock()), caplog.at_level(logging.DEBUG):
235 235 daemon._run()
236 236
237 237 assert_message_in_log(
238 238 caplog.records,
239 239 'Running event loop of callback daemon in background thread',
240 240 levelno=logging.DEBUG, module='hooks_daemon')
241 241
242 242 def test_stop_cleans_up_the_connection(self, tcp_server, caplog):
243 243 thread = mock.Mock()
244 244
245 245 with self._tcp_patcher(tcp_server):
246 246 daemon = hooks_daemon.HttpHooksCallbackDaemon()
247 247
248 248 with self._thread_patcher(thread), caplog.at_level(logging.DEBUG):
249 249 with daemon:
250 250 assert daemon._daemon == tcp_server
251 251 assert daemon._callback_thread == thread
252 252
253 253 assert daemon._daemon is None
254 254 assert daemon._callback_thread is None
255 255 tcp_server.shutdown.assert_called_with()
256 256 thread.join.assert_called_once_with()
257 257
258 258 assert_message_in_log(
259 259 caplog.records, 'Waiting for background thread to finish.',
260 260 levelno=logging.DEBUG, module='hooks_daemon')
261 261
262 262 def _tcp_patcher(self, tcp_server):
263 263 return mock.patch.object(
264 264 hooks_daemon, 'TCPServer', return_value=tcp_server)
265 265
266 266 def _thread_patcher(self, thread):
267 267 return mock.patch.object(
268 268 hooks_daemon.threading, 'Thread', return_value=thread)
269 269
270 270
271 271 class TestPrepareHooksDaemon(object):
272 272 @pytest.mark.parametrize('protocol', ('http',))
273 273 def test_returns_dummy_hooks_callback_daemon_when_using_direct_calls(
274 274 self, protocol):
275 275 expected_extras = {'extra1': 'value1'}
276 276 callback, extras = hooks_daemon.prepare_callback_daemon(
277 277 expected_extras.copy(), protocol=protocol,
278 278 host='127.0.0.1', use_direct_calls=True)
279 279 assert isinstance(callback, hooks_daemon.DummyHooksCallbackDaemon)
280 280 expected_extras['hooks_module'] = 'rhodecode.lib.hooks_daemon'
281 281 expected_extras['time'] = extras['time']
282 282 assert 'extra1' in extras
283 283
284 284 @pytest.mark.parametrize('protocol, expected_class', (
285 285 ('http', hooks_daemon.HttpHooksCallbackDaemon),
286 286 ))
287 287 def test_returns_real_hooks_callback_daemon_when_protocol_is_specified(
288 288 self, protocol, expected_class):
289 289 expected_extras = {
290 290 'extra1': 'value1',
291 291 'txn_id': 'txnid2',
292 292 'hooks_protocol': protocol.lower()
293 293 }
294 294 callback, extras = hooks_daemon.prepare_callback_daemon(
295 295 expected_extras.copy(), protocol=protocol, host='127.0.0.1',
296 296 use_direct_calls=False,
297 297 txn_id='txnid2')
298 298 assert isinstance(callback, expected_class)
299 299 extras.pop('hooks_uri')
300 300 expected_extras['time'] = extras['time']
301 301 assert extras == expected_extras
302 302
303 303 @pytest.mark.parametrize('protocol', (
304 304 'invalid',
305 305 'Http',
306 306 'HTTP',
307 307 ))
308 308 def test_raises_on_invalid_protocol(self, protocol):
309 309 expected_extras = {
310 310 'extra1': 'value1',
311 311 'hooks_protocol': protocol.lower()
312 312 }
313 313 with pytest.raises(Exception):
314 314 callback, extras = hooks_daemon.prepare_callback_daemon(
315 315 expected_extras.copy(),
316 316 protocol=protocol, host='127.0.0.1',
317 317 use_direct_calls=False)
318 318
319 319
320 320 class MockRequest(object):
321 321 def __init__(self, request):
322 322 self.request = request
323 self.input_stream = StringIO(b'{}'.format(self.request))
324 self.output_stream = StringIO()
323 self.input_stream = io.StringIO(b'{}'.format(self.request))
324 self.output_stream = io.StringIO()
325 325
326 326 def makefile(self, mode, *args, **kwargs):
327 327 return self.output_stream if mode == 'wb' else self.input_stream
328 328
329 329
330 330 class MockServer(object):
331 331 def __init__(self, handler_cls, request):
332 332 ip_port = ('0.0.0.0', 8888)
333 333 self.request = MockRequest(request)
334 334 self.server_address = ip_port
335 335 self.handler = handler_cls(self.request, ip_port, self)
336 336
337 337
338 338 @pytest.fixture()
339 339 def tcp_server():
340 340 server = mock.Mock()
341 341 server.server_address = ('127.0.0.1', 8881)
342 342 return server
@@ -1,176 +1,176 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import datetime
22 22 import os
23 23 import shutil
24 24 import tarfile
25 25 import tempfile
26 26 import zipfile
27 from io import StringIO
27 import io
28 28
29 29 import mock
30 30 import pytest
31 31
32 32 from rhodecode.lib.vcs.backends import base
33 33 from rhodecode.lib.vcs.exceptions import ImproperArchiveTypeError, VCSError
34 34 from rhodecode.lib.vcs.nodes import FileNode
35 35 from rhodecode.tests.vcs.conftest import BackendTestMixin
36 36
37 37
38 38 @pytest.mark.usefixtures("vcs_repository_support")
39 39 class TestArchives(BackendTestMixin):
40 40
41 41 @pytest.fixture(autouse=True)
42 42 def tempfile(self, request):
43 43 self.temp_file = tempfile.mkstemp()[1]
44 44
45 45 @request.addfinalizer
46 46 def cleanup():
47 47 os.remove(self.temp_file)
48 48
49 49 @classmethod
50 50 def _get_commits(cls):
51 51 start_date = datetime.datetime(2010, 1, 1, 20)
52 52 yield {
53 53 'message': 'Initial Commit',
54 54 'author': 'Joe Doe <joe.doe@example.com>',
55 55 'date': start_date + datetime.timedelta(hours=12),
56 56 'added': [
57 57 FileNode('executable_0o100755', '...', mode=0o100755),
58 58 FileNode('executable_0o100500', '...', mode=0o100500),
59 59 FileNode('not_executable', '...', mode=0o100644),
60 60 ],
61 61 }
62 62 for x in range(5):
63 63 yield {
64 64 'message': 'Commit %d' % x,
65 65 'author': 'Joe Doe <joe.doe@example.com>',
66 66 'date': start_date + datetime.timedelta(hours=12 * x),
67 67 'added': [
68 68 FileNode('%d/file_%d.txt' % (x, x), content='Foobar %d' % x),
69 69 ],
70 70 }
71 71
72 72 @pytest.mark.parametrize('compressor', ['gz', 'bz2'])
73 73 def test_archive_tar(self, compressor):
74 74 self.tip.archive_repo(
75 75 self.temp_file, kind='t{}'.format(compressor), archive_dir_name='repo')
76 76 out_dir = tempfile.mkdtemp()
77 77 out_file = tarfile.open(self.temp_file, 'r|{}'.format(compressor))
78 78 out_file.extractall(out_dir)
79 79 out_file.close()
80 80
81 81 for x in range(5):
82 82 node_path = '%d/file_%d.txt' % (x, x)
83 83 with open(os.path.join(out_dir, 'repo/' + node_path)) as f:
84 84 file_content = f.read()
85 85 assert file_content == self.tip.get_node(node_path).content
86 86
87 87 shutil.rmtree(out_dir)
88 88
89 89 @pytest.mark.parametrize('compressor', ['gz', 'bz2'])
90 90 def test_archive_tar_symlink(self, compressor):
91 91 return False
92 92
93 93 @pytest.mark.parametrize('compressor', ['gz', 'bz2'])
94 94 def test_archive_tar_file_modes(self, compressor):
95 95 self.tip.archive_repo(
96 96 self.temp_file, kind='t{}'.format(compressor), archive_dir_name='repo')
97 97 out_dir = tempfile.mkdtemp()
98 98 out_file = tarfile.open(self.temp_file, 'r|{}'.format(compressor))
99 99 out_file.extractall(out_dir)
100 100 out_file.close()
101 101 dest = lambda inp: os.path.join(out_dir, 'repo/' + inp)
102 102
103 103 assert oct(os.stat(dest('not_executable')).st_mode) == '0100644'
104 104
105 105 def test_archive_zip(self):
106 106 self.tip.archive_repo(self.temp_file, kind='zip', archive_dir_name='repo')
107 107 out = zipfile.ZipFile(self.temp_file)
108 108
109 109 for x in range(5):
110 110 node_path = '%d/file_%d.txt' % (x, x)
111 decompressed = StringIO.StringIO()
111 decompressed = io.StringIO()
112 112 decompressed.write(out.read('repo/' + node_path))
113 113 assert decompressed.getvalue() == \
114 114 self.tip.get_node(node_path).content
115 115 decompressed.close()
116 116
117 117 def test_archive_zip_with_metadata(self):
118 118 self.tip.archive_repo(self.temp_file, kind='zip',
119 119 archive_dir_name='repo', write_metadata=True)
120 120
121 121 out = zipfile.ZipFile(self.temp_file)
122 122 metafile = out.read('repo/.archival.txt')
123 123
124 124 raw_id = self.tip.raw_id
125 125 assert 'commit_id:%s' % raw_id in metafile
126 126
127 127 for x in range(5):
128 128 node_path = '%d/file_%d.txt' % (x, x)
129 decompressed = StringIO.StringIO()
129 decompressed = io.StringIO()
130 130 decompressed.write(out.read('repo/' + node_path))
131 131 assert decompressed.getvalue() == \
132 132 self.tip.get_node(node_path).content
133 133 decompressed.close()
134 134
135 135 def test_archive_wrong_kind(self):
136 136 with pytest.raises(ImproperArchiveTypeError):
137 137 self.tip.archive_repo(self.temp_file, kind='wrong kind')
138 138
139 139
140 140 @pytest.fixture()
141 141 def base_commit():
142 142 """
143 143 Prepare a `base.BaseCommit` just enough for `_validate_archive_prefix`.
144 144 """
145 145 commit = base.BaseCommit()
146 146 commit.repository = mock.Mock()
147 147 commit.repository.name = u'fake_repo'
148 148 commit.short_id = 'fake_id'
149 149 return commit
150 150
151 151
152 152 @pytest.mark.parametrize("prefix", [u"unicode-prefix", u"Ünïcödë"])
153 153 def test_validate_archive_prefix_enforces_bytes_as_prefix(prefix, base_commit):
154 154 with pytest.raises(ValueError):
155 155 base_commit._validate_archive_prefix(prefix)
156 156
157 157
158 158 def test_validate_archive_prefix_empty_prefix(base_commit):
159 159 # TODO: johbo: Should raise a ValueError here.
160 160 with pytest.raises(VCSError):
161 161 base_commit._validate_archive_prefix('')
162 162
163 163
164 164 def test_validate_archive_prefix_with_leading_slash(base_commit):
165 165 # TODO: johbo: Should raise a ValueError here.
166 166 with pytest.raises(VCSError):
167 167 base_commit._validate_archive_prefix('/any')
168 168
169 169
170 170 def test_validate_archive_prefix_falls_back_to_repository_name(base_commit):
171 171 prefix = base_commit._validate_archive_prefix(None)
172 172 expected_prefix = base_commit._ARCHIVE_PREFIX_TEMPLATE.format(
173 173 repo_name='fake_repo',
174 174 short_id='fake_id')
175 175 assert isinstance(prefix, str)
176 176 assert prefix == expected_prefix
General Comments 0
You need to be logged in to leave comments. Login now