##// END OF EJS Templates
libs: major refactor for python3
super-admin -
r5085:4eab4aa8 default
parent child Browse files
Show More
@@ -1,610 +1,606 b''
1 1
2 2 # Copyright (C) 2010-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software: you can redistribute it and/or modify
5 5 # it under the terms of the GNU Affero General Public License, version 3
6 6 # (only), as published by the Free Software Foundation.
7 7 #
8 8 # This program is distributed in the hope that it will be useful,
9 9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 11 # GNU General Public License for more details.
12 12 #
13 13 # You should have received a copy of the GNU Affero General Public License
14 14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 15 #
16 16 # This program is dual-licensed. If you wish to learn more about the
17 17 # RhodeCode Enterprise Edition, including its added features, Support services,
18 18 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 19
20 20 """
21 21 The base Controller API
22 22 Provides the BaseController class for subclassing. And usage in different
23 23 controllers
24 24 """
25 25
26 26 import logging
27 27 import socket
28 import base64
28 29
29 30 import markupsafe
30 31 import ipaddress
31 32
33 import paste.httpheaders
32 34 from paste.auth.basic import AuthBasicAuthenticator
33 35 from paste.httpexceptions import HTTPUnauthorized, HTTPForbidden, get_exception
34 from paste.httpheaders import WWW_AUTHENTICATE, AUTHORIZATION
35 36
36 37 import rhodecode
37 38 from rhodecode.authentication.base import VCS_TYPE
38 39 from rhodecode.lib import auth, utils2
39 40 from rhodecode.lib import helpers as h
40 41 from rhodecode.lib.auth import AuthUser, CookieStoreWrapper
41 42 from rhodecode.lib.exceptions import UserCreationError
42 43 from rhodecode.lib.utils import (password_changed, get_enabled_hook_classes)
43 from rhodecode.lib.utils2 import (
44 str2bool, safe_unicode, AttributeDict, safe_int, sha1, aslist, safe_str)
44 from rhodecode.lib.utils2 import AttributeDict
45 from rhodecode.lib.str_utils import ascii_bytes, safe_int, safe_str
46 from rhodecode.lib.type_utils import aslist, str2bool
47 from rhodecode.lib.hash_utils import sha1
45 48 from rhodecode.model.db import Repository, User, ChangesetComment, UserBookmark
46 49 from rhodecode.model.notification import NotificationModel
47 50 from rhodecode.model.settings import VcsSettingsModel, SettingsModel
48 51
49 52 log = logging.getLogger(__name__)
50 53
51 54
52 55 def _filter_proxy(ip):
53 56 """
54 57 Passed in IP addresses in HEADERS can be in a special format of multiple
55 58 ips. Those comma separated IPs are passed from various proxies in the
56 59 chain of request processing. The left-most being the original client.
57 60 We only care about the first IP which came from the org. client.
58 61
59 62 :param ip: ip string from headers
60 63 """
61 64 if ',' in ip:
62 65 _ips = ip.split(',')
63 66 _first_ip = _ips[0].strip()
64 67 log.debug('Got multiple IPs %s, using %s', ','.join(_ips), _first_ip)
65 68 return _first_ip
66 69 return ip
67 70
68 71
69 72 def _filter_port(ip):
70 73 """
71 74 Removes a port from ip, there are 4 main cases to handle here.
72 75 - ipv4 eg. 127.0.0.1
73 76 - ipv6 eg. ::1
74 77 - ipv4+port eg. 127.0.0.1:8080
75 78 - ipv6+port eg. [::1]:8080
76 79
77 80 :param ip:
78 81 """
79 82 def is_ipv6(ip_addr):
80 83 if hasattr(socket, 'inet_pton'):
81 84 try:
82 85 socket.inet_pton(socket.AF_INET6, ip_addr)
83 86 except socket.error:
84 87 return False
85 88 else:
86 89 # fallback to ipaddress
87 90 try:
88 91 ipaddress.IPv6Address(safe_str(ip_addr))
89 92 except Exception:
90 93 return False
91 94 return True
92 95
93 96 if ':' not in ip: # must be ipv4 pure ip
94 97 return ip
95 98
96 99 if '[' in ip and ']' in ip: # ipv6 with port
97 100 return ip.split(']')[0][1:].lower()
98 101
99 102 # must be ipv6 or ipv4 with port
100 103 if is_ipv6(ip):
101 104 return ip
102 105 else:
103 106 ip, _port = ip.split(':')[:2] # means ipv4+port
104 107 return ip
105 108
106 109
107 110 def get_ip_addr(environ):
108 111 proxy_key = 'HTTP_X_REAL_IP'
109 112 proxy_key2 = 'HTTP_X_FORWARDED_FOR'
110 113 def_key = 'REMOTE_ADDR'
111 _filters = lambda x: _filter_port(_filter_proxy(x))
114
115 def ip_filters(ip_):
116 return _filter_port(_filter_proxy(ip_))
112 117
113 118 ip = environ.get(proxy_key)
114 119 if ip:
115 return _filters(ip)
120 return ip_filters(ip)
116 121
117 122 ip = environ.get(proxy_key2)
118 123 if ip:
119 return _filters(ip)
124 return ip_filters(ip)
120 125
121 126 ip = environ.get(def_key, '0.0.0.0')
122 return _filters(ip)
127 return ip_filters(ip)
123 128
124 129
125 130 def get_server_ip_addr(environ, log_errors=True):
126 131 hostname = environ.get('SERVER_NAME')
127 132 try:
128 133 return socket.gethostbyname(hostname)
129 134 except Exception as e:
130 135 if log_errors:
131 136 # in some cases this lookup is not possible, and we don't want to
132 137 # make it an exception in logs
133 138 log.exception('Could not retrieve server ip address: %s', e)
134 139 return hostname
135 140
136 141
137 142 def get_server_port(environ):
138 143 return environ.get('SERVER_PORT')
139 144
140 145
141 def get_access_path(environ):
142 path = environ.get('PATH_INFO')
143 org_req = environ.get('pylons.original_request')
144 if org_req:
145 path = org_req.environ.get('PATH_INFO')
146 return path
147
148 146
149 147 def get_user_agent(environ):
150 148 return environ.get('HTTP_USER_AGENT')
151 149
152 150
153 151 def vcs_operation_context(
154 152 environ, repo_name, username, action, scm, check_locking=True,
155 153 is_shadow_repo=False, check_branch_perms=False, detect_force_push=False):
156 154 """
157 155 Generate the context for a vcs operation, e.g. push or pull.
158 156
159 157 This context is passed over the layers so that hooks triggered by the
160 158 vcs operation know details like the user, the user's IP address etc.
161 159
162 160 :param check_locking: Allows to switch of the computation of the locking
163 161 data. This serves mainly the need of the simplevcs middleware to be
164 162 able to disable this for certain operations.
165 163
166 164 """
167 165 # Tri-state value: False: unlock, None: nothing, True: lock
168 166 make_lock = None
169 167 locked_by = [None, None, None]
170 168 is_anonymous = username == User.DEFAULT_USER
171 169 user = User.get_by_username(username)
172 170 if not is_anonymous and check_locking:
173 171 log.debug('Checking locking on repository "%s"', repo_name)
174 172 repo = Repository.get_by_repo_name(repo_name)
175 173 make_lock, __, locked_by = repo.get_locking_state(
176 174 action, user.user_id)
177 175 user_id = user.user_id
178 176 settings_model = VcsSettingsModel(repo=repo_name)
179 177 ui_settings = settings_model.get_ui_settings()
180 178
181 179 # NOTE(marcink): This should be also in sync with
182 180 # rhodecode/apps/ssh_support/lib/backends/base.py:update_environment scm_data
183 181 store = [x for x in ui_settings if x.key == '/']
184 182 repo_store = ''
185 183 if store:
186 184 repo_store = store[0].value
187 185
188 186 scm_data = {
189 187 'ip': get_ip_addr(environ),
190 188 'username': username,
191 189 'user_id': user_id,
192 190 'action': action,
193 191 'repository': repo_name,
194 192 'scm': scm,
195 193 'config': rhodecode.CONFIG['__file__'],
196 194 'repo_store': repo_store,
197 195 'make_lock': make_lock,
198 196 'locked_by': locked_by,
199 197 'server_url': utils2.get_server_url(environ),
200 198 'user_agent': get_user_agent(environ),
201 199 'hooks': get_enabled_hook_classes(ui_settings),
202 200 'is_shadow_repo': is_shadow_repo,
203 201 'detect_force_push': detect_force_push,
204 202 'check_branch_perms': check_branch_perms,
205 203 }
206 204 return scm_data
207 205
208 206
209 207 class BasicAuth(AuthBasicAuthenticator):
210 208
211 209 def __init__(self, realm, authfunc, registry, auth_http_code=None,
212 210 initial_call_detection=False, acl_repo_name=None, rc_realm=''):
211 super(BasicAuth, self).__init__(realm=realm, authfunc=authfunc)
213 212 self.realm = realm
214 213 self.rc_realm = rc_realm
215 214 self.initial_call = initial_call_detection
216 215 self.authfunc = authfunc
217 216 self.registry = registry
218 217 self.acl_repo_name = acl_repo_name
219 218 self._rc_auth_http_code = auth_http_code
220 219
221 def _get_response_from_code(self, http_code):
220 def _get_response_from_code(self, http_code, fallback):
222 221 try:
223 222 return get_exception(safe_int(http_code))
224 223 except Exception:
225 log.exception('Failed to fetch response for code %s', http_code)
226 return HTTPForbidden
224 log.exception('Failed to fetch response class for code %s, using fallback: %s', http_code, fallback)
225 return fallback
227 226
228 227 def get_rc_realm(self):
229 228 return safe_str(self.rc_realm)
230 229
231 230 def build_authentication(self):
232 head = WWW_AUTHENTICATE.tuples('Basic realm="%s"' % self.realm)
231 header = [('WWW-Authenticate', f'Basic realm="{self.realm}"')]
232
233 # NOTE: the initial_Call detection seems to be not working/not needed witg latest Mercurial
234 # investigate if we still need it.
233 235 if self._rc_auth_http_code and not self.initial_call:
234 236 # return alternative HTTP code if alternative http return code
235 237 # is specified in RhodeCode config, but ONLY if it's not the
236 238 # FIRST call
237 custom_response_klass = self._get_response_from_code(
238 self._rc_auth_http_code)
239 return custom_response_klass(headers=head)
240 return HTTPUnauthorized(headers=head)
239 custom_response_klass = self._get_response_from_code(self._rc_auth_http_code, fallback=HTTPUnauthorized)
240 log.debug('Using custom response class: %s', custom_response_klass)
241 return custom_response_klass(headers=header)
242 return HTTPUnauthorized(headers=header)
241 243
242 244 def authenticate(self, environ):
243 authorization = AUTHORIZATION(environ)
245 authorization = paste.httpheaders.AUTHORIZATION(environ)
244 246 if not authorization:
245 247 return self.build_authentication()
246 (authmeth, auth) = authorization.split(' ', 1)
247 if 'basic' != authmeth.lower():
248 (auth_meth, auth_creds_b64) = authorization.split(' ', 1)
249 if 'basic' != auth_meth.lower():
248 250 return self.build_authentication()
249 auth = auth.strip().decode('base64')
250 _parts = auth.split(':', 1)
251
252 credentials = safe_str(base64.b64decode(auth_creds_b64.strip()))
253 _parts = credentials.split(':', 1)
251 254 if len(_parts) == 2:
252 255 username, password = _parts
253 256 auth_data = self.authfunc(
254 257 username, password, environ, VCS_TYPE,
255 258 registry=self.registry, acl_repo_name=self.acl_repo_name)
256 259 if auth_data:
257 260 return {'username': username, 'auth_data': auth_data}
258 261 if username and password:
259 262 # we mark that we actually executed authentication once, at
260 263 # that point we can use the alternative auth code
261 264 self.initial_call = False
262 265
263 266 return self.build_authentication()
264 267
265 268 __call__ = authenticate
266 269
267 270
268 271 def calculate_version_hash(config):
269 272 return sha1(
270 config.get('beaker.session.secret', '') +
271 rhodecode.__version__)[:8]
273 config.get(b'beaker.session.secret', b'') + ascii_bytes(rhodecode.__version__)
274 )[:8]
272 275
273 276
274 277 def get_current_lang(request):
275 # NOTE(marcink): remove after pyramid move
276 try:
277 return translation.get_lang()[0]
278 except:
279 pass
280
281 278 return getattr(request, '_LOCALE_', request.locale_name)
282 279
283 280
284 281 def attach_context_attributes(context, request, user_id=None, is_api=None):
285 282 """
286 283 Attach variables into template context called `c`.
287 284 """
288 285 config = request.registry.settings
289 286
290 287 rc_config = SettingsModel().get_all_settings(cache=True, from_request=False)
291 288 context.rc_config = rc_config
292 289 context.rhodecode_version = rhodecode.__version__
293 290 context.rhodecode_edition = config.get('rhodecode.edition')
294 291 context.rhodecode_edition_id = config.get('rhodecode.edition_id')
295 292 # unique secret + version does not leak the version but keep consistency
296 293 context.rhodecode_version_hash = calculate_version_hash(config)
297 294
298 295 # Default language set for the incoming request
299 296 context.language = get_current_lang(request)
300 297
301 298 # Visual options
302 299 context.visual = AttributeDict({})
303 300
304 301 # DB stored Visual Items
305 302 context.visual.show_public_icon = str2bool(
306 303 rc_config.get('rhodecode_show_public_icon'))
307 304 context.visual.show_private_icon = str2bool(
308 305 rc_config.get('rhodecode_show_private_icon'))
309 306 context.visual.stylify_metatags = str2bool(
310 307 rc_config.get('rhodecode_stylify_metatags'))
311 308 context.visual.dashboard_items = safe_int(
312 309 rc_config.get('rhodecode_dashboard_items', 100))
313 310 context.visual.admin_grid_items = safe_int(
314 311 rc_config.get('rhodecode_admin_grid_items', 100))
315 312 context.visual.show_revision_number = str2bool(
316 313 rc_config.get('rhodecode_show_revision_number', True))
317 314 context.visual.show_sha_length = safe_int(
318 315 rc_config.get('rhodecode_show_sha_length', 100))
319 316 context.visual.repository_fields = str2bool(
320 317 rc_config.get('rhodecode_repository_fields'))
321 318 context.visual.show_version = str2bool(
322 319 rc_config.get('rhodecode_show_version'))
323 320 context.visual.use_gravatar = str2bool(
324 321 rc_config.get('rhodecode_use_gravatar'))
325 322 context.visual.gravatar_url = rc_config.get('rhodecode_gravatar_url')
326 323 context.visual.default_renderer = rc_config.get(
327 324 'rhodecode_markup_renderer', 'rst')
328 325 context.visual.comment_types = ChangesetComment.COMMENT_TYPES
329 326 context.visual.rhodecode_support_url = \
330 327 rc_config.get('rhodecode_support_url') or h.route_url('rhodecode_support')
331 328
332 329 context.visual.affected_files_cut_off = 60
333 330
334 331 context.pre_code = rc_config.get('rhodecode_pre_code')
335 332 context.post_code = rc_config.get('rhodecode_post_code')
336 333 context.rhodecode_name = rc_config.get('rhodecode_title')
337 334 context.default_encodings = aslist(config.get('default_encoding'), sep=',')
338 335 # if we have specified default_encoding in the request, it has more
339 336 # priority
340 337 if request.GET.get('default_encoding'):
341 338 context.default_encodings.insert(0, request.GET.get('default_encoding'))
342 339 context.clone_uri_tmpl = rc_config.get('rhodecode_clone_uri_tmpl')
343 340 context.clone_uri_id_tmpl = rc_config.get('rhodecode_clone_uri_id_tmpl')
344 341 context.clone_uri_ssh_tmpl = rc_config.get('rhodecode_clone_uri_ssh_tmpl')
345 342
346 343 # INI stored
347 344 context.labs_active = str2bool(
348 345 config.get('labs_settings_active', 'false'))
349 346 context.ssh_enabled = str2bool(
350 347 config.get('ssh.generate_authorized_keyfile', 'false'))
351 348 context.ssh_key_generator_enabled = str2bool(
352 349 config.get('ssh.enable_ui_key_generator', 'true'))
353 350
354 351 context.visual.allow_repo_location_change = str2bool(
355 352 config.get('allow_repo_location_change', True))
356 353 context.visual.allow_custom_hooks_settings = str2bool(
357 354 config.get('allow_custom_hooks_settings', True))
358 355 context.debug_style = str2bool(config.get('debug_style', False))
359 356
360 357 context.rhodecode_instanceid = config.get('instance_id')
361 358
362 359 context.visual.cut_off_limit_diff = safe_int(
363 config.get('cut_off_limit_diff'))
360 config.get('cut_off_limit_diff'), default=0)
364 361 context.visual.cut_off_limit_file = safe_int(
365 config.get('cut_off_limit_file'))
362 config.get('cut_off_limit_file'), default=0)
366 363
367 364 context.license = AttributeDict({})
368 365 context.license.hide_license_info = str2bool(
369 366 config.get('license.hide_license_info', False))
370 367
371 368 # AppEnlight
372 369 context.appenlight_enabled = config.get('appenlight', False)
373 370 context.appenlight_api_public_key = config.get(
374 371 'appenlight.api_public_key', '')
375 372 context.appenlight_server_url = config.get('appenlight.server_url', '')
376 373
377 374 diffmode = {
378 375 "unified": "unified",
379 376 "sideside": "sideside"
380 377 }.get(request.GET.get('diffmode'))
381 378
382 379 if is_api is not None:
383 380 is_api = hasattr(request, 'rpc_user')
384 381 session_attrs = {
385 382 # defaults
386 383 "clone_url_format": "http",
387 384 "diffmode": "sideside",
388 385 "license_fingerprint": request.session.get('license_fingerprint')
389 386 }
390 387
391 388 if not is_api:
392 389 # don't access pyramid session for API calls
393 390 if diffmode and diffmode != request.session.get('rc_user_session_attr.diffmode'):
394 391 request.session['rc_user_session_attr.diffmode'] = diffmode
395 392
396 393 # session settings per user
397 394
398 for k, v in request.session.items():
395 for k, v in list(request.session.items()):
399 396 pref = 'rc_user_session_attr.'
400 397 if k and k.startswith(pref):
401 398 k = k[len(pref):]
402 399 session_attrs[k] = v
403 400
404 401 context.user_session_attrs = session_attrs
405 402
406 403 # JS template context
407 404 context.template_context = {
408 405 'repo_name': None,
409 406 'repo_type': None,
410 407 'repo_landing_commit': None,
411 408 'rhodecode_user': {
412 409 'username': None,
413 410 'email': None,
414 411 'notification_status': False
415 412 },
416 413 'session_attrs': session_attrs,
417 414 'visual': {
418 415 'default_renderer': None
419 416 },
420 417 'commit_data': {
421 418 'commit_id': None
422 419 },
423 420 'pull_request_data': {'pull_request_id': None},
424 421 'timeago': {
425 422 'refresh_time': 120 * 1000,
426 423 'cutoff_limit': 1000 * 60 * 60 * 24 * 7
427 424 },
428 425 'pyramid_dispatch': {
429 426
430 427 },
431 428 'extra': {'plugins': {}}
432 429 }
433 430 # END CONFIG VARS
434 431 if is_api:
435 432 csrf_token = None
436 433 else:
437 434 csrf_token = auth.get_csrf_token(session=request.session)
438 435
439 436 context.csrf_token = csrf_token
440 context.backends = rhodecode.BACKENDS.keys()
437 context.backends = list(rhodecode.BACKENDS.keys())
441 438
442 439 unread_count = 0
443 440 user_bookmark_list = []
444 441 if user_id:
445 442 unread_count = NotificationModel().get_unread_cnt_for_user(user_id)
446 443 user_bookmark_list = UserBookmark.get_bookmarks_for_user(user_id)
447 444 context.unread_notifications = unread_count
448 445 context.bookmark_items = user_bookmark_list
449 446
450 447 # web case
451 448 if hasattr(request, 'user'):
452 449 context.auth_user = request.user
453 450 context.rhodecode_user = request.user
454 451
455 452 # api case
456 453 if hasattr(request, 'rpc_user'):
457 454 context.auth_user = request.rpc_user
458 455 context.rhodecode_user = request.rpc_user
459 456
460 457 # attach the whole call context to the request
461 458 request.set_call_context(context)
462 459
463 460
464 461 def get_auth_user(request):
465 462 environ = request.environ
466 463 session = request.session
467 464
468 465 ip_addr = get_ip_addr(environ)
469 466
470 467 # make sure that we update permissions each time we call controller
471 468 _auth_token = (
472 469 # ?auth_token=XXX
473 470 request.GET.get('auth_token', '')
474 471 # ?api_key=XXX !LEGACY
475 472 or request.GET.get('api_key', '')
476 473 # or headers....
477 474 or request.headers.get('X-Rc-Auth-Token', '')
478 475 )
479 476 if not _auth_token and request.matchdict:
480 477 url_auth_token = request.matchdict.get('_auth_token')
481 478 _auth_token = url_auth_token
482 479 if _auth_token:
483 480 log.debug('Using URL extracted auth token `...%s`', _auth_token[-4:])
484 481
485 482 if _auth_token:
486 483 # when using API_KEY we assume user exists, and
487 484 # doesn't need auth based on cookies.
488 485 auth_user = AuthUser(api_key=_auth_token, ip_addr=ip_addr)
489 486 authenticated = False
490 487 else:
491 488 cookie_store = CookieStoreWrapper(session.get('rhodecode_user'))
492 489 try:
493 490 auth_user = AuthUser(user_id=cookie_store.get('user_id', None),
494 491 ip_addr=ip_addr)
495 492 except UserCreationError as e:
496 493 h.flash(e, 'error')
497 494 # container auth or other auth functions that create users
498 495 # on the fly can throw this exception signaling that there's
499 496 # issue with user creation, explanation should be provided
500 497 # in Exception itself. We then create a simple blank
501 498 # AuthUser
502 499 auth_user = AuthUser(ip_addr=ip_addr)
503 500
504 501 # in case someone changes a password for user it triggers session
505 502 # flush and forces a re-login
506 503 if password_changed(auth_user, session):
507 504 session.invalidate()
508 505 cookie_store = CookieStoreWrapper(session.get('rhodecode_user'))
509 506 auth_user = AuthUser(ip_addr=ip_addr)
510 507
511 508 authenticated = cookie_store.get('is_authenticated')
512 509
513 510 if not auth_user.is_authenticated and auth_user.is_user_object:
514 511 # user is not authenticated and not empty
515 512 auth_user.set_authenticated(authenticated)
516 513
517 514 return auth_user, _auth_token
518 515
519 516
520 517 def h_filter(s):
521 518 """
522 519 Custom filter for Mako templates. Mako by standard uses `markupsafe.escape`
523 520 we wrap this with additional functionality that converts None to empty
524 521 strings
525 522 """
526 523 if s is None:
527 524 return markupsafe.Markup()
528 525 return markupsafe.escape(s)
529 526
530 527
531 528 def add_events_routes(config):
532 529 """
533 530 Adds routing that can be used in events. Because some events are triggered
534 531 outside of pyramid context, we need to bootstrap request with some
535 532 routing registered
536 533 """
537 534
538 535 from rhodecode.apps._base import ADMIN_PREFIX
539 536
540 537 config.add_route(name='home', pattern='/')
541 538 config.add_route(name='main_page_repos_data', pattern='/_home_repos')
542 539 config.add_route(name='main_page_repo_groups_data', pattern='/_home_repo_groups')
543 540
544 541 config.add_route(name='login', pattern=ADMIN_PREFIX + '/login')
545 542 config.add_route(name='logout', pattern=ADMIN_PREFIX + '/logout')
546 543 config.add_route(name='repo_summary', pattern='/{repo_name}')
547 544 config.add_route(name='repo_summary_explicit', pattern='/{repo_name}/summary')
548 545 config.add_route(name='repo_group_home', pattern='/{repo_group_name}')
549 546
550 547 config.add_route(name='pullrequest_show',
551 548 pattern='/{repo_name}/pull-request/{pull_request_id}')
552 549 config.add_route(name='pull_requests_global',
553 550 pattern='/pull-request/{pull_request_id}')
554 551
555 552 config.add_route(name='repo_commit',
556 553 pattern='/{repo_name}/changeset/{commit_id}')
557 554 config.add_route(name='repo_files',
558 555 pattern='/{repo_name}/files/{commit_id}/{f_path}')
559 556
560 557 config.add_route(name='hovercard_user',
561 558 pattern='/_hovercard/user/{user_id}')
562 559
563 560 config.add_route(name='hovercard_user_group',
564 561 pattern='/_hovercard/user_group/{user_group_id}')
565 562
566 563 config.add_route(name='hovercard_pull_request',
567 564 pattern='/_hovercard/pull_request/{pull_request_id}')
568 565
569 566 config.add_route(name='hovercard_repo_commit',
570 567 pattern='/_hovercard/commit/{repo_name}/{commit_id}')
571 568
572 569
573 570 def bootstrap_config(request, registry_name='RcTestRegistry'):
574 571 import pyramid.testing
575 572 registry = pyramid.testing.Registry(registry_name)
576 573
577 574 config = pyramid.testing.setUp(registry=registry, request=request)
578 575
579 576 # allow pyramid lookup in testing
580 577 config.include('pyramid_mako')
581 578 config.include('rhodecode.lib.rc_beaker')
582 579 config.include('rhodecode.lib.rc_cache')
583
580 config.include('rhodecode.lib.rc_cache.archive_cache')
584 581 add_events_routes(config)
585 582
586 583 return config
587 584
588 585
589 586 def bootstrap_request(**kwargs):
590 587 """
591 588 Returns a thin version of Request Object that is used in non-web context like testing/celery
592 589 """
593 590
594 591 import pyramid.testing
595 592 from rhodecode.lib.request import ThinRequest as _ThinRequest
596 593
597 594 class ThinRequest(_ThinRequest):
598 595 application_url = kwargs.pop('application_url', 'http://example.com')
599 596 host = kwargs.pop('host', 'example.com:80')
600 597 domain = kwargs.pop('domain', 'example.com')
601 598
602 599 class ThinSession(pyramid.testing.DummySession):
603 600 def save(*arg, **kw):
604 601 pass
605 602
606 603 request = ThinRequest(**kwargs)
607 604 request.session = ThinSession()
608 605
609 606 return request
610
@@ -1,248 +1,251 b''
1 1
2 2 # Copyright (C) 2010-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software: you can redistribute it and/or modify
5 5 # it under the terms of the GNU Affero General Public License, version 3
6 6 # (only), as published by the Free Software Foundation.
7 7 #
8 8 # This program is distributed in the hope that it will be useful,
9 9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 11 # GNU General Public License for more details.
12 12 #
13 13 # You should have received a copy of the GNU Affero General Public License
14 14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 15 #
16 16 # This program is dual-licensed. If you wish to learn more about the
17 17 # RhodeCode Enterprise Edition, including its added features, Support services,
18 18 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 19
20 20 """caching_query.py
21 21
22 22 Represent functions and classes
23 23 which allow the usage of Dogpile caching with SQLAlchemy.
24 24 Introduces a query option called FromCache.
25 25
26 26 .. versionchanged:: 1.4 the caching approach has been altered to work
27 27 based on a session event.
28 28
29 29
30 30 The three new concepts introduced here are:
31 31
32 32 * ORMCache - an extension for an ORM :class:`.Session`
33 33 retrieves results in/from dogpile.cache.
34 34 * FromCache - a query option that establishes caching
35 35 parameters on a Query
36 36 * RelationshipCache - a variant of FromCache which is specific
37 37 to a query invoked during a lazy load.
38 38
39 39 The rest of what's here are standard SQLAlchemy and
40 40 dogpile.cache constructs.
41 41
42 42 """
43 43 from dogpile.cache.api import NO_VALUE
44 44
45 45 from sqlalchemy import event
46 46 from sqlalchemy.orm import loading
47 47 from sqlalchemy.orm.interfaces import UserDefinedOption
48 48
49 49
50 50 DEFAULT_REGION = "sql_cache_short"
51 51
52 52
53 53 class ORMCache:
54 54
55 55 """An add-on for an ORM :class:`.Session` optionally loads full results
56 56 from a dogpile cache region.
57 57
58 58 cache = ORMCache(regions={})
59 59 cache.listen_on_session(Session)
60 60
61 61 """
62 62
63 63 def __init__(self, regions):
64 64 self.cache_regions = regions or self._get_region()
65 65 self._statement_cache = {}
66 66
67 67 @classmethod
68 68 def _get_region(cls):
69 69 from rhodecode.lib.rc_cache import region_meta
70 70 return region_meta.dogpile_cache_regions
71 71
72 72 def listen_on_session(self, session_factory):
73 73 event.listen(session_factory, "do_orm_execute", self._do_orm_execute)
74 74
75 75 def _do_orm_execute(self, orm_context):
76
77 76 for opt in orm_context.user_defined_options:
78 77 if isinstance(opt, RelationshipCache):
79 78 opt = opt._process_orm_context(orm_context)
80 79 if opt is None:
81 80 continue
82 81
83 82 if isinstance(opt, FromCache):
84 83 dogpile_region = self.cache_regions[opt.region]
85 84
85 if dogpile_region.expiration_time <= 0:
86 # don't cache 0 time expiration cache
87 continue
88
86 89 if opt.cache_key:
87 90 our_cache_key = f'SQL_CACHE_{opt.cache_key}'
88 91 else:
89 92 our_cache_key = opt._generate_cache_key(
90 93 orm_context.statement, orm_context.parameters, self
91 94 )
92 95
93 96 if opt.ignore_expiration:
94 97 cached_value = dogpile_region.get(
95 98 our_cache_key,
96 99 expiration_time=opt.expiration_time,
97 100 ignore_expiration=opt.ignore_expiration,
98 101 )
99 102 else:
100 103
101 104 def createfunc():
102 105 return orm_context.invoke_statement().freeze()
103 106
104 107 cached_value = dogpile_region.get_or_create(
105 108 our_cache_key,
106 109 createfunc,
107 110 expiration_time=opt.expiration_time,
108 111 )
109 112
110 113 if cached_value is NO_VALUE:
111 114 # keyerror? this is bigger than a keyerror...
112 115 raise KeyError()
113 116
114 117 orm_result = loading.merge_frozen_result(
115 118 orm_context.session,
116 119 orm_context.statement,
117 120 cached_value,
118 121 load=False,
119 122 )
120 123 return orm_result()
121 124
122 125 else:
123 126 return None
124 127
125 128 def invalidate(self, statement, parameters, opt):
126 129 """Invalidate the cache value represented by a statement."""
127 130
128 131 statement = statement.__clause_element__()
129 132
130 133 dogpile_region = self.cache_regions[opt.region]
131 134
132 135 cache_key = opt._generate_cache_key(statement, parameters, self)
133 136
134 137 dogpile_region.delete(cache_key)
135 138
136 139
137 140 class FromCache(UserDefinedOption):
138 141 """Specifies that a Query should load results from a cache."""
139 142
140 143 propagate_to_loaders = False
141 144
142 145 def __init__(
143 146 self,
144 147 region=DEFAULT_REGION,
145 148 cache_key=None,
146 149 expiration_time=None,
147 150 ignore_expiration=False,
148 151 ):
149 152 """Construct a new FromCache.
150 153
151 154 :param region: the cache region. Should be a
152 155 region configured in the dictionary of dogpile
153 156 regions.
154 157
155 158 :param cache_key: optional. A string cache key
156 159 that will serve as the key to the query. Use this
157 160 if your query has a huge amount of parameters (such
158 161 as when using in_()) which correspond more simply to
159 162 some other identifier.
160 163
161 164 """
162 165 self.region = region
163 166 self.cache_key = cache_key
164 167 self.expiration_time = expiration_time
165 168 self.ignore_expiration = ignore_expiration
166 169
167 170 # this is not needed as of SQLAlchemy 1.4.28;
168 171 # UserDefinedOption classes no longer participate in the SQL
169 172 # compilation cache key
170 173 def _gen_cache_key(self, anon_map, bindparams):
171 174 return None
172 175
173 176 def _generate_cache_key(self, statement, parameters, orm_cache):
174 177 """generate a cache key with which to key the results of a statement.
175 178
176 179 This leverages the use of the SQL compilation cache key which is
177 180 repurposed as a SQL results key.
178 181
179 182 """
180 183 statement_cache_key = statement._generate_cache_key()
181 184
182 185 key = statement_cache_key.to_offline_string(
183 186 orm_cache._statement_cache, statement, parameters
184 187 ) + repr(self.cache_key)
185 188 # print("here's our key...%s" % key)
186 189 return key
187 190
188 191
189 192 class RelationshipCache(FromCache):
190 193 """Specifies that a Query as called within a "lazy load"
191 194 should load results from a cache."""
192 195
193 196 propagate_to_loaders = True
194 197
195 198 def __init__(
196 199 self,
197 200 attribute,
198 201 region=DEFAULT_REGION,
199 202 cache_key=None,
200 203 expiration_time=None,
201 204 ignore_expiration=False,
202 205 ):
203 206 """Construct a new RelationshipCache.
204 207
205 208 :param attribute: A Class.attribute which
206 209 indicates a particular class relationship() whose
207 210 lazy loader should be pulled from the cache.
208 211
209 212 :param region: name of the cache region.
210 213
211 214 :param cache_key: optional. A string cache key
212 215 that will serve as the key to the query, bypassing
213 216 the usual means of forming a key from the Query itself.
214 217
215 218 """
216 219 self.region = region
217 220 self.cache_key = cache_key
218 221 self.expiration_time = expiration_time
219 222 self.ignore_expiration = ignore_expiration
220 223 self._relationship_options = {
221 224 (attribute.property.parent.class_, attribute.property.key): self
222 225 }
223 226
224 227 def _process_orm_context(self, orm_context):
225 228 current_path = orm_context.loader_strategy_path
226 229
227 230 if current_path:
228 231 mapper, prop = current_path[-2:]
229 232 key = prop.key
230 233
231 234 for cls in mapper.class_.__mro__:
232 235 if (cls, key) in self._relationship_options:
233 236 relationship_option = self._relationship_options[
234 237 (cls, key)
235 238 ]
236 239 return relationship_option
237 240
238 241 def and_(self, option):
239 242 """Chain another RelationshipCache option to this one.
240 243
241 244 While many RelationshipCache objects can be specified on a single
242 245 Query separately, chaining them together allows for a more efficient
243 246 lookup during load.
244 247
245 248 """
246 249 self._relationship_options.update(option._relationship_options)
247 250 return self
248 251
@@ -1,372 +1,372 b''
1 1
2 2
3 3 # Copyright (C) 2016-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import os
22 22 import itsdangerous
23 23 import logging
24 24 import requests
25 25 import datetime
26 26
27 27 from dogpile.util.readwrite_lock import ReadWriteMutex
28 from pyramid.threadlocal import get_current_registry
29 28
30 29 import rhodecode.lib.helpers as h
31 30 from rhodecode.lib.auth import HasRepoPermissionAny
32 31 from rhodecode.lib.ext_json import json
33 32 from rhodecode.model.db import User
34 33 from rhodecode.lib.str_utils import ascii_str
35 34 from rhodecode.lib.hash_utils import sha1_safe
36 35
37 36 log = logging.getLogger(__name__)
38 37
39 38 LOCK = ReadWriteMutex()
40 39
41 40 USER_STATE_PUBLIC_KEYS = [
42 41 'id', 'username', 'first_name', 'last_name',
43 42 'icon_link', 'display_name', 'display_link']
44 43
45 44
46 45 class ChannelstreamException(Exception):
47 46 pass
48 47
49 48
50 49 class ChannelstreamConnectionException(ChannelstreamException):
51 50 pass
52 51
53 52
54 53 class ChannelstreamPermissionException(ChannelstreamException):
55 54 pass
56 55
57 56
58 57 def get_channelstream_server_url(config, endpoint):
59 58 return 'http://{}{}'.format(config['server'], endpoint)
60 59
61 60
62 61 def channelstream_request(config, payload, endpoint, raise_exc=True):
63 62 signer = itsdangerous.TimestampSigner(config['secret'])
64 63 sig_for_server = signer.sign(endpoint)
65 64 secret_headers = {'x-channelstream-secret': sig_for_server,
66 65 'x-channelstream-endpoint': endpoint,
67 66 'Content-Type': 'application/json'}
68 67 req_url = get_channelstream_server_url(config, endpoint)
69 68
70 69 log.debug('Sending a channelstream request to endpoint: `%s`', req_url)
71 70 response = None
72 71 try:
73 72 response = requests.post(req_url, data=json.dumps(payload),
74 73 headers=secret_headers).json()
75 74 except requests.ConnectionError:
76 75 log.exception('ConnectionError occurred for endpoint %s', req_url)
77 76 if raise_exc:
78 77 raise ChannelstreamConnectionException(req_url)
79 78 except Exception:
80 79 log.exception('Exception related to Channelstream happened')
81 80 if raise_exc:
82 81 raise ChannelstreamConnectionException()
83 82 log.debug('Got channelstream response: %s', response)
84 83 return response
85 84
86 85
87 86 def get_user_data(user_id):
88 87 user = User.get(user_id)
89 88 return {
90 89 'id': user.user_id,
91 90 'username': user.username,
92 91 'first_name': user.first_name,
93 92 'last_name': user.last_name,
94 93 'icon_link': h.gravatar_url(user.email, 60),
95 94 'display_name': h.person(user, 'username_or_name_or_email'),
96 95 'display_link': h.link_to_user(user),
97 96 'notifications': user.user_data.get('notification_status', True)
98 97 }
99 98
100 99
101 100 def broadcast_validator(channel_name):
102 101 """ checks if user can access the broadcast channel """
103 102 if channel_name == 'broadcast':
104 103 return True
105 104
106 105
107 106 def repo_validator(channel_name):
108 107 """ checks if user can access the broadcast channel """
109 108 channel_prefix = '/repo$'
110 109 if channel_name.startswith(channel_prefix):
111 110 elements = channel_name[len(channel_prefix):].split('$')
112 111 repo_name = elements[0]
113 112 can_access = HasRepoPermissionAny(
114 113 'repository.read',
115 114 'repository.write',
116 115 'repository.admin')(repo_name)
117 116 log.debug(
118 117 'permission check for %s channel resulted in %s',
119 118 repo_name, can_access)
120 119 if can_access:
121 120 return True
122 121 return False
123 122
124 123
125 124 def check_channel_permissions(channels, plugin_validators, should_raise=True):
126 125 valid_channels = []
127 126
128 127 validators = [broadcast_validator, repo_validator]
129 128 if plugin_validators:
130 129 validators.extend(plugin_validators)
131 130 for channel_name in channels:
132 131 is_valid = False
133 132 for validator in validators:
134 133 if validator(channel_name):
135 134 is_valid = True
136 135 break
137 136 if is_valid:
138 137 valid_channels.append(channel_name)
139 138 else:
140 139 if should_raise:
141 140 raise ChannelstreamPermissionException()
142 141 return valid_channels
143 142
144 143
145 144 def get_channels_info(self, channels):
146 145 payload = {'channels': channels}
147 146 # gather persistence info
148 147 return channelstream_request(self._config(), payload, '/info')
149 148
150 149
151 150 def parse_channels_info(info_result, include_channel_info=None):
152 151 """
153 152 Returns data that contains only secure information that can be
154 153 presented to clients
155 154 """
156 155 include_channel_info = include_channel_info or []
157 156
158 157 user_state_dict = {}
159 158 for userinfo in info_result['users']:
160 159 user_state_dict[userinfo['user']] = {
161 160 k: v for k, v in list(userinfo['state'].items())
162 161 if k in USER_STATE_PUBLIC_KEYS
163 162 }
164 163
165 164 channels_info = {}
166 165
167 166 for c_name, c_info in list(info_result['channels'].items()):
168 167 if c_name not in include_channel_info:
169 168 continue
170 169 connected_list = []
171 170 for username in c_info['users']:
172 171 connected_list.append({
173 172 'user': username,
174 173 'state': user_state_dict[username]
175 174 })
176 175 channels_info[c_name] = {'users': connected_list,
177 176 'history': c_info['history']}
178 177
179 178 return channels_info
180 179
181 180
182 181 def log_filepath(history_location, channel_name):
183 182
184 183 channel_hash = ascii_str(sha1_safe(channel_name))
185 184 filename = f'{channel_hash}.log'
186 185 filepath = os.path.join(history_location, filename)
187 186 return filepath
188 187
189 188
190 189 def read_history(history_location, channel_name):
191 190 filepath = log_filepath(history_location, channel_name)
192 191 if not os.path.exists(filepath):
193 192 return []
194 193 history_lines_limit = -100
195 194 history = []
196 195 with open(filepath, 'rb') as f:
197 196 for line in f.readlines()[history_lines_limit:]:
198 197 try:
199 198 history.append(json.loads(line))
200 199 except Exception:
201 200 log.exception('Failed to load history')
202 201 return history
203 202
204 203
205 204 def update_history_from_logs(config, channels, payload):
206 205 history_location = config.get('history.location')
207 206 for channel in channels:
208 207 history = read_history(history_location, channel)
209 208 payload['channels_info'][channel]['history'] = history
210 209
211 210
212 211 def write_history(config, message):
213 212 """ writes a message to a base64encoded filename """
214 213 history_location = config.get('history.location')
215 214 if not os.path.exists(history_location):
216 215 return
217 216 try:
218 217 LOCK.acquire_write_lock()
219 218 filepath = log_filepath(history_location, message['channel'])
220 219 json_message = json.dumps(message)
221 220 with open(filepath, 'ab') as f:
222 221 f.write(json_message)
223 222 f.write('\n')
224 223 finally:
225 224 LOCK.release_write_lock()
226 225
227 226
228 227 def get_connection_validators(registry):
229 228 validators = []
230 229 for k, config in list(registry.rhodecode_plugins.items()):
231 230 validator = config.get('channelstream', {}).get('connect_validator')
232 231 if validator:
233 232 validators.append(validator)
234 233 return validators
235 234
236 235
237 236 def get_channelstream_config(registry=None):
238 237 if not registry:
238 from pyramid.threadlocal import get_current_registry
239 239 registry = get_current_registry()
240 240
241 241 rhodecode_plugins = getattr(registry, 'rhodecode_plugins', {})
242 242 channelstream_config = rhodecode_plugins.get('channelstream', {})
243 243 return channelstream_config
244 244
245 245
246 246 def post_message(channel, message, username, registry=None):
247 247 channelstream_config = get_channelstream_config(registry)
248 248 if not channelstream_config.get('enabled'):
249 249 return
250 250
251 251 message_obj = message
252 252 if isinstance(message, str):
253 253 message_obj = {
254 254 'message': message,
255 255 'level': 'success',
256 256 'topic': '/notifications'
257 257 }
258 258
259 259 log.debug('Channelstream: sending notification to channel %s', channel)
260 260 payload = {
261 261 'type': 'message',
262 262 'timestamp': datetime.datetime.utcnow(),
263 263 'user': 'system',
264 264 'exclude_users': [username],
265 265 'channel': channel,
266 266 'message': message_obj
267 267 }
268 268
269 269 try:
270 270 return channelstream_request(
271 271 channelstream_config, [payload], '/message',
272 272 raise_exc=False)
273 273 except ChannelstreamException:
274 274 log.exception('Failed to send channelstream data')
275 275 raise
276 276
277 277
278 278 def _reload_link(label):
279 279 return (
280 280 '<a onclick="window.location.reload()">'
281 281 '<strong>{}</strong>'
282 282 '</a>'.format(label)
283 283 )
284 284
285 285
286 286 def pr_channel(pull_request):
287 287 repo_name = pull_request.target_repo.repo_name
288 288 pull_request_id = pull_request.pull_request_id
289 289 channel = '/repo${}$/pr/{}'.format(repo_name, pull_request_id)
290 290 log.debug('Getting pull-request channelstream broadcast channel: %s', channel)
291 291 return channel
292 292
293 293
294 294 def comment_channel(repo_name, commit_obj=None, pull_request_obj=None):
295 295 channel = None
296 296 if commit_obj:
297 297 channel = '/repo${}$/commit/{}'.format(
298 298 repo_name, commit_obj.raw_id
299 299 )
300 300 elif pull_request_obj:
301 301 channel = '/repo${}$/pr/{}'.format(
302 302 repo_name, pull_request_obj.pull_request_id
303 303 )
304 304 log.debug('Getting comment channelstream broadcast channel: %s', channel)
305 305
306 306 return channel
307 307
308 308
309 309 def pr_update_channelstream_push(request, pr_broadcast_channel, user, msg, **kwargs):
310 310 """
311 311 Channel push on pull request update
312 312 """
313 313 if not pr_broadcast_channel:
314 314 return
315 315
316 316 _ = request.translate
317 317
318 318 message = '{} {}'.format(
319 319 msg,
320 320 _reload_link(_(' Reload page to load changes')))
321 321
322 322 message_obj = {
323 323 'message': message,
324 324 'level': 'success',
325 325 'topic': '/notifications'
326 326 }
327 327
328 328 post_message(
329 329 pr_broadcast_channel, message_obj, user.username,
330 330 registry=request.registry)
331 331
332 332
333 333 def comment_channelstream_push(request, comment_broadcast_channel, user, msg, **kwargs):
334 334 """
335 335 Channelstream push on comment action, on commit, or pull-request
336 336 """
337 337 if not comment_broadcast_channel:
338 338 return
339 339
340 340 _ = request.translate
341 341
342 342 comment_data = kwargs.pop('comment_data', {})
343 343 user_data = kwargs.pop('user_data', {})
344 344 comment_id = list(comment_data.keys())[0] if comment_data else ''
345 345
346 346 message = '<strong>{}</strong> {} #{}'.format(
347 347 user.username,
348 348 msg,
349 349 comment_id,
350 350 )
351 351
352 352 message_obj = {
353 353 'message': message,
354 354 'level': 'success',
355 355 'topic': '/notifications'
356 356 }
357 357
358 358 post_message(
359 359 comment_broadcast_channel, message_obj, user.username,
360 360 registry=request.registry)
361 361
362 362 message_obj = {
363 363 'message': None,
364 364 'user': user.username,
365 365 'comment_id': comment_id,
366 366 'comment_data': comment_data,
367 367 'user_data': user_data,
368 368 'topic': '/comment'
369 369 }
370 370 post_message(
371 371 comment_broadcast_channel, message_obj, user.username,
372 372 registry=request.registry)
@@ -1,797 +1,819 b''
1 1
2 2
3 3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import logging
22 22 import difflib
23 from itertools import groupby
23 import itertools
24 24
25 25 from pygments import lex
26 26 from pygments.formatters.html import _get_ttype_class as pygment_token_class
27 27 from pygments.lexers.special import TextLexer, Token
28 28 from pygments.lexers import get_lexer_by_name
29 29
30 30 from rhodecode.lib.helpers import (
31 31 get_lexer_for_filenode, html_escape, get_custom_lexer)
32 from rhodecode.lib.utils2 import AttributeDict, StrictAttributeDict, safe_unicode
32 from rhodecode.lib.str_utils import safe_str
33 from rhodecode.lib.utils2 import AttributeDict, StrictAttributeDict
33 34 from rhodecode.lib.vcs.nodes import FileNode
34 from rhodecode.lib.vcs.exceptions import VCSError, NodeDoesNotExistError
35 from rhodecode.lib.vcs.exceptions import NodeDoesNotExistError
35 36 from rhodecode.lib.diff_match_patch import diff_match_patch
36 37 from rhodecode.lib.diffs import LimitedDiffContainer, DEL_FILENODE, BIN_FILENODE
37 38
38 39
39 40 plain_text_lexer = get_lexer_by_name(
40 41 'text', stripall=False, stripnl=False, ensurenl=False)
41 42
42 43
43 44 log = logging.getLogger(__name__)
44 45
45 46
46 47 def filenode_as_lines_tokens(filenode, lexer=None):
47 48 org_lexer = lexer
48 49 lexer = lexer or get_lexer_for_filenode(filenode)
49 log.debug('Generating file node pygment tokens for %s, %s, org_lexer:%s',
50 log.debug('Generating file node pygment tokens for %s, file=`%s`, org_lexer:%s',
50 51 lexer, filenode, org_lexer)
51 content = filenode.content
52 content = filenode.str_content
52 53 tokens = tokenize_string(content, lexer)
53 54 lines = split_token_stream(tokens, content)
54 55 rv = list(lines)
55 56 return rv
56 57
57 58
58 59 def tokenize_string(content, lexer):
59 60 """
60 61 Use pygments to tokenize some content based on a lexer
61 62 ensuring all original new lines and whitespace is preserved
62 63 """
63 64
64 65 lexer.stripall = False
65 66 lexer.stripnl = False
66 67 lexer.ensurenl = False
67 68
69 # pygments needs to operate on str
70 str_content = safe_str(content)
71
68 72 if isinstance(lexer, TextLexer):
69 lexed = [(Token.Text, content)]
73 # we convert content here to STR because pygments does that while tokenizing
74 # if we DON'T get a lexer for unknown file type
75 lexed = [(Token.Text, str_content)]
70 76 else:
71 lexed = lex(content, lexer)
77 lexed = lex(str_content, lexer)
72 78
73 79 for token_type, token_text in lexed:
74 80 yield pygment_token_class(token_type), token_text
75 81
76 82
77 83 def split_token_stream(tokens, content):
78 84 """
79 85 Take a list of (TokenType, text) tuples and split them by a string
80 86
81 87 split_token_stream([(TEXT, 'some\ntext'), (TEXT, 'more\n')])
82 88 [(TEXT, 'some'), (TEXT, 'text'),
83 89 (TEXT, 'more'), (TEXT, 'text')]
84 90 """
85 91
86 92 token_buffer = []
93
87 94 for token_class, token_text in tokens:
95
96 # token_text, should be str
88 97 parts = token_text.split('\n')
89 98 for part in parts[:-1]:
90 99 token_buffer.append((token_class, part))
91 100 yield token_buffer
92 101 token_buffer = []
93 102
94 103 token_buffer.append((token_class, parts[-1]))
95 104
96 105 if token_buffer:
97 106 yield token_buffer
98 107 elif content:
99 108 # this is a special case, we have the content, but tokenization didn't produce
100 # any results. THis can happen if know file extensions like .css have some bogus
109 # any results. This can happen if know file extensions like .css have some bogus
101 110 # unicode content without any newline characters
102 111 yield [(pygment_token_class(Token.Text), content)]
103 112
104 113
105 114 def filenode_as_annotated_lines_tokens(filenode):
106 115 """
107 116 Take a file node and return a list of annotations => lines, if no annotation
108 117 is found, it will be None.
109 118
110 119 eg:
111 120
112 121 [
113 122 (annotation1, [
114 123 (1, line1_tokens_list),
115 124 (2, line2_tokens_list),
116 125 ]),
117 126 (annotation2, [
118 127 (3, line1_tokens_list),
119 128 ]),
120 129 (None, [
121 130 (4, line1_tokens_list),
122 131 ]),
123 132 (annotation1, [
124 133 (5, line1_tokens_list),
125 134 (6, line2_tokens_list),
126 135 ])
127 136 ]
128 137 """
129 138
130 139 commit_cache = {} # cache commit_getter lookups
131 140
132 141 def _get_annotation(commit_id, commit_getter):
133 142 if commit_id not in commit_cache:
134 143 commit_cache[commit_id] = commit_getter()
135 144 return commit_cache[commit_id]
136 145
137 146 annotation_lookup = {
138 147 line_no: _get_annotation(commit_id, commit_getter)
139 148 for line_no, commit_id, commit_getter, line_content
140 149 in filenode.annotate
141 150 }
142 151
143 152 annotations_lines = ((annotation_lookup.get(line_no), line_no, tokens)
144 153 for line_no, tokens
145 154 in enumerate(filenode_as_lines_tokens(filenode), 1))
146 155
147 grouped_annotations_lines = groupby(annotations_lines, lambda x: x[0])
156 grouped_annotations_lines = itertools.groupby(annotations_lines, lambda x: x[0])
148 157
149 158 for annotation, group in grouped_annotations_lines:
150 159 yield (
151 160 annotation, [(line_no, tokens)
152 161 for (_, line_no, tokens) in group]
153 162 )
154 163
155 164
156 165 def render_tokenstream(tokenstream):
157 166 result = []
158 167 for token_class, token_ops_texts in rollup_tokenstream(tokenstream):
159 168
160 169 if token_class:
161 result.append('<span class="%s">' % token_class)
170 result.append(f'<span class="{token_class}">')
162 171 else:
163 172 result.append('<span>')
164 173
165 174 for op_tag, token_text in token_ops_texts:
166 175
167 176 if op_tag:
168 result.append('<%s>' % op_tag)
177 result.append(f'<{op_tag}>')
169 178
170 179 # NOTE(marcink): in some cases of mixed encodings, we might run into
171 180 # troubles in the html_escape, in this case we say unicode force on token_text
172 181 # that would ensure "correct" data even with the cost of rendered
173 182 try:
174 183 escaped_text = html_escape(token_text)
175 184 except TypeError:
176 escaped_text = html_escape(safe_unicode(token_text))
185 escaped_text = html_escape(safe_str(token_text))
177 186
178 187 # TODO: dan: investigate showing hidden characters like space/nl/tab
179 188 # escaped_text = escaped_text.replace(' ', '<sp> </sp>')
180 189 # escaped_text = escaped_text.replace('\n', '<nl>\n</nl>')
181 190 # escaped_text = escaped_text.replace('\t', '<tab>\t</tab>')
182 191
183 192 result.append(escaped_text)
184 193
185 194 if op_tag:
186 result.append('</%s>' % op_tag)
195 result.append(f'</{op_tag}>')
187 196
188 197 result.append('</span>')
189 198
190 199 html = ''.join(result)
191 200 return html
192 201
193 202
194 203 def rollup_tokenstream(tokenstream):
195 204 """
196 205 Group a token stream of the format:
197 206
198 207 ('class', 'op', 'text')
199 208 or
200 209 ('class', 'text')
201 210
202 211 into
203 212
204 213 [('class1',
205 214 [('op1', 'text'),
206 215 ('op2', 'text')]),
207 216 ('class2',
208 217 [('op3', 'text')])]
209 218
210 219 This is used to get the minimal tags necessary when
211 220 rendering to html eg for a token stream ie.
212 221
213 222 <span class="A"><ins>he</ins>llo</span>
214 223 vs
215 224 <span class="A"><ins>he</ins></span><span class="A">llo</span>
216 225
217 226 If a 2 tuple is passed in, the output op will be an empty string.
218 227
219 228 eg:
220 229
221 230 >>> rollup_tokenstream([('classA', '', 'h'),
222 231 ('classA', 'del', 'ell'),
223 232 ('classA', '', 'o'),
224 233 ('classB', '', ' '),
225 234 ('classA', '', 'the'),
226 235 ('classA', '', 're'),
227 236 ])
228 237
229 238 [('classA', [('', 'h'), ('del', 'ell'), ('', 'o')],
230 239 ('classB', [('', ' ')],
231 240 ('classA', [('', 'there')]]
232 241
233 242 """
234 243 if tokenstream and len(tokenstream[0]) == 2:
235 244 tokenstream = ((t[0], '', t[1]) for t in tokenstream)
236 245
237 246 result = []
238 for token_class, op_list in groupby(tokenstream, lambda t: t[0]):
247 for token_class, op_list in itertools.groupby(tokenstream, lambda t: t[0]):
239 248 ops = []
240 for token_op, token_text_list in groupby(op_list, lambda o: o[1]):
249 for token_op, token_text_list in itertools.groupby(op_list, lambda o: o[1]):
241 250 text_buffer = []
242 251 for t_class, t_op, t_text in token_text_list:
243 252 text_buffer.append(t_text)
253
244 254 ops.append((token_op, ''.join(text_buffer)))
245 255 result.append((token_class, ops))
246 256 return result
247 257
248 258
249 259 def tokens_diff(old_tokens, new_tokens, use_diff_match_patch=True):
250 260 """
251 261 Converts a list of (token_class, token_text) tuples to a list of
252 262 (token_class, token_op, token_text) tuples where token_op is one of
253 263 ('ins', 'del', '')
254 264
255 265 :param old_tokens: list of (token_class, token_text) tuples of old line
256 266 :param new_tokens: list of (token_class, token_text) tuples of new line
257 267 :param use_diff_match_patch: boolean, will use google's diff match patch
258 268 library which has options to 'smooth' out the character by character
259 269 differences making nicer ins/del blocks
260 270 """
261 271
262 272 old_tokens_result = []
263 273 new_tokens_result = []
264 274
265 similarity = difflib.SequenceMatcher(None,
275 def int_convert(val):
276 if isinstance(val, int):
277 return str(val)
278 return val
279
280 similarity = difflib.SequenceMatcher(
281 None,
266 282 ''.join(token_text for token_class, token_text in old_tokens),
267 283 ''.join(token_text for token_class, token_text in new_tokens)
268 284 ).ratio()
269 285
270 if similarity < 0.6: # return, the blocks are too different
286 if similarity < 0.6: # return, the blocks are too different
271 287 for token_class, token_text in old_tokens:
272 288 old_tokens_result.append((token_class, '', token_text))
273 289 for token_class, token_text in new_tokens:
274 290 new_tokens_result.append((token_class, '', token_text))
275 291 return old_tokens_result, new_tokens_result, similarity
276 292
277 token_sequence_matcher = difflib.SequenceMatcher(None,
293 token_sequence_matcher = difflib.SequenceMatcher(
294 None,
278 295 [x[1] for x in old_tokens],
279 296 [x[1] for x in new_tokens])
280 297
281 298 for tag, o1, o2, n1, n2 in token_sequence_matcher.get_opcodes():
282 # check the differences by token block types first to give a more
299 # check the differences by token block types first to give a
283 300 # nicer "block" level replacement vs character diffs
284 301
285 302 if tag == 'equal':
286 303 for token_class, token_text in old_tokens[o1:o2]:
287 304 old_tokens_result.append((token_class, '', token_text))
288 305 for token_class, token_text in new_tokens[n1:n2]:
289 306 new_tokens_result.append((token_class, '', token_text))
290 307 elif tag == 'delete':
291 308 for token_class, token_text in old_tokens[o1:o2]:
292 old_tokens_result.append((token_class, 'del', token_text))
309 old_tokens_result.append((token_class, 'del', int_convert(token_text)))
293 310 elif tag == 'insert':
294 311 for token_class, token_text in new_tokens[n1:n2]:
295 new_tokens_result.append((token_class, 'ins', token_text))
312 new_tokens_result.append((token_class, 'ins', int_convert(token_text)))
296 313 elif tag == 'replace':
297 314 # if same type token blocks must be replaced, do a diff on the
298 315 # characters in the token blocks to show individual changes
299 316
300 317 old_char_tokens = []
301 318 new_char_tokens = []
302 319 for token_class, token_text in old_tokens[o1:o2]:
303 for char in token_text:
320 for char in map(lambda i: i, token_text):
304 321 old_char_tokens.append((token_class, char))
305 322
306 323 for token_class, token_text in new_tokens[n1:n2]:
307 for char in token_text:
324 for char in map(lambda i: i, token_text):
308 325 new_char_tokens.append((token_class, char))
309 326
310 327 old_string = ''.join([token_text for
311 token_class, token_text in old_char_tokens])
328 token_class, token_text in old_char_tokens])
312 329 new_string = ''.join([token_text for
313 token_class, token_text in new_char_tokens])
330 token_class, token_text in new_char_tokens])
314 331
315 332 char_sequence = difflib.SequenceMatcher(
316 333 None, old_string, new_string)
317 334 copcodes = char_sequence.get_opcodes()
318 335 obuffer, nbuffer = [], []
319 336
320 337 if use_diff_match_patch:
321 338 dmp = diff_match_patch()
322 339 dmp.Diff_EditCost = 11 # TODO: dan: extract this to a setting
323 340 reps = dmp.diff_main(old_string, new_string)
324 341 dmp.diff_cleanupEfficiency(reps)
325 342
326 343 a, b = 0, 0
327 344 for op, rep in reps:
328 345 l = len(rep)
329 346 if op == 0:
330 347 for i, c in enumerate(rep):
331 348 obuffer.append((old_char_tokens[a+i][0], '', c))
332 349 nbuffer.append((new_char_tokens[b+i][0], '', c))
333 350 a += l
334 351 b += l
335 352 elif op == -1:
336 353 for i, c in enumerate(rep):
337 obuffer.append((old_char_tokens[a+i][0], 'del', c))
354 obuffer.append((old_char_tokens[a+i][0], 'del', int_convert(c)))
338 355 a += l
339 356 elif op == 1:
340 357 for i, c in enumerate(rep):
341 nbuffer.append((new_char_tokens[b+i][0], 'ins', c))
358 nbuffer.append((new_char_tokens[b+i][0], 'ins', int_convert(c)))
342 359 b += l
343 360 else:
344 361 for ctag, co1, co2, cn1, cn2 in copcodes:
345 362 if ctag == 'equal':
346 363 for token_class, token_text in old_char_tokens[co1:co2]:
347 364 obuffer.append((token_class, '', token_text))
348 365 for token_class, token_text in new_char_tokens[cn1:cn2]:
349 366 nbuffer.append((token_class, '', token_text))
350 367 elif ctag == 'delete':
351 368 for token_class, token_text in old_char_tokens[co1:co2]:
352 obuffer.append((token_class, 'del', token_text))
369 obuffer.append((token_class, 'del', int_convert(token_text)))
353 370 elif ctag == 'insert':
354 371 for token_class, token_text in new_char_tokens[cn1:cn2]:
355 nbuffer.append((token_class, 'ins', token_text))
372 nbuffer.append((token_class, 'ins', int_convert(token_text)))
356 373 elif ctag == 'replace':
357 374 for token_class, token_text in old_char_tokens[co1:co2]:
358 obuffer.append((token_class, 'del', token_text))
375 obuffer.append((token_class, 'del', int_convert(token_text)))
359 376 for token_class, token_text in new_char_tokens[cn1:cn2]:
360 nbuffer.append((token_class, 'ins', token_text))
377 nbuffer.append((token_class, 'ins', int_convert(token_text)))
361 378
362 379 old_tokens_result.extend(obuffer)
363 380 new_tokens_result.extend(nbuffer)
364 381
365 382 return old_tokens_result, new_tokens_result, similarity
366 383
367 384
368 385 def diffset_node_getter(commit):
369 def get_node(fname):
386 def get_diff_node(file_name):
387
370 388 try:
371 return commit.get_node(fname)
389 return commit.get_node(file_name, pre_load=['size', 'flags', 'data'])
372 390 except NodeDoesNotExistError:
373 391 return None
374 392
375 return get_node
393 return get_diff_node
376 394
377 395
378 396 class DiffSet(object):
379 397 """
380 398 An object for parsing the diff result from diffs.DiffProcessor and
381 399 adding highlighting, side by side/unified renderings and line diffs
382 400 """
383 401
384 402 HL_REAL = 'REAL' # highlights using original file, slow
385 403 HL_FAST = 'FAST' # highlights using just the line, fast but not correct
386 404 # in the case of multiline code
387 405 HL_NONE = 'NONE' # no highlighting, fastest
388 406
389 407 def __init__(self, highlight_mode=HL_REAL, repo_name=None,
390 408 source_repo_name=None,
391 409 source_node_getter=lambda filename: None,
392 410 target_repo_name=None,
393 411 target_node_getter=lambda filename: None,
394 412 source_nodes=None, target_nodes=None,
395 413 # files over this size will use fast highlighting
396 414 max_file_size_limit=150 * 1024,
397 415 ):
398 416
399 417 self.highlight_mode = highlight_mode
400 418 self.highlighted_filenodes = {
401 419 'before': {},
402 420 'after': {}
403 421 }
404 422 self.source_node_getter = source_node_getter
405 423 self.target_node_getter = target_node_getter
406 424 self.source_nodes = source_nodes or {}
407 425 self.target_nodes = target_nodes or {}
408 426 self.repo_name = repo_name
409 427 self.target_repo_name = target_repo_name or repo_name
410 428 self.source_repo_name = source_repo_name or repo_name
411 429 self.max_file_size_limit = max_file_size_limit
412 430
413 431 def render_patchset(self, patchset, source_ref=None, target_ref=None):
414 432 diffset = AttributeDict(dict(
415 433 lines_added=0,
416 434 lines_deleted=0,
417 435 changed_files=0,
418 436 files=[],
419 437 file_stats={},
420 438 limited_diff=isinstance(patchset, LimitedDiffContainer),
421 439 repo_name=self.repo_name,
422 440 target_repo_name=self.target_repo_name,
423 441 source_repo_name=self.source_repo_name,
424 442 source_ref=source_ref,
425 443 target_ref=target_ref,
426 444 ))
427 445 for patch in patchset:
428 446 diffset.file_stats[patch['filename']] = patch['stats']
429 447 filediff = self.render_patch(patch)
430 448 filediff.diffset = StrictAttributeDict(dict(
431 449 source_ref=diffset.source_ref,
432 450 target_ref=diffset.target_ref,
433 451 repo_name=diffset.repo_name,
434 452 source_repo_name=diffset.source_repo_name,
435 453 target_repo_name=diffset.target_repo_name,
436 454 ))
437 455 diffset.files.append(filediff)
438 456 diffset.changed_files += 1
439 457 if not patch['stats']['binary']:
440 458 diffset.lines_added += patch['stats']['added']
441 459 diffset.lines_deleted += patch['stats']['deleted']
442 460
443 461 return diffset
444 462
445 463 _lexer_cache = {}
446 464
447 465 def _get_lexer_for_filename(self, filename, filenode=None):
448 466 # cached because we might need to call it twice for source/target
449 467 if filename not in self._lexer_cache:
450 468 if filenode:
451 469 lexer = filenode.lexer
452 470 extension = filenode.extension
453 471 else:
454 472 lexer = FileNode.get_lexer(filename=filename)
455 473 extension = filename.split('.')[-1]
456 474
457 475 lexer = get_custom_lexer(extension) or lexer
458 476 self._lexer_cache[filename] = lexer
459 477 return self._lexer_cache[filename]
460 478
461 479 def render_patch(self, patch):
462 480 log.debug('rendering diff for %r', patch['filename'])
463 481
464 482 source_filename = patch['original_filename']
465 483 target_filename = patch['filename']
466 484
467 485 source_lexer = plain_text_lexer
468 486 target_lexer = plain_text_lexer
469 487
470 488 if not patch['stats']['binary']:
471 489 node_hl_mode = self.HL_NONE if patch['chunks'] == [] else None
472 490 hl_mode = node_hl_mode or self.highlight_mode
473 491
474 492 if hl_mode == self.HL_REAL:
475 493 if (source_filename and patch['operation'] in ('D', 'M')
476 494 and source_filename not in self.source_nodes):
477 495 self.source_nodes[source_filename] = (
478 496 self.source_node_getter(source_filename))
479 497
480 498 if (target_filename and patch['operation'] in ('A', 'M')
481 499 and target_filename not in self.target_nodes):
482 500 self.target_nodes[target_filename] = (
483 501 self.target_node_getter(target_filename))
484 502
485 503 elif hl_mode == self.HL_FAST:
486 504 source_lexer = self._get_lexer_for_filename(source_filename)
487 505 target_lexer = self._get_lexer_for_filename(target_filename)
488 506
489 507 source_file = self.source_nodes.get(source_filename, source_filename)
490 508 target_file = self.target_nodes.get(target_filename, target_filename)
491 509 raw_id_uid = ''
492 510 if self.source_nodes.get(source_filename):
493 511 raw_id_uid = self.source_nodes[source_filename].commit.raw_id
494 512
495 513 if not raw_id_uid and self.target_nodes.get(target_filename):
496 514 # in case this is a new file we only have it in target
497 515 raw_id_uid = self.target_nodes[target_filename].commit.raw_id
498 516
499 517 source_filenode, target_filenode = None, None
500 518
501 519 # TODO: dan: FileNode.lexer works on the content of the file - which
502 520 # can be slow - issue #4289 explains a lexer clean up - which once
503 521 # done can allow caching a lexer for a filenode to avoid the file lookup
504 522 if isinstance(source_file, FileNode):
505 523 source_filenode = source_file
506 524 #source_lexer = source_file.lexer
507 525 source_lexer = self._get_lexer_for_filename(source_filename)
508 526 source_file.lexer = source_lexer
509 527
510 528 if isinstance(target_file, FileNode):
511 529 target_filenode = target_file
512 530 #target_lexer = target_file.lexer
513 531 target_lexer = self._get_lexer_for_filename(target_filename)
514 532 target_file.lexer = target_lexer
515 533
516 534 source_file_path, target_file_path = None, None
517 535
518 536 if source_filename != '/dev/null':
519 537 source_file_path = source_filename
520 538 if target_filename != '/dev/null':
521 539 target_file_path = target_filename
522 540
523 541 source_file_type = source_lexer.name
524 542 target_file_type = target_lexer.name
525 543
526 544 filediff = AttributeDict({
527 545 'source_file_path': source_file_path,
528 546 'target_file_path': target_file_path,
529 547 'source_filenode': source_filenode,
530 548 'target_filenode': target_filenode,
531 549 'source_file_type': target_file_type,
532 550 'target_file_type': source_file_type,
533 551 'patch': {'filename': patch['filename'], 'stats': patch['stats']},
534 552 'operation': patch['operation'],
535 553 'source_mode': patch['stats']['old_mode'],
536 554 'target_mode': patch['stats']['new_mode'],
537 555 'limited_diff': patch['is_limited_diff'],
538 556 'hunks': [],
539 557 'hunk_ops': None,
540 558 'diffset': self,
541 559 'raw_id': raw_id_uid,
542 560 })
543 561
544 562 file_chunks = patch['chunks'][1:]
545 563 for i, hunk in enumerate(file_chunks, 1):
546 564 hunkbit = self.parse_hunk(hunk, source_file, target_file)
547 565 hunkbit.source_file_path = source_file_path
548 566 hunkbit.target_file_path = target_file_path
549 567 hunkbit.index = i
550 568 filediff.hunks.append(hunkbit)
551 569
552 570 # Simulate hunk on OPS type line which doesn't really contain any diff
553 571 # this allows commenting on those
554 572 if not file_chunks:
555 573 actions = []
556 for op_id, op_text in filediff.patch['stats']['ops'].items():
574 for op_id, op_text in list(filediff.patch['stats']['ops'].items()):
557 575 if op_id == DEL_FILENODE:
558 576 actions.append('file was removed')
559 577 elif op_id == BIN_FILENODE:
560 578 actions.append('binary diff hidden')
561 579 else:
562 actions.append(safe_unicode(op_text))
580 actions.append(safe_str(op_text))
563 581 action_line = 'NO CONTENT: ' + \
564 582 ', '.join(actions) or 'UNDEFINED_ACTION'
565 583
566 584 hunk_ops = {'source_length': 0, 'source_start': 0,
567 585 'lines': [
568 586 {'new_lineno': 0, 'old_lineno': 1,
569 587 'action': 'unmod-no-hl', 'line': action_line}
570 588 ],
571 589 'section_header': '', 'target_start': 1, 'target_length': 1}
572 590
573 591 hunkbit = self.parse_hunk(hunk_ops, source_file, target_file)
574 592 hunkbit.source_file_path = source_file_path
575 593 hunkbit.target_file_path = target_file_path
576 594 filediff.hunk_ops = hunkbit
577 595 return filediff
578 596
579 597 def parse_hunk(self, hunk, source_file, target_file):
580 598 result = AttributeDict(dict(
581 599 source_start=hunk['source_start'],
582 600 source_length=hunk['source_length'],
583 601 target_start=hunk['target_start'],
584 602 target_length=hunk['target_length'],
585 603 section_header=hunk['section_header'],
586 604 lines=[],
587 605 ))
588 606 before, after = [], []
589 607
590 608 for line in hunk['lines']:
609
591 610 if line['action'] in ['unmod', 'unmod-no-hl']:
592 611 no_hl = line['action'] == 'unmod-no-hl'
593 result.lines.extend(
594 self.parse_lines(before, after, source_file, target_file, no_hl=no_hl))
612 parsed_lines = self.parse_lines(before, after, source_file, target_file, no_hl=no_hl)
613 result.lines.extend(parsed_lines)
595 614 after.append(line)
596 615 before.append(line)
597 616 elif line['action'] == 'add':
598 617 after.append(line)
599 618 elif line['action'] == 'del':
600 619 before.append(line)
601 620 elif line['action'] == 'old-no-nl':
602 621 before.append(line)
622 #line['line'] = safe_str(line['line'])
603 623 elif line['action'] == 'new-no-nl':
624 #line['line'] = safe_str(line['line'])
604 625 after.append(line)
605 626
606 627 all_actions = [x['action'] for x in after] + [x['action'] for x in before]
607 628 no_hl = {x for x in all_actions} == {'unmod-no-hl'}
608 result.lines.extend(
609 self.parse_lines(before, after, source_file, target_file, no_hl=no_hl))
610 # NOTE(marcink): we must keep list() call here so we can cache the result...
629 parsed_no_hl_lines = self.parse_lines(before, after, source_file, target_file, no_hl=no_hl)
630 result.lines.extend(parsed_no_hl_lines)
631
632 # NOTE(marcink): we must keep list() call here, so we can cache the result...
611 633 result.unified = list(self.as_unified(result.lines))
612 634 result.sideside = result.lines
613 635
614 636 return result
615 637
616 638 def parse_lines(self, before_lines, after_lines, source_file, target_file,
617 639 no_hl=False):
618 640 # TODO: dan: investigate doing the diff comparison and fast highlighting
619 641 # on the entire before and after buffered block lines rather than by
620 642 # line, this means we can get better 'fast' highlighting if the context
621 643 # allows it - eg.
622 644 # line 4: """
623 645 # line 5: this gets highlighted as a string
624 646 # line 6: """
625 647
626 648 lines = []
627 649
628 650 before_newline = AttributeDict()
629 651 after_newline = AttributeDict()
630 652 if before_lines and before_lines[-1]['action'] == 'old-no-nl':
631 653 before_newline_line = before_lines.pop(-1)
632 654 before_newline.content = '\n {}'.format(
633 655 render_tokenstream(
634 [(x[0], '', x[1])
656 [(x[0], '', safe_str(x[1]))
635 657 for x in [('nonl', before_newline_line['line'])]]))
636 658
637 659 if after_lines and after_lines[-1]['action'] == 'new-no-nl':
638 660 after_newline_line = after_lines.pop(-1)
639 661 after_newline.content = '\n {}'.format(
640 662 render_tokenstream(
641 [(x[0], '', x[1])
663 [(x[0], '', safe_str(x[1]))
642 664 for x in [('nonl', after_newline_line['line'])]]))
643 665
644 666 while before_lines or after_lines:
645 667 before, after = None, None
646 668 before_tokens, after_tokens = None, None
647 669
648 670 if before_lines:
649 671 before = before_lines.pop(0)
650 672 if after_lines:
651 673 after = after_lines.pop(0)
652 674
653 675 original = AttributeDict()
654 676 modified = AttributeDict()
655 677
656 678 if before:
657 679 if before['action'] == 'old-no-nl':
658 before_tokens = [('nonl', before['line'])]
680 before_tokens = [('nonl', safe_str(before['line']))]
659 681 else:
660 682 before_tokens = self.get_line_tokens(
661 683 line_text=before['line'], line_number=before['old_lineno'],
662 684 input_file=source_file, no_hl=no_hl, source='before')
663 685 original.lineno = before['old_lineno']
664 686 original.content = before['line']
665 687 original.action = self.action_to_op(before['action'])
666 688
667 689 original.get_comment_args = (
668 690 source_file, 'o', before['old_lineno'])
669 691
670 692 if after:
671 693 if after['action'] == 'new-no-nl':
672 after_tokens = [('nonl', after['line'])]
694 after_tokens = [('nonl', safe_str(after['line']))]
673 695 else:
674 696 after_tokens = self.get_line_tokens(
675 697 line_text=after['line'], line_number=after['new_lineno'],
676 698 input_file=target_file, no_hl=no_hl, source='after')
677 699 modified.lineno = after['new_lineno']
678 700 modified.content = after['line']
679 701 modified.action = self.action_to_op(after['action'])
680 702
681 703 modified.get_comment_args = (target_file, 'n', after['new_lineno'])
682 704
683 705 # diff the lines
684 706 if before_tokens and after_tokens:
685 707 o_tokens, m_tokens, similarity = tokens_diff(
686 708 before_tokens, after_tokens)
687 709 original.content = render_tokenstream(o_tokens)
688 710 modified.content = render_tokenstream(m_tokens)
689 711 elif before_tokens:
690 712 original.content = render_tokenstream(
691 713 [(x[0], '', x[1]) for x in before_tokens])
692 714 elif after_tokens:
693 715 modified.content = render_tokenstream(
694 716 [(x[0], '', x[1]) for x in after_tokens])
695 717
696 718 if not before_lines and before_newline:
697 719 original.content += before_newline.content
698 720 before_newline = None
699 721 if not after_lines and after_newline:
700 722 modified.content += after_newline.content
701 723 after_newline = None
702 724
703 725 lines.append(AttributeDict({
704 726 'original': original,
705 727 'modified': modified,
706 728 }))
707 729
708 730 return lines
709 731
710 732 def get_line_tokens(self, line_text, line_number, input_file=None, no_hl=False, source=''):
711 733 filenode = None
712 734 filename = None
713 735
714 736 if isinstance(input_file, str):
715 737 filename = input_file
716 738 elif isinstance(input_file, FileNode):
717 739 filenode = input_file
718 filename = input_file.unicode_path
740 filename = input_file.str_path
719 741
720 742 hl_mode = self.HL_NONE if no_hl else self.highlight_mode
721 743 if hl_mode == self.HL_REAL and filenode:
722 744 lexer = self._get_lexer_for_filename(filename)
723 file_size_allowed = input_file.size < self.max_file_size_limit
745 file_size_allowed = filenode.size < self.max_file_size_limit
724 746 if line_number and file_size_allowed:
725 return self.get_tokenized_filenode_line(input_file, line_number, lexer, source)
747 return self.get_tokenized_filenode_line(filenode, line_number, lexer, source)
726 748
727 749 if hl_mode in (self.HL_REAL, self.HL_FAST) and filename:
728 750 lexer = self._get_lexer_for_filename(filename)
729 751 return list(tokenize_string(line_text, lexer))
730 752
731 753 return list(tokenize_string(line_text, plain_text_lexer))
732 754
733 755 def get_tokenized_filenode_line(self, filenode, line_number, lexer=None, source=''):
756 name_hash = hash(filenode)
734 757
735 def tokenize(_filenode):
736 self.highlighted_filenodes[source][filenode] = filenode_as_lines_tokens(filenode, lexer)
758 hl_node_code = self.highlighted_filenodes[source]
737 759
738 if filenode not in self.highlighted_filenodes[source]:
739 tokenize(filenode)
760 if name_hash not in hl_node_code:
761 hl_node_code[name_hash] = filenode_as_lines_tokens(filenode, lexer)
740 762
741 763 try:
742 return self.highlighted_filenodes[source][filenode][line_number - 1]
764 return hl_node_code[name_hash][line_number - 1]
743 765 except Exception:
744 log.exception('diff rendering error')
766 log.exception('diff rendering error on L:%s and file=%s', line_number - 1, filenode.name)
745 767 return [('', 'L{}: rhodecode diff rendering error'.format(line_number))]
746 768
747 769 def action_to_op(self, action):
748 770 return {
749 771 'add': '+',
750 772 'del': '-',
751 773 'unmod': ' ',
752 774 'unmod-no-hl': ' ',
753 775 'old-no-nl': ' ',
754 776 'new-no-nl': ' ',
755 777 }.get(action, action)
756 778
757 779 def as_unified(self, lines):
758 780 """
759 781 Return a generator that yields the lines of a diff in unified order
760 782 """
761 783 def generator():
762 784 buf = []
763 785 for line in lines:
764 786
765 787 if buf and not line.original or line.original.action == ' ':
766 788 for b in buf:
767 789 yield b
768 790 buf = []
769 791
770 792 if line.original:
771 793 if line.original.action == ' ':
772 794 yield (line.original.lineno, line.modified.lineno,
773 795 line.original.action, line.original.content,
774 796 line.original.get_comment_args)
775 797 continue
776 798
777 799 if line.original.action == '-':
778 800 yield (line.original.lineno, None,
779 801 line.original.action, line.original.content,
780 802 line.original.get_comment_args)
781 803
782 804 if line.modified.action == '+':
783 805 buf.append((
784 806 None, line.modified.lineno,
785 807 line.modified.action, line.modified.content,
786 808 line.modified.get_comment_args))
787 809 continue
788 810
789 811 if line.modified:
790 812 yield (None, line.modified.lineno,
791 813 line.modified.action, line.modified.content,
792 814 line.modified.get_comment_args)
793 815
794 816 for b in buf:
795 817 yield b
796 818
797 819 return generator()
@@ -1,679 +1,688 b''
1 1
2 2 # Copyright (C) 2010-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software: you can redistribute it and/or modify
5 5 # it under the terms of the GNU Affero General Public License, version 3
6 6 # (only), as published by the Free Software Foundation.
7 7 #
8 8 # This program is distributed in the hope that it will be useful,
9 9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 11 # GNU General Public License for more details.
12 12 #
13 13 # You should have received a copy of the GNU Affero General Public License
14 14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 15 #
16 16 # This program is dual-licensed. If you wish to learn more about the
17 17 # RhodeCode Enterprise Edition, including its added features, Support services,
18 18 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 19
20 20 """
21 21 Database creation, and setup module for RhodeCode Enterprise. Used for creation
22 22 of database as well as for migration operations
23 23 """
24 24
25 25 import os
26 26 import sys
27 27 import time
28 28 import uuid
29 29 import logging
30 30 import getpass
31 31 from os.path import dirname as dn, join as jn
32 32
33 33 from sqlalchemy.engine import create_engine
34 34
35 35 from rhodecode import __dbversion__
36 36 from rhodecode.model import init_model
37 37 from rhodecode.model.user import UserModel
38 38 from rhodecode.model.db import (
39 39 User, Permission, RhodeCodeUi, RhodeCodeSetting, UserToPerm,
40 40 DbMigrateVersion, RepoGroup, UserRepoGroupToPerm, CacheKey, Repository)
41 41 from rhodecode.model.meta import Session, Base
42 42 from rhodecode.model.permission import PermissionModel
43 43 from rhodecode.model.repo import RepoModel
44 44 from rhodecode.model.repo_group import RepoGroupModel
45 45 from rhodecode.model.settings import SettingsModel
46 46
47 47
48 48 log = logging.getLogger(__name__)
49 49
50 50
51 51 def notify(msg):
52 52 """
53 53 Notification for migrations messages
54 54 """
55 55 ml = len(msg) + (4 * 2)
56 print(('\n%s\n*** %s ***\n%s' % ('*' * ml, msg, '*' * ml)).upper())
56 print((('\n%s\n*** %s ***\n%s' % ('*' * ml, msg, '*' * ml)).upper()))
57 57
58 58
59 59 class DbManage(object):
60 60
61 61 def __init__(self, log_sql, dbconf, root, tests=False,
62 SESSION=None, cli_args=None):
62 SESSION=None, cli_args=None, enc_key=b''):
63
63 64 self.dbname = dbconf.split('/')[-1]
64 65 self.tests = tests
65 66 self.root = root
66 67 self.dburi = dbconf
67 68 self.log_sql = log_sql
68 69 self.cli_args = cli_args or {}
70 self.sa = None
71 self.engine = None
72 self.enc_key = enc_key
73 # sets .sa .engine
69 74 self.init_db(SESSION=SESSION)
75
70 76 self.ask_ok = self.get_ask_ok_func(self.cli_args.get('force_ask'))
71 77
72 78 def db_exists(self):
73 79 if not self.sa:
74 80 self.init_db()
75 81 try:
76 82 self.sa.query(RhodeCodeUi)\
77 83 .filter(RhodeCodeUi.ui_key == '/')\
78 84 .scalar()
79 85 return True
80 86 except Exception:
81 87 return False
82 88 finally:
83 89 self.sa.rollback()
84 90
85 91 def get_ask_ok_func(self, param):
86 92 if param not in [None]:
87 93 # return a function lambda that has a default set to param
88 94 return lambda *args, **kwargs: param
89 95 else:
90 96 from rhodecode.lib.utils import ask_ok
91 97 return ask_ok
92 98
93 99 def init_db(self, SESSION=None):
100
94 101 if SESSION:
95 102 self.sa = SESSION
103 self.engine = SESSION.bind
96 104 else:
97 105 # init new sessions
98 106 engine = create_engine(self.dburi, echo=self.log_sql)
99 init_model(engine)
107 init_model(engine, encryption_key=self.enc_key)
100 108 self.sa = Session()
109 self.engine = engine
101 110
102 111 def create_tables(self, override=False):
103 112 """
104 113 Create a auth database
105 114 """
106 115
107 116 log.info("Existing database with the same name is going to be destroyed.")
108 117 log.info("Setup command will run DROP ALL command on that database.")
118 engine = self.engine
119
109 120 if self.tests:
110 121 destroy = True
111 122 else:
112 123 destroy = self.ask_ok('Are you sure that you want to destroy the old database? [y/n]')
113 124 if not destroy:
114 125 log.info('db tables bootstrap: Nothing done.')
115 126 sys.exit(0)
116 127 if destroy:
117 Base.metadata.drop_all()
128 Base.metadata.drop_all(bind=engine)
118 129
119 130 checkfirst = not override
120 Base.metadata.create_all(checkfirst=checkfirst)
131 Base.metadata.create_all(bind=engine, checkfirst=checkfirst)
121 132 log.info('Created tables for %s', self.dbname)
122 133
123 134 def set_db_version(self):
124 135 ver = DbMigrateVersion()
125 136 ver.version = __dbversion__
126 137 ver.repository_id = 'rhodecode_db_migrations'
127 138 ver.repository_path = 'versions'
128 139 self.sa.add(ver)
129 140 log.info('db version set to: %s', __dbversion__)
130 141
131 142 def run_post_migration_tasks(self):
132 143 """
133 144 Run various tasks before actually doing migrations
134 145 """
135 146 # delete cache keys on each upgrade
136 147 total = CacheKey.query().count()
137 148 log.info("Deleting (%s) cache keys now...", total)
138 149 CacheKey.delete_all_cache()
139 150
140 151 def upgrade(self, version=None):
141 152 """
142 153 Upgrades given database schema to given revision following
143 154 all needed steps, to perform the upgrade
144 155
145 156 """
146 157
147 158 from rhodecode.lib.dbmigrate.migrate.versioning import api
148 from rhodecode.lib.dbmigrate.migrate.exceptions import \
149 DatabaseNotControlledError
159 from rhodecode.lib.dbmigrate.migrate.exceptions import DatabaseNotControlledError
150 160
151 161 if 'sqlite' in self.dburi:
152 162 print(
153 163 '********************** WARNING **********************\n'
154 164 'Make sure your version of sqlite is at least 3.7.X. \n'
155 165 'Earlier versions are known to fail on some migrations\n'
156 166 '*****************************************************\n')
157 167
158 168 upgrade = self.ask_ok(
159 169 'You are about to perform a database upgrade. Make '
160 170 'sure you have backed up your database. '
161 171 'Continue ? [y/n]')
162 172 if not upgrade:
163 173 log.info('No upgrade performed')
164 174 sys.exit(0)
165 175
166 176 repository_path = jn(dn(dn(dn(os.path.realpath(__file__)))),
167 177 'rhodecode/lib/dbmigrate')
168 178 db_uri = self.dburi
169 179
170 180 if version:
171 181 DbMigrateVersion.set_version(version)
172 182
173 183 try:
174 184 curr_version = api.db_version(db_uri, repository_path)
175 msg = ('Found current database db_uri under version '
176 'control with version {}'.format(curr_version))
185 msg = (f'Found current database db_uri under version '
186 f'control with version {curr_version}')
177 187
178 188 except (RuntimeError, DatabaseNotControlledError):
179 189 curr_version = 1
180 msg = ('Current database is not under version control. Setting '
181 'as version %s' % curr_version)
190 msg = f'Current database is not under version control. ' \
191 f'Setting as version {curr_version}'
182 192 api.version_control(db_uri, repository_path, curr_version)
183 193
184 194 notify(msg)
185 195
186
187 196 if curr_version == __dbversion__:
188 197 log.info('This database is already at the newest version')
189 198 sys.exit(0)
190 199
191 upgrade_steps = range(curr_version + 1, __dbversion__ + 1)
192 notify('attempting to upgrade database from '
193 'version %s to version %s' % (curr_version, __dbversion__))
200 upgrade_steps = list(range(curr_version + 1, __dbversion__ + 1))
201 notify(f'attempting to upgrade database from '
202 f'version {curr_version} to version {__dbversion__}')
194 203
195 204 # CALL THE PROPER ORDER OF STEPS TO PERFORM FULL UPGRADE
196 205 _step = None
197 206 for step in upgrade_steps:
198 notify('performing upgrade step %s' % step)
207 notify(f'performing upgrade step {step}')
199 208 time.sleep(0.5)
200 209
201 210 api.upgrade(db_uri, repository_path, step)
202 211 self.sa.rollback()
203 notify('schema upgrade for step %s completed' % (step,))
212 notify(f'schema upgrade for step {step} completed')
204 213
205 214 _step = step
206 215
207 216 self.run_post_migration_tasks()
208 notify('upgrade to version %s successful' % _step)
217 notify(f'upgrade to version {step} successful')
209 218
210 219 def fix_repo_paths(self):
211 220 """
212 221 Fixes an old RhodeCode version path into new one without a '*'
213 222 """
214 223
215 224 paths = self.sa.query(RhodeCodeUi)\
216 225 .filter(RhodeCodeUi.ui_key == '/')\
217 226 .scalar()
218 227
219 228 paths.ui_value = paths.ui_value.replace('*', '')
220 229
221 230 try:
222 231 self.sa.add(paths)
223 232 self.sa.commit()
224 233 except Exception:
225 234 self.sa.rollback()
226 235 raise
227 236
228 237 def fix_default_user(self):
229 238 """
230 239 Fixes an old default user with some 'nicer' default values,
231 240 used mostly for anonymous access
232 241 """
233 242 def_user = self.sa.query(User)\
234 .filter(User.username == User.DEFAULT_USER)\
235 .one()
243 .filter(User.username == User.DEFAULT_USER)\
244 .one()
236 245
237 246 def_user.name = 'Anonymous'
238 247 def_user.lastname = 'User'
239 248 def_user.email = User.DEFAULT_USER_EMAIL
240 249
241 250 try:
242 251 self.sa.add(def_user)
243 252 self.sa.commit()
244 253 except Exception:
245 254 self.sa.rollback()
246 255 raise
247 256
248 257 def fix_settings(self):
249 258 """
250 259 Fixes rhodecode settings and adds ga_code key for google analytics
251 260 """
252 261
253 262 hgsettings3 = RhodeCodeSetting('ga_code', '')
254 263
255 264 try:
256 265 self.sa.add(hgsettings3)
257 266 self.sa.commit()
258 267 except Exception:
259 268 self.sa.rollback()
260 269 raise
261 270
262 271 def create_admin_and_prompt(self):
263 272
264 273 # defaults
265 274 defaults = self.cli_args
266 275 username = defaults.get('username')
267 276 password = defaults.get('password')
268 277 email = defaults.get('email')
269 278
270 279 if username is None:
271 280 username = eval(input('Specify admin username:'))
272 281 if password is None:
273 282 password = self._get_admin_password()
274 283 if not password:
275 284 # second try
276 285 password = self._get_admin_password()
277 286 if not password:
278 287 sys.exit()
279 288 if email is None:
280 289 email = eval(input('Specify admin email:'))
281 290 api_key = self.cli_args.get('api_key')
282 291 self.create_user(username, password, email, True,
283 292 strict_creation_check=False,
284 293 api_key=api_key)
285 294
286 295 def _get_admin_password(self):
287 296 password = getpass.getpass('Specify admin password '
288 297 '(min 6 chars):')
289 298 confirm = getpass.getpass('Confirm password:')
290 299
291 300 if password != confirm:
292 301 log.error('passwords mismatch')
293 302 return False
294 303 if len(password) < 6:
295 304 log.error('password is too short - use at least 6 characters')
296 305 return False
297 306
298 307 return password
299 308
300 309 def create_test_admin_and_users(self):
301 310 log.info('creating admin and regular test users')
302 311 from rhodecode.tests import TEST_USER_ADMIN_LOGIN, \
303 312 TEST_USER_ADMIN_PASS, TEST_USER_ADMIN_EMAIL, \
304 313 TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS, \
305 314 TEST_USER_REGULAR_EMAIL, TEST_USER_REGULAR2_LOGIN, \
306 315 TEST_USER_REGULAR2_PASS, TEST_USER_REGULAR2_EMAIL
307 316
308 317 self.create_user(TEST_USER_ADMIN_LOGIN, TEST_USER_ADMIN_PASS,
309 318 TEST_USER_ADMIN_EMAIL, True, api_key=True)
310 319
311 320 self.create_user(TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS,
312 321 TEST_USER_REGULAR_EMAIL, False, api_key=True)
313 322
314 323 self.create_user(TEST_USER_REGULAR2_LOGIN, TEST_USER_REGULAR2_PASS,
315 324 TEST_USER_REGULAR2_EMAIL, False, api_key=True)
316 325
317 326 def create_ui_settings(self, repo_store_path):
318 327 """
319 328 Creates ui settings, fills out hooks
320 329 and disables dotencode
321 330 """
322 331 settings_model = SettingsModel(sa=self.sa)
323 332 from rhodecode.lib.vcs.backends.hg import largefiles_store
324 333 from rhodecode.lib.vcs.backends.git import lfs_store
325 334
326 335 # Build HOOKS
327 336 hooks = [
328 337 (RhodeCodeUi.HOOK_REPO_SIZE, 'python:vcsserver.hooks.repo_size'),
329 338
330 339 # HG
331 340 (RhodeCodeUi.HOOK_PRE_PULL, 'python:vcsserver.hooks.pre_pull'),
332 341 (RhodeCodeUi.HOOK_PULL, 'python:vcsserver.hooks.log_pull_action'),
333 342 (RhodeCodeUi.HOOK_PRE_PUSH, 'python:vcsserver.hooks.pre_push'),
334 343 (RhodeCodeUi.HOOK_PRETX_PUSH, 'python:vcsserver.hooks.pre_push'),
335 344 (RhodeCodeUi.HOOK_PUSH, 'python:vcsserver.hooks.log_push_action'),
336 345 (RhodeCodeUi.HOOK_PUSH_KEY, 'python:vcsserver.hooks.key_push'),
337 346
338 347 ]
339 348
340 349 for key, value in hooks:
341 350 hook_obj = settings_model.get_ui_by_key(key)
342 351 hooks2 = hook_obj if hook_obj else RhodeCodeUi()
343 352 hooks2.ui_section = 'hooks'
344 353 hooks2.ui_key = key
345 354 hooks2.ui_value = value
346 355 self.sa.add(hooks2)
347 356
348 357 # enable largefiles
349 358 largefiles = RhodeCodeUi()
350 359 largefiles.ui_section = 'extensions'
351 360 largefiles.ui_key = 'largefiles'
352 361 largefiles.ui_value = ''
353 362 self.sa.add(largefiles)
354 363
355 364 # set default largefiles cache dir, defaults to
356 365 # /repo_store_location/.cache/largefiles
357 366 largefiles = RhodeCodeUi()
358 367 largefiles.ui_section = 'largefiles'
359 368 largefiles.ui_key = 'usercache'
360 369 largefiles.ui_value = largefiles_store(repo_store_path)
361 370
362 371 self.sa.add(largefiles)
363 372
364 373 # set default lfs cache dir, defaults to
365 374 # /repo_store_location/.cache/lfs_store
366 375 lfsstore = RhodeCodeUi()
367 376 lfsstore.ui_section = 'vcs_git_lfs'
368 377 lfsstore.ui_key = 'store_location'
369 378 lfsstore.ui_value = lfs_store(repo_store_path)
370 379
371 380 self.sa.add(lfsstore)
372 381
373 382 # enable hgsubversion disabled by default
374 383 hgsubversion = RhodeCodeUi()
375 384 hgsubversion.ui_section = 'extensions'
376 385 hgsubversion.ui_key = 'hgsubversion'
377 386 hgsubversion.ui_value = ''
378 387 hgsubversion.ui_active = False
379 388 self.sa.add(hgsubversion)
380 389
381 390 # enable hgevolve disabled by default
382 391 hgevolve = RhodeCodeUi()
383 392 hgevolve.ui_section = 'extensions'
384 393 hgevolve.ui_key = 'evolve'
385 394 hgevolve.ui_value = ''
386 395 hgevolve.ui_active = False
387 396 self.sa.add(hgevolve)
388 397
389 398 hgevolve = RhodeCodeUi()
390 399 hgevolve.ui_section = 'experimental'
391 400 hgevolve.ui_key = 'evolution'
392 401 hgevolve.ui_value = ''
393 402 hgevolve.ui_active = False
394 403 self.sa.add(hgevolve)
395 404
396 405 hgevolve = RhodeCodeUi()
397 406 hgevolve.ui_section = 'experimental'
398 407 hgevolve.ui_key = 'evolution.exchange'
399 408 hgevolve.ui_value = ''
400 409 hgevolve.ui_active = False
401 410 self.sa.add(hgevolve)
402 411
403 412 hgevolve = RhodeCodeUi()
404 413 hgevolve.ui_section = 'extensions'
405 414 hgevolve.ui_key = 'topic'
406 415 hgevolve.ui_value = ''
407 416 hgevolve.ui_active = False
408 417 self.sa.add(hgevolve)
409 418
410 419 # enable hggit disabled by default
411 420 hggit = RhodeCodeUi()
412 421 hggit.ui_section = 'extensions'
413 422 hggit.ui_key = 'hggit'
414 423 hggit.ui_value = ''
415 424 hggit.ui_active = False
416 425 self.sa.add(hggit)
417 426
418 427 # set svn branch defaults
419 428 branches = ["/branches/*", "/trunk"]
420 429 tags = ["/tags/*"]
421 430
422 431 for branch in branches:
423 432 settings_model.create_ui_section_value(
424 433 RhodeCodeUi.SVN_BRANCH_ID, branch)
425 434
426 435 for tag in tags:
427 436 settings_model.create_ui_section_value(RhodeCodeUi.SVN_TAG_ID, tag)
428 437
429 438 def create_auth_plugin_options(self, skip_existing=False):
430 439 """
431 440 Create default auth plugin settings, and make it active
432 441
433 442 :param skip_existing:
434 443 """
435 444 defaults = [
436 445 ('auth_plugins',
437 446 'egg:rhodecode-enterprise-ce#token,egg:rhodecode-enterprise-ce#rhodecode',
438 447 'list'),
439 448
440 449 ('auth_authtoken_enabled',
441 450 'True',
442 451 'bool'),
443 452
444 453 ('auth_rhodecode_enabled',
445 454 'True',
446 455 'bool'),
447 456 ]
448 457 for k, v, t in defaults:
449 458 if (skip_existing and
450 459 SettingsModel().get_setting_by_name(k) is not None):
451 460 log.debug('Skipping option %s', k)
452 461 continue
453 462 setting = RhodeCodeSetting(k, v, t)
454 463 self.sa.add(setting)
455 464
456 465 def create_default_options(self, skip_existing=False):
457 466 """Creates default settings"""
458 467
459 468 for k, v, t in [
460 469 ('default_repo_enable_locking', False, 'bool'),
461 470 ('default_repo_enable_downloads', False, 'bool'),
462 471 ('default_repo_enable_statistics', False, 'bool'),
463 472 ('default_repo_private', False, 'bool'),
464 473 ('default_repo_type', 'hg', 'unicode')]:
465 474
466 475 if (skip_existing and
467 476 SettingsModel().get_setting_by_name(k) is not None):
468 477 log.debug('Skipping option %s', k)
469 478 continue
470 479 setting = RhodeCodeSetting(k, v, t)
471 480 self.sa.add(setting)
472 481
473 482 def fixup_groups(self):
474 483 def_usr = User.get_default_user()
475 484 for g in RepoGroup.query().all():
476 485 g.group_name = g.get_new_name(g.name)
477 486 self.sa.add(g)
478 487 # get default perm
479 488 default = UserRepoGroupToPerm.query()\
480 489 .filter(UserRepoGroupToPerm.group == g)\
481 490 .filter(UserRepoGroupToPerm.user == def_usr)\
482 491 .scalar()
483 492
484 493 if default is None:
485 494 log.debug('missing default permission for group %s adding', g)
486 495 perm_obj = RepoGroupModel()._create_default_perms(g)
487 496 self.sa.add(perm_obj)
488 497
489 498 def reset_permissions(self, username):
490 499 """
491 500 Resets permissions to default state, useful when old systems had
492 501 bad permissions, we must clean them up
493 502
494 503 :param username:
495 504 """
496 505 default_user = User.get_by_username(username)
497 506 if not default_user:
498 507 return
499 508
500 509 u2p = UserToPerm.query()\
501 510 .filter(UserToPerm.user == default_user).all()
502 511 fixed = False
503 512 if len(u2p) != len(Permission.DEFAULT_USER_PERMISSIONS):
504 513 for p in u2p:
505 514 Session().delete(p)
506 515 fixed = True
507 516 self.populate_default_permissions()
508 517 return fixed
509 518
510 519 def config_prompt(self, test_repo_path='', retries=3):
511 520 defaults = self.cli_args
512 521 _path = defaults.get('repos_location')
513 522 if retries == 3:
514 523 log.info('Setting up repositories config')
515 524
516 525 if _path is not None:
517 526 path = _path
518 527 elif not self.tests and not test_repo_path:
519 528 path = eval(input(
520 529 'Enter a valid absolute path to store repositories. '
521 530 'All repositories in that path will be added automatically:'
522 531 ))
523 532 else:
524 533 path = test_repo_path
525 534 path_ok = True
526 535
527 536 # check proper dir
528 537 if not os.path.isdir(path):
529 538 path_ok = False
530 539 log.error('Given path %s is not a valid directory', path)
531 540
532 541 elif not os.path.isabs(path):
533 542 path_ok = False
534 543 log.error('Given path %s is not an absolute path', path)
535 544
536 545 # check if path is at least readable.
537 546 if not os.access(path, os.R_OK):
538 547 path_ok = False
539 548 log.error('Given path %s is not readable', path)
540 549
541 550 # check write access, warn user about non writeable paths
542 551 elif not os.access(path, os.W_OK) and path_ok:
543 552 log.warning('No write permission to given path %s', path)
544 553
545 q = ('Given path %s is not writeable, do you want to '
546 'continue with read only mode ? [y/n]' % (path,))
554 q = (f'Given path {path} is not writeable, do you want to '
555 f'continue with read only mode ? [y/n]')
547 556 if not self.ask_ok(q):
548 557 log.error('Canceled by user')
549 558 sys.exit(-1)
550 559
551 560 if retries == 0:
552 561 sys.exit('max retries reached')
553 562 if not path_ok:
554 563 retries -= 1
555 564 return self.config_prompt(test_repo_path, retries)
556 565
557 566 real_path = os.path.normpath(os.path.realpath(path))
558 567
559 568 if real_path != os.path.normpath(path):
560 q = ('Path looks like a symlink, RhodeCode Enterprise will store '
561 'given path as %s ? [y/n]') % (real_path,)
569 q = (f'Path looks like a symlink, RhodeCode Enterprise will store '
570 f'given path as {real_path} ? [y/n]')
562 571 if not self.ask_ok(q):
563 572 log.error('Canceled by user')
564 573 sys.exit(-1)
565 574
566 575 return real_path
567 576
568 577 def create_settings(self, path):
569 578
570 579 self.create_ui_settings(path)
571 580
572 581 ui_config = [
573 582 ('web', 'push_ssl', 'False'),
574 583 ('web', 'allow_archive', 'gz zip bz2'),
575 584 ('web', 'allow_push', '*'),
576 585 ('web', 'baseurl', '/'),
577 586 ('paths', '/', path),
578 587 ('phases', 'publish', 'True')
579 588 ]
580 589 for section, key, value in ui_config:
581 590 ui_conf = RhodeCodeUi()
582 591 setattr(ui_conf, 'ui_section', section)
583 592 setattr(ui_conf, 'ui_key', key)
584 593 setattr(ui_conf, 'ui_value', value)
585 594 self.sa.add(ui_conf)
586 595
587 596 # rhodecode app settings
588 597 settings = [
589 598 ('realm', 'RhodeCode', 'unicode'),
590 599 ('title', '', 'unicode'),
591 600 ('pre_code', '', 'unicode'),
592 601 ('post_code', '', 'unicode'),
593 602
594 603 # Visual
595 604 ('show_public_icon', True, 'bool'),
596 605 ('show_private_icon', True, 'bool'),
597 606 ('stylify_metatags', True, 'bool'),
598 607 ('dashboard_items', 100, 'int'),
599 608 ('admin_grid_items', 25, 'int'),
600 609
601 610 ('markup_renderer', 'markdown', 'unicode'),
602 611
603 612 ('repository_fields', True, 'bool'),
604 613 ('show_version', True, 'bool'),
605 614 ('show_revision_number', True, 'bool'),
606 615 ('show_sha_length', 12, 'int'),
607 616
608 617 ('use_gravatar', False, 'bool'),
609 618 ('gravatar_url', User.DEFAULT_GRAVATAR_URL, 'unicode'),
610 619
611 620 ('clone_uri_tmpl', Repository.DEFAULT_CLONE_URI, 'unicode'),
612 621 ('clone_uri_id_tmpl', Repository.DEFAULT_CLONE_URI_ID, 'unicode'),
613 622 ('clone_uri_ssh_tmpl', Repository.DEFAULT_CLONE_URI_SSH, 'unicode'),
614 623 ('support_url', '', 'unicode'),
615 624 ('update_url', RhodeCodeSetting.DEFAULT_UPDATE_URL, 'unicode'),
616 625
617 626 # VCS Settings
618 627 ('pr_merge_enabled', True, 'bool'),
619 628 ('use_outdated_comments', True, 'bool'),
620 629 ('diff_cache', True, 'bool'),
621 630 ]
622 631
623 632 for key, val, type_ in settings:
624 633 sett = RhodeCodeSetting(key, val, type_)
625 634 self.sa.add(sett)
626 635
627 636 self.create_auth_plugin_options()
628 637 self.create_default_options()
629 638
630 639 log.info('created ui config')
631 640
632 641 def create_user(self, username, password, email='', admin=False,
633 642 strict_creation_check=True, api_key=None):
634 643 log.info('creating user `%s`', username)
635 644 user = UserModel().create_or_update(
636 645 username, password, email, firstname='RhodeCode', lastname='Admin',
637 646 active=True, admin=admin, extern_type="rhodecode",
638 647 strict_creation_check=strict_creation_check)
639 648
640 649 if api_key:
641 650 log.info('setting a new default auth token for user `%s`', username)
642 651 UserModel().add_auth_token(
643 652 user=user, lifetime_minutes=-1,
644 653 role=UserModel.auth_token_role.ROLE_ALL,
645 654 description='BUILTIN TOKEN')
646 655
647 656 def create_default_user(self):
648 657 log.info('creating default user')
649 658 # create default user for handling default permissions.
650 659 user = UserModel().create_or_update(username=User.DEFAULT_USER,
651 660 password=str(uuid.uuid1())[:20],
652 661 email=User.DEFAULT_USER_EMAIL,
653 662 firstname='Anonymous',
654 663 lastname='User',
655 664 strict_creation_check=False)
656 665 # based on configuration options activate/de-activate this user which
657 666 # controls anonymous access
658 667 if self.cli_args.get('public_access') is False:
659 668 log.info('Public access disabled')
660 669 user.active = False
661 670 Session().add(user)
662 671 Session().commit()
663 672
664 673 def create_permissions(self):
665 674 """
666 675 Creates all permissions defined in the system
667 676 """
668 677 # module.(access|create|change|delete)_[name]
669 678 # module.(none|read|write|admin)
670 679 log.info('creating permissions')
671 680 PermissionModel(self.sa).create_permissions()
672 681
673 682 def populate_default_permissions(self):
674 683 """
675 684 Populate default permissions. It will create only the default
676 685 permissions that are missing, and not alter already defined ones
677 686 """
678 687 log.info('creating default user permissions')
679 688 PermissionModel(self.sa).create_default_user_permissions(user=User.DEFAULT_USER)
@@ -1,182 +1,202 b''
1 1
2 2 # Copyright (C) 2010-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software: you can redistribute it and/or modify
5 5 # it under the terms of the GNU Affero General Public License, version 3
6 6 # (only), as published by the Free Software Foundation.
7 7 #
8 8 # This program is distributed in the hope that it will be useful,
9 9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 11 # GNU General Public License for more details.
12 12 #
13 13 # You should have received a copy of the GNU Affero General Public License
14 14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 15 #
16 16 # This program is dual-licensed. If you wish to learn more about the
17 17 # RhodeCode Enterprise Edition, including its added features, Support services,
18 18 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 19
20 20 """
21 21 Set of custom exceptions used in RhodeCode
22 22 """
23 23
24 24 from webob.exc import HTTPClientError
25 25 from pyramid.httpexceptions import HTTPBadGateway
26 26
27 27
28 28 class LdapUsernameError(Exception):
29 29 pass
30 30
31 31
32 32 class LdapPasswordError(Exception):
33 33 pass
34 34
35 35
36 36 class LdapConnectionError(Exception):
37 37 pass
38 38
39 39
40 40 class LdapImportError(Exception):
41 41 pass
42 42
43 43
44 44 class DefaultUserException(Exception):
45 45 pass
46 46
47 47
48 48 class UserOwnsReposException(Exception):
49 49 pass
50 50
51 51
52 52 class UserOwnsRepoGroupsException(Exception):
53 53 pass
54 54
55 55
56 56 class UserOwnsUserGroupsException(Exception):
57 57 pass
58 58
59 59
60 60 class UserOwnsPullRequestsException(Exception):
61 61 pass
62 62
63 63
64 64 class UserOwnsArtifactsException(Exception):
65 65 pass
66 66
67 67
68 68 class UserGroupAssignedException(Exception):
69 69 pass
70 70
71 71
72 72 class StatusChangeOnClosedPullRequestError(Exception):
73 73 pass
74 74
75 75
76 76 class AttachedForksError(Exception):
77 77 pass
78 78
79 79
80 80 class AttachedPullRequestsError(Exception):
81 81 pass
82 82
83 83
84 84 class RepoGroupAssignmentError(Exception):
85 85 pass
86 86
87 87
88 88 class NonRelativePathError(Exception):
89 89 pass
90 90
91 91
92 92 class HTTPRequirementError(HTTPClientError):
93 93 title = explanation = 'Repository Requirement Missing'
94 94 reason = None
95 95
96 96 def __init__(self, message, *args, **kwargs):
97 97 self.title = self.explanation = message
98 98 super(HTTPRequirementError, self).__init__(*args, **kwargs)
99 99 self.args = (message, )
100 100
101 101
102 102 class HTTPLockedRC(HTTPClientError):
103 103 """
104 104 Special Exception For locked Repos in RhodeCode, the return code can
105 105 be overwritten by _code keyword argument passed into constructors
106 106 """
107 107 code = 423
108 108 title = explanation = 'Repository Locked'
109 109 reason = None
110 110
111 111 def __init__(self, message, *args, **kwargs):
112 from rhodecode import CONFIG
113 from rhodecode.lib.utils2 import safe_int
114 _code = CONFIG.get('lock_ret_code')
115 self.code = safe_int(_code, self.code)
112 import rhodecode
113
114 self.code = rhodecode.ConfigGet().get_int('lock_ret_code', missing=self.code)
115
116 116 self.title = self.explanation = message
117 117 super(HTTPLockedRC, self).__init__(*args, **kwargs)
118 118 self.args = (message, )
119 119
120 120
121 121 class HTTPBranchProtected(HTTPClientError):
122 122 """
123 123 Special Exception For Indicating that branch is protected in RhodeCode, the
124 124 return code can be overwritten by _code keyword argument passed into constructors
125 125 """
126 126 code = 403
127 127 title = explanation = 'Branch Protected'
128 128 reason = None
129 129
130 130 def __init__(self, message, *args, **kwargs):
131 131 self.title = self.explanation = message
132 132 super(HTTPBranchProtected, self).__init__(*args, **kwargs)
133 133 self.args = (message, )
134 134
135 135
136 136 class IMCCommitError(Exception):
137 137 pass
138 138
139 139
140 140 class UserCreationError(Exception):
141 141 pass
142 142
143 143
144 144 class NotAllowedToCreateUserError(Exception):
145 145 pass
146 146
147 147
148 148 class RepositoryCreationError(Exception):
149 149 pass
150 150
151 151
152 152 class VCSServerUnavailable(HTTPBadGateway):
153 153 """ HTTP Exception class for VCS Server errors """
154 154 code = 502
155 155 title = 'VCS Server Error'
156 156 causes = [
157 157 'VCS Server is not running',
158 158 'Incorrect vcs.server=host:port',
159 159 'Incorrect vcs.server.protocol',
160 160 ]
161 161
162 162 def __init__(self, message=''):
163 163 self.explanation = 'Could not connect to VCS Server'
164 164 if message:
165 165 self.explanation += ': ' + message
166 166 super(VCSServerUnavailable, self).__init__()
167 167
168 168
169 169 class ArtifactMetadataDuplicate(ValueError):
170 170
171 171 def __init__(self, *args, **kwargs):
172 172 self.err_section = kwargs.pop('err_section', None)
173 173 self.err_key = kwargs.pop('err_key', None)
174 174 super(ArtifactMetadataDuplicate, self).__init__(*args, **kwargs)
175 175
176 176
177 177 class ArtifactMetadataBadValueType(ValueError):
178 178 pass
179 179
180 180
181 181 class CommentVersionMismatch(ValueError):
182 182 pass
183
184
185 class SignatureVerificationError(ValueError):
186 pass
187
188
189 def signature_verification_error(msg):
190 details = """
191 Encryption signature verification failed.
192 Please check your value of secret key, and/or encrypted value stored.
193 Secret key stored inside .ini file:
194 `rhodecode.encrypted_values.secret` or defaults to
195 `beaker.session.secret`
196
197 Probably the stored values were encrypted using a different secret then currently set in .ini file
198 """
199
200 final_msg = f'{msg}\n{details}'
201 return SignatureVerificationError(final_msg)
202
@@ -1,444 +1,443 b''
1 1 # Copyright (c) Django Software Foundation and individual contributors.
2 2 # All rights reserved.
3 3 #
4 4 # Redistribution and use in source and binary forms, with or without modification,
5 5 # are permitted provided that the following conditions are met:
6 6 #
7 7 # 1. Redistributions of source code must retain the above copyright notice,
8 8 # this list of conditions and the following disclaimer.
9 9 #
10 10 # 2. Redistributions in binary form must reproduce the above copyright
11 11 # notice, this list of conditions and the following disclaimer in the
12 12 # documentation and/or other materials provided with the distribution.
13 13 #
14 14 # 3. Neither the name of Django nor the names of its contributors may be used
15 15 # to endorse or promote products derived from this software without
16 16 # specific prior written permission.
17 17 #
18 18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
19 19 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
20 20 # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 21 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
22 22 # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
23 23 # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
24 24 # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
25 25 # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26 26 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27 27 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 28
29 29 """
30 30 For definitions of the different versions of RSS, see:
31 31 http://web.archive.org/web/20110718035220/http://diveintomark.org/archives/2004/02/04/incompatible-rss
32 32 """
33 33
34 34
35 35 import datetime
36 36 import io
37 37
38 import pytz
39 38 from six.moves.urllib import parse as urlparse
40 39
41 40 from rhodecode.lib.feedgenerator import datetime_safe
42 41 from rhodecode.lib.feedgenerator.utils import SimplerXMLGenerator, iri_to_uri, force_text
43 42
44 43
45 44 #### The following code comes from ``django.utils.feedgenerator`` ####
46 45
47 46
48 47 def rfc2822_date(date):
49 48 # We can't use strftime() because it produces locale-dependent results, so
50 49 # we have to map english month and day names manually
51 50 months = ('Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec',)
52 51 days = ('Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun')
53 52 # Support datetime objects older than 1900
54 53 date = datetime_safe.new_datetime(date)
55 54 # We do this ourselves to be timezone aware, email.Utils is not tz aware.
56 55 dow = days[date.weekday()]
57 56 month = months[date.month - 1]
58 57 time_str = date.strftime('%s, %%d %s %%Y %%H:%%M:%%S ' % (dow, month))
59 58
60 59 offset = date.utcoffset()
61 60 # Historically, this function assumes that naive datetimes are in UTC.
62 61 if offset is None:
63 62 return time_str + '-0000'
64 63 else:
65 64 timezone = (offset.days * 24 * 60) + (offset.seconds // 60)
66 65 hour, minute = divmod(timezone, 60)
67 66 return time_str + '%+03d%02d' % (hour, minute)
68 67
69 68
70 69 def rfc3339_date(date):
71 70 # Support datetime objects older than 1900
72 71 date = datetime_safe.new_datetime(date)
73 72 time_str = date.strftime('%Y-%m-%dT%H:%M:%S')
74 73
75 74 offset = date.utcoffset()
76 75 # Historically, this function assumes that naive datetimes are in UTC.
77 76 if offset is None:
78 77 return time_str + 'Z'
79 78 else:
80 79 timezone = (offset.days * 24 * 60) + (offset.seconds // 60)
81 80 hour, minute = divmod(timezone, 60)
82 81 return time_str + '%+03d:%02d' % (hour, minute)
83 82
84 83
85 84 def get_tag_uri(url, date):
86 85 """
87 86 Creates a TagURI.
88 87
89 88 See http://web.archive.org/web/20110514113830/http://diveintomark.org/archives/2004/05/28/howto-atom-id
90 89 """
91 90 bits = urlparse(url)
92 91 d = ''
93 92 if date is not None:
94 93 d = ',%s' % datetime_safe.new_datetime(date).strftime('%Y-%m-%d')
95 94 return 'tag:%s%s:%s/%s' % (bits.hostname, d, bits.path, bits.fragment)
96 95
97 96
98 97 class SyndicationFeed(object):
99 98 """Base class for all syndication feeds. Subclasses should provide write()"""
100 99
101 100 def __init__(self, title, link, description, language=None, author_email=None,
102 101 author_name=None, author_link=None, subtitle=None, categories=None,
103 102 feed_url=None, feed_copyright=None, feed_guid=None, ttl=None, **kwargs):
104 103 def to_unicode(s):
105 104 return force_text(s, strings_only=True)
106 105 if categories:
107 106 categories = [force_text(c) for c in categories]
108 107 if ttl is not None:
109 108 # Force ints to unicode
110 109 ttl = force_text(ttl)
111 110 self.feed = {
112 111 'title': to_unicode(title),
113 112 'link': iri_to_uri(link),
114 113 'description': to_unicode(description),
115 114 'language': to_unicode(language),
116 115 'author_email': to_unicode(author_email),
117 116 'author_name': to_unicode(author_name),
118 117 'author_link': iri_to_uri(author_link),
119 118 'subtitle': to_unicode(subtitle),
120 119 'categories': categories or (),
121 120 'feed_url': iri_to_uri(feed_url),
122 121 'feed_copyright': to_unicode(feed_copyright),
123 122 'id': feed_guid or link,
124 123 'ttl': ttl,
125 124 }
126 125 self.feed.update(kwargs)
127 126 self.items = []
128 127
129 128 def add_item(self, title, link, description, author_email=None,
130 129 author_name=None, author_link=None, pubdate=None, comments=None,
131 130 unique_id=None, unique_id_is_permalink=None, enclosure=None,
132 131 categories=(), item_copyright=None, ttl=None, updateddate=None,
133 132 enclosures=None, **kwargs):
134 133 """
135 134 Adds an item to the feed. All args are expected to be Python Unicode
136 135 objects except pubdate and updateddate, which are datetime.datetime
137 136 objects, and enclosures, which is an iterable of instances of the
138 137 Enclosure class.
139 138 """
140 139 def to_unicode(s):
141 140 return force_text(s, strings_only=True)
142 141 if categories:
143 142 categories = [to_unicode(c) for c in categories]
144 143 if ttl is not None:
145 144 # Force ints to unicode
146 145 ttl = force_text(ttl)
147 146 if enclosure is None:
148 147 enclosures = [] if enclosures is None else enclosures
149 148
150 149 item = {
151 150 'title': to_unicode(title),
152 151 'link': iri_to_uri(link),
153 152 'description': to_unicode(description),
154 153 'author_email': to_unicode(author_email),
155 154 'author_name': to_unicode(author_name),
156 155 'author_link': iri_to_uri(author_link),
157 156 'pubdate': pubdate,
158 157 'updateddate': updateddate,
159 158 'comments': to_unicode(comments),
160 159 'unique_id': to_unicode(unique_id),
161 160 'unique_id_is_permalink': unique_id_is_permalink,
162 161 'enclosures': enclosures,
163 162 'categories': categories or (),
164 163 'item_copyright': to_unicode(item_copyright),
165 164 'ttl': ttl,
166 165 }
167 166 item.update(kwargs)
168 167 self.items.append(item)
169 168
170 169 def num_items(self):
171 170 return len(self.items)
172 171
173 172 def root_attributes(self):
174 173 """
175 174 Return extra attributes to place on the root (i.e. feed/channel) element.
176 175 Called from write().
177 176 """
178 177 return {}
179 178
180 179 def add_root_elements(self, handler):
181 180 """
182 181 Add elements in the root (i.e. feed/channel) element. Called
183 182 from write().
184 183 """
185 184 pass
186 185
187 186 def item_attributes(self, item):
188 187 """
189 188 Return extra attributes to place on each item (i.e. item/entry) element.
190 189 """
191 190 return {}
192 191
193 192 def add_item_elements(self, handler, item):
194 193 """
195 194 Add elements on each item (i.e. item/entry) element.
196 195 """
197 196 pass
198 197
199 198 def write(self, outfile, encoding):
200 199 """
201 200 Outputs the feed in the given encoding to outfile, which is a file-like
202 201 object. Subclasses should override this.
203 202 """
204 203 raise NotImplementedError('subclasses of SyndicationFeed must provide a write() method')
205 204
206 205 def writeString(self, encoding):
207 206 """
208 207 Returns the feed in the given encoding as a string.
209 208 """
210 209 s = io.StringIO()
211 210 self.write(s, encoding)
212 211 return s.getvalue()
213 212
214 213 def latest_post_date(self):
215 214 """
216 215 Returns the latest item's pubdate or updateddate. If no items
217 216 have either of these attributes this returns the current UTC date/time.
218 217 """
219 218 latest_date = None
220 219 date_keys = ('updateddate', 'pubdate')
221 220
222 221 for item in self.items:
223 222 for date_key in date_keys:
224 223 item_date = item.get(date_key)
225 224 if item_date:
226 225 if latest_date is None or item_date > latest_date:
227 226 latest_date = item_date
228 227
229 228 # datetime.now(tz=utc) is slower, as documented in django.utils.timezone.now
230 return latest_date or datetime.datetime.utcnow().replace(tzinfo=pytz.utc)
229 return latest_date or datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
231 230
232 231
233 232 class Enclosure(object):
234 233 """Represents an RSS enclosure"""
235 234 def __init__(self, url, length, mime_type):
236 235 """All args are expected to be Python Unicode objects"""
237 236 self.length, self.mime_type = length, mime_type
238 237 self.url = iri_to_uri(url)
239 238
240 239
241 240 class RssFeed(SyndicationFeed):
242 241 content_type = 'application/rss+xml; charset=utf-8'
243 242
244 243 def write(self, outfile, encoding):
245 244 handler = SimplerXMLGenerator(outfile, encoding)
246 245 handler.startDocument()
247 246 handler.startElement("rss", self.rss_attributes())
248 247 handler.startElement("channel", self.root_attributes())
249 248 self.add_root_elements(handler)
250 249 self.write_items(handler)
251 250 self.endChannelElement(handler)
252 251 handler.endElement("rss")
253 252
254 253 def rss_attributes(self):
255 254 return {"version": self._version,
256 255 "xmlns:atom": "http://www.w3.org/2005/Atom"}
257 256
258 257 def write_items(self, handler):
259 258 for item in self.items:
260 259 handler.startElement('item', self.item_attributes(item))
261 260 self.add_item_elements(handler, item)
262 261 handler.endElement("item")
263 262
264 263 def add_root_elements(self, handler):
265 264 handler.addQuickElement("title", self.feed['title'])
266 265 handler.addQuickElement("link", self.feed['link'])
267 266 handler.addQuickElement("description", self.feed['description'])
268 267 if self.feed['feed_url'] is not None:
269 268 handler.addQuickElement("atom:link", None, {"rel": "self", "href": self.feed['feed_url']})
270 269 if self.feed['language'] is not None:
271 270 handler.addQuickElement("language", self.feed['language'])
272 271 for cat in self.feed['categories']:
273 272 handler.addQuickElement("category", cat)
274 273 if self.feed['feed_copyright'] is not None:
275 274 handler.addQuickElement("copyright", self.feed['feed_copyright'])
276 275 handler.addQuickElement("lastBuildDate", rfc2822_date(self.latest_post_date()))
277 276 if self.feed['ttl'] is not None:
278 277 handler.addQuickElement("ttl", self.feed['ttl'])
279 278
280 279 def endChannelElement(self, handler):
281 280 handler.endElement("channel")
282 281
283 282
284 283 class RssUserland091Feed(RssFeed):
285 284 _version = "0.91"
286 285
287 286 def add_item_elements(self, handler, item):
288 287 handler.addQuickElement("title", item['title'])
289 288 handler.addQuickElement("link", item['link'])
290 289 if item['description'] is not None:
291 290 handler.addQuickElement("description", item['description'])
292 291
293 292
294 293 class Rss201rev2Feed(RssFeed):
295 294 # Spec: http://blogs.law.harvard.edu/tech/rss
296 295 _version = "2.0"
297 296
298 297 def add_item_elements(self, handler, item):
299 298 handler.addQuickElement("title", item['title'])
300 299 handler.addQuickElement("link", item['link'])
301 300 if item['description'] is not None:
302 301 handler.addQuickElement("description", item['description'])
303 302
304 303 # Author information.
305 304 if item["author_name"] and item["author_email"]:
306 305 handler.addQuickElement("author", "%s (%s)" % (item['author_email'], item['author_name']))
307 306 elif item["author_email"]:
308 307 handler.addQuickElement("author", item["author_email"])
309 308 elif item["author_name"]:
310 309 handler.addQuickElement(
311 310 "dc:creator", item["author_name"], {"xmlns:dc": "http://purl.org/dc/elements/1.1/"}
312 311 )
313 312
314 313 if item['pubdate'] is not None:
315 314 handler.addQuickElement("pubDate", rfc2822_date(item['pubdate']))
316 315 if item['comments'] is not None:
317 316 handler.addQuickElement("comments", item['comments'])
318 317 if item['unique_id'] is not None:
319 318 guid_attrs = {}
320 319 if isinstance(item.get('unique_id_is_permalink'), bool):
321 320 guid_attrs['isPermaLink'] = str(item['unique_id_is_permalink']).lower()
322 321 handler.addQuickElement("guid", item['unique_id'], guid_attrs)
323 322 if item['ttl'] is not None:
324 323 handler.addQuickElement("ttl", item['ttl'])
325 324
326 325 # Enclosure.
327 326 if item['enclosures']:
328 327 enclosures = list(item['enclosures'])
329 328 if len(enclosures) > 1:
330 329 raise ValueError(
331 330 "RSS feed items may only have one enclosure, see "
332 331 "http://www.rssboard.org/rss-profile#element-channel-item-enclosure"
333 332 )
334 333 enclosure = enclosures[0]
335 334 handler.addQuickElement('enclosure', '', {
336 335 'url': enclosure.url,
337 336 'length': enclosure.length,
338 337 'type': enclosure.mime_type,
339 338 })
340 339
341 340 # Categories.
342 341 for cat in item['categories']:
343 342 handler.addQuickElement("category", cat)
344 343
345 344
346 345 class Atom1Feed(SyndicationFeed):
347 346 # Spec: https://tools.ietf.org/html/rfc4287
348 347 content_type = 'application/atom+xml; charset=utf-8'
349 348 ns = "http://www.w3.org/2005/Atom"
350 349
351 350 def write(self, outfile, encoding):
352 351 handler = SimplerXMLGenerator(outfile, encoding)
353 352 handler.startDocument()
354 353 handler.startElement('feed', self.root_attributes())
355 354 self.add_root_elements(handler)
356 355 self.write_items(handler)
357 356 handler.endElement("feed")
358 357
359 358 def root_attributes(self):
360 359 if self.feed['language'] is not None:
361 360 return {"xmlns": self.ns, "xml:lang": self.feed['language']}
362 361 else:
363 362 return {"xmlns": self.ns}
364 363
365 364 def add_root_elements(self, handler):
366 365 handler.addQuickElement("title", self.feed['title'])
367 366 handler.addQuickElement("link", "", {"rel": "alternate", "href": self.feed['link']})
368 367 if self.feed['feed_url'] is not None:
369 368 handler.addQuickElement("link", "", {"rel": "self", "href": self.feed['feed_url']})
370 369 handler.addQuickElement("id", self.feed['id'])
371 370 handler.addQuickElement("updated", rfc3339_date(self.latest_post_date()))
372 371 if self.feed['author_name'] is not None:
373 372 handler.startElement("author", {})
374 373 handler.addQuickElement("name", self.feed['author_name'])
375 374 if self.feed['author_email'] is not None:
376 375 handler.addQuickElement("email", self.feed['author_email'])
377 376 if self.feed['author_link'] is not None:
378 377 handler.addQuickElement("uri", self.feed['author_link'])
379 378 handler.endElement("author")
380 379 if self.feed['subtitle'] is not None:
381 380 handler.addQuickElement("subtitle", self.feed['subtitle'])
382 381 for cat in self.feed['categories']:
383 382 handler.addQuickElement("category", "", {"term": cat})
384 383 if self.feed['feed_copyright'] is not None:
385 384 handler.addQuickElement("rights", self.feed['feed_copyright'])
386 385
387 386 def write_items(self, handler):
388 387 for item in self.items:
389 388 handler.startElement("entry", self.item_attributes(item))
390 389 self.add_item_elements(handler, item)
391 390 handler.endElement("entry")
392 391
393 392 def add_item_elements(self, handler, item):
394 393 handler.addQuickElement("title", item['title'])
395 394 handler.addQuickElement("link", "", {"href": item['link'], "rel": "alternate"})
396 395
397 396 if item['pubdate'] is not None:
398 397 handler.addQuickElement('published', rfc3339_date(item['pubdate']))
399 398
400 399 if item['updateddate'] is not None:
401 400 handler.addQuickElement('updated', rfc3339_date(item['updateddate']))
402 401
403 402 # Author information.
404 403 if item['author_name'] is not None:
405 404 handler.startElement("author", {})
406 405 handler.addQuickElement("name", item['author_name'])
407 406 if item['author_email'] is not None:
408 407 handler.addQuickElement("email", item['author_email'])
409 408 if item['author_link'] is not None:
410 409 handler.addQuickElement("uri", item['author_link'])
411 410 handler.endElement("author")
412 411
413 412 # Unique ID.
414 413 if item['unique_id'] is not None:
415 414 unique_id = item['unique_id']
416 415 else:
417 416 unique_id = get_tag_uri(item['link'], item['pubdate'])
418 417 handler.addQuickElement("id", unique_id)
419 418
420 419 # Summary.
421 420 if item['description'] is not None:
422 421 handler.addQuickElement("summary", item['description'], {"type": "html"})
423 422
424 423 # Enclosures.
425 424 for enclosure in item['enclosures']:
426 425 handler.addQuickElement('link', '', {
427 426 'rel': 'enclosure',
428 427 'href': enclosure.url,
429 428 'length': enclosure.length,
430 429 'type': enclosure.mime_type,
431 430 })
432 431
433 432 # Categories.
434 433 for cat in item['categories']:
435 434 handler.addQuickElement("category", "", {"term": cat})
436 435
437 436 # Rights.
438 437 if item['item_copyright'] is not None:
439 438 handler.addQuickElement("rights", item['item_copyright'])
440 439
441 440
442 441 # This isolates the decision of what the system default is, so calling code can
443 442 # do "feedgenerator.DefaultFeed" instead of "feedgenerator.Rss201rev2Feed".
444 443 DefaultFeed = Rss201rev2Feed No newline at end of file
@@ -1,155 +1,155 b''
1 1
2 2
3 3 # Copyright (C) 2012-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 """
22 22 Index schema for RhodeCode
23 23 """
24 24
25 25 import importlib
26 26 import logging
27 27
28 28 from rhodecode.lib.index.search_utils import normalize_text_for_matching
29 29
30 30 log = logging.getLogger(__name__)
31 31
32 32 # leave defaults for backward compat
33 33 default_searcher = 'rhodecode.lib.index.whoosh'
34 34 default_location = '%(here)s/data/index'
35 35
36 36 ES_VERSION_2 = '2'
37 37 ES_VERSION_6 = '6'
38 38 # for legacy reasons we keep 2 compat as default
39 39 DEFAULT_ES_VERSION = ES_VERSION_2
40 40
41 41 try:
42 42 from rhodecode_tools.lib.fts_index.elasticsearch_engine_6 import ES_CONFIG # pragma: no cover
43 43 except ImportError:
44 44 log.warning('rhodecode_tools not available, use of full text search is limited')
45 45 pass
46 46
47 47
48 48 class BaseSearcher(object):
49 49 query_lang_doc = ''
50 50 es_version = None
51 51 name = None
52 52 DIRECTION_ASC = 'asc'
53 53 DIRECTION_DESC = 'desc'
54 54
55 55 def __init__(self):
56 56 pass
57 57
58 58 def cleanup(self):
59 59 pass
60 60
61 61 def search(self, query, document_type, search_user,
62 62 repo_name=None, repo_group_name=None,
63 63 raise_on_exc=True):
64 64 raise Exception('NotImplemented')
65 65
66 66 @staticmethod
67 67 def query_to_mark(query, default_field=None):
68 68 """
69 69 Formats the query to mark token for jquery.mark.js highlighting. ES could
70 70 have a different format optionally.
71 71
72 72 :param default_field:
73 73 :param query:
74 74 """
75 75 return ' '.join(normalize_text_for_matching(query).split())
76 76
77 77 @property
78 78 def is_es_6(self):
79 79 return self.es_version == ES_VERSION_6
80 80
81 81 def get_handlers(self):
82 82 return {}
83 83
84 84 @staticmethod
85 85 def extract_search_tags(query):
86 86 return []
87 87
88 88 @staticmethod
89 89 def escape_specials(val):
90 90 """
91 91 Handle and escape reserved chars for search
92 92 """
93 93 return val
94 94
95 95 def sort_def(self, search_type, direction, sort_field):
96 96 """
97 97 Defines sorting for search. This function should decide if for given
98 98 search_type, sorting can be done with sort_field.
99 99
100 100 It also should translate common sort fields into backend specific. e.g elasticsearch
101 101 """
102 102 raise NotImplementedError()
103 103
104 104 @staticmethod
105 105 def get_sort(search_type, search_val):
106 106 """
107 107 Method used to parse the GET search sort value to a field and direction.
108 108 e.g asc:lines == asc, lines
109 109
110 110 There's also a legacy support for newfirst/oldfirst which defines commit
111 111 sorting only
112 112 """
113 113
114 114 direction = BaseSearcher.DIRECTION_ASC
115 115 sort_field = None
116 116
117 117 if not search_val:
118 118 return direction, sort_field
119 119
120 120 if search_val.startswith('asc:'):
121 121 sort_field = search_val[4:]
122 122 direction = BaseSearcher.DIRECTION_ASC
123 123 elif search_val.startswith('desc:'):
124 124 sort_field = search_val[5:]
125 125 direction = BaseSearcher.DIRECTION_DESC
126 126 elif search_val == 'newfirst' and search_type == 'commit':
127 127 sort_field = 'date'
128 128 direction = BaseSearcher.DIRECTION_DESC
129 129 elif search_val == 'oldfirst' and search_type == 'commit':
130 130 sort_field = 'date'
131 131 direction = BaseSearcher.DIRECTION_ASC
132 132
133 133 return direction, sort_field
134 134
135 135
136 136 def search_config(config, prefix='search.'):
137 137 _config = {}
138 138 for key in config.keys():
139 139 if key.startswith(prefix):
140 140 _config[key[len(prefix):]] = config[key]
141 141 return _config
142 142
143 143
144 144 def searcher_from_config(config, prefix='search.'):
145 145 _config = search_config(config, prefix)
146 146
147 147 if 'location' not in _config:
148 148 _config['location'] = default_location
149 149 if 'es_version' not in _config:
150 # use old legacy ES version set to 2
150 # use an old legacy ES version set to 2
151 151 _config['es_version'] = '2'
152 152
153 153 imported = importlib.import_module(_config.get('module', default_searcher))
154 154 searcher = imported.Searcher(config=_config)
155 155 return searcher
@@ -1,278 +1,173 b''
1 1
2 2 # Copyright (C) 2010-2020 RhodeCode GmbH
3 3 #
4 4 # This program is free software: you can redistribute it and/or modify
5 5 # it under the terms of the GNU Affero General Public License, version 3
6 6 # (only), as published by the Free Software Foundation.
7 7 #
8 8 # This program is distributed in the hope that it will be useful,
9 9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 11 # GNU General Public License for more details.
12 12 #
13 13 # You should have received a copy of the GNU Affero General Public License
14 14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 15 #
16 16 # This program is dual-licensed. If you wish to learn more about the
17 17 # RhodeCode Enterprise Edition, including its added features, Support services,
18 18 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 19
20 import collections
21
22 20 import sqlalchemy
23 21 from sqlalchemy import UnicodeText
24 from sqlalchemy.ext.mutable import Mutable
22 from sqlalchemy.ext.mutable import Mutable, \
23 MutableList as MutationList, \
24 MutableDict as MutationDict
25 25
26 from rhodecode.lib.ext_json import json
27 from rhodecode.lib.utils2 import safe_unicode
26 from rhodecode.lib import ext_json
28 27
29 28
30 29 class JsonRaw(str):
31 30 """
32 31 Allows interacting with a JSON types field using a raw string.
33 32
34 33 For example::
35 34 db_instance = JsonTable()
36 35 db_instance.enabled = True
37 36 db_instance.json_data = JsonRaw('{"a": 4}')
38 37
39 38 This will bypass serialization/checks, and allow storing
40 39 raw values
41 40 """
42 41 pass
43 42
44 43
45 # Set this to the standard dict if Order is not required
46 DictClass = collections.OrderedDict
47
48
49 44 class JSONEncodedObj(sqlalchemy.types.TypeDecorator):
50 45 """
51 46 Represents an immutable structure as a json-encoded string.
52 47
53 48 If default is, for example, a dict, then a NULL value in the
54 49 database will be exposed as an empty dict.
55 50 """
56 51
57 52 impl = UnicodeText
58 53 safe = True
59 enforce_unicode = True
54 enforce_str = True
60 55
61 56 def __init__(self, *args, **kwargs):
62 57 self.default = kwargs.pop('default', None)
63 58 self.safe = kwargs.pop('safe_json', self.safe)
64 self.enforce_unicode = kwargs.pop('enforce_unicode', self.enforce_unicode)
59 self.enforce_str = kwargs.pop('enforce_str', self.enforce_str)
65 60 self.dialect_map = kwargs.pop('dialect_map', {})
66 61 super(JSONEncodedObj, self).__init__(*args, **kwargs)
67 62
68 63 def load_dialect_impl(self, dialect):
69 64 if dialect.name in self.dialect_map:
70 65 return dialect.type_descriptor(self.dialect_map[dialect.name])
71 66 return dialect.type_descriptor(self.impl)
72 67
73 68 def process_bind_param(self, value, dialect):
74 69 if isinstance(value, JsonRaw):
75 70 value = value
76 71 elif value is not None:
77 value = json.dumps(value)
78 if self.enforce_unicode:
79 value = safe_unicode(value)
72 if self.enforce_str:
73 value = ext_json.str_json(value)
74 else:
75 value = ext_json.json.dumps(value)
80 76 return value
81 77
82 78 def process_result_value(self, value, dialect):
83 79 if self.default is not None and (not value or value == '""'):
84 80 return self.default()
85 81
86 82 if value is not None:
87 83 try:
88 value = json.loads(value, object_pairs_hook=DictClass)
89 except Exception as e:
84 value = ext_json.json.loads(value)
85 except Exception:
90 86 if self.safe and self.default is not None:
91 87 return self.default()
92 88 else:
93 89 raise
94 90 return value
95 91
96 92
97 93 class MutationObj(Mutable):
94
98 95 @classmethod
99 96 def coerce(cls, key, value):
100 97 if isinstance(value, dict) and not isinstance(value, MutationDict):
101 98 return MutationDict.coerce(key, value)
102 99 if isinstance(value, list) and not isinstance(value, MutationList):
103 100 return MutationList.coerce(key, value)
104 101 return value
105 102
106 103 def de_coerce(self):
107 104 return self
108 105
109 106 @classmethod
110 107 def _listen_on_attribute(cls, attribute, coerce, parent_cls):
111 108 key = attribute.key
112 109 if parent_cls is not attribute.class_:
113 110 return
114 111
115 112 # rely on "propagate" here
116 113 parent_cls = attribute.class_
117 114
118 115 def load(state, *args):
119 116 val = state.dict.get(key, None)
120 117 if coerce:
121 118 val = cls.coerce(key, val)
122 119 state.dict[key] = val
123 120 if isinstance(val, cls):
124 121 val._parents[state.obj()] = key
125 122
126 123 def set(target, value, oldvalue, initiator):
127 124 if not isinstance(value, cls):
128 125 value = cls.coerce(key, value)
129 126 if isinstance(value, cls):
130 127 value._parents[target.obj()] = key
131 128 if isinstance(oldvalue, cls):
132 129 oldvalue._parents.pop(target.obj(), None)
133 130 return value
134 131
135 132 def pickle(state, state_dict):
136 133 val = state.dict.get(key, None)
137 134 if isinstance(val, cls):
138 135 if 'ext.mutable.values' not in state_dict:
139 136 state_dict['ext.mutable.values'] = []
140 137 state_dict['ext.mutable.values'].append(val)
141 138
142 139 def unpickle(state, state_dict):
143 140 if 'ext.mutable.values' in state_dict:
144 141 for val in state_dict['ext.mutable.values']:
145 142 val._parents[state.obj()] = key
146 143
147 144 sqlalchemy.event.listen(parent_cls, 'load', load, raw=True,
148 145 propagate=True)
149 146 sqlalchemy.event.listen(parent_cls, 'refresh', load, raw=True,
150 147 propagate=True)
151 148 sqlalchemy.event.listen(parent_cls, 'pickle', pickle, raw=True,
152 149 propagate=True)
153 150 sqlalchemy.event.listen(attribute, 'set', set, raw=True, retval=True,
154 151 propagate=True)
155 152 sqlalchemy.event.listen(parent_cls, 'unpickle', unpickle, raw=True,
156 153 propagate=True)
157 154
158 155
159 class MutationDict(MutationObj, DictClass):
160 @classmethod
161 def coerce(cls, key, value):
162 """Convert plain dictionary to MutationDict"""
163 self = MutationDict(
164 (k, MutationObj.coerce(key, v)) for (k, v) in value.items())
165 self._key = key
166 return self
167
168 def de_coerce(self):
169 return dict(self)
170
171 def __setitem__(self, key, value):
172 # Due to the way OrderedDict works, this is called during __init__.
173 # At this time we don't have a key set, but what is more, the value
174 # being set has already been coerced. So special case this and skip.
175 if hasattr(self, '_key'):
176 value = MutationObj.coerce(self._key, value)
177 DictClass.__setitem__(self, key, value)
178 self.changed()
179
180 def __delitem__(self, key):
181 DictClass.__delitem__(self, key)
182 self.changed()
183
184 def __setstate__(self, state):
185 self.__dict__ = state
186
187 def __reduce_ex__(self, proto):
188 # support pickling of MutationDicts
189 d = dict(self)
190 return (self.__class__, (d,))
191
192
193 class MutationList(MutationObj, list):
194 @classmethod
195 def coerce(cls, key, value):
196 """Convert plain list to MutationList"""
197 self = MutationList((MutationObj.coerce(key, v) for v in value))
198 self._key = key
199 return self
200
201 def de_coerce(self):
202 return list(self)
203
204 def __setitem__(self, idx, value):
205 list.__setitem__(self, idx, MutationObj.coerce(self._key, value))
206 self.changed()
207
208 def __setslice__(self, start, stop, values):
209 list.__setslice__(self, start, stop,
210 (MutationObj.coerce(self._key, v) for v in values))
211 self.changed()
212
213 def __delitem__(self, idx):
214 list.__delitem__(self, idx)
215 self.changed()
216
217 def __delslice__(self, start, stop):
218 list.__delslice__(self, start, stop)
219 self.changed()
220
221 def append(self, value):
222 list.append(self, MutationObj.coerce(self._key, value))
223 self.changed()
224
225 def insert(self, idx, value):
226 list.insert(self, idx, MutationObj.coerce(self._key, value))
227 self.changed()
228
229 def extend(self, values):
230 list.extend(self, (MutationObj.coerce(self._key, v) for v in values))
231 self.changed()
232
233 def pop(self, *args, **kw):
234 value = list.pop(self, *args, **kw)
235 self.changed()
236 return value
237
238 def remove(self, value):
239 list.remove(self, value)
240 self.changed()
241
242
243 156 def JsonType(impl=None, **kwargs):
244 157 """
245 158 Helper for using a mutation obj, it allows to use .with_variant easily.
246 159 example::
247 160
248 161 settings = Column('settings_json',
249 162 MutationObj.as_mutable(
250 163 JsonType(dialect_map=dict(mysql=UnicodeText(16384))))
251 164 """
252 165
253 166 if impl == 'list':
254 167 return JSONEncodedObj(default=list, **kwargs)
255 168 elif impl == 'dict':
256 return JSONEncodedObj(default=DictClass, **kwargs)
169 return JSONEncodedObj(default=dict, **kwargs)
257 170 else:
258 171 return JSONEncodedObj(**kwargs)
259 172
260 173
261 JSON = MutationObj.as_mutable(JsonType())
262 """
263 A type to encode/decode JSON on the fly
264
265 sqltype is the string type for the underlying DB column::
266
267 Column(JSON) (defaults to UnicodeText)
268 """
269
270 JSONDict = MutationObj.as_mutable(JsonType('dict'))
271 """
272 A type to encode/decode JSON dictionaries on the fly
273 """
274
275 JSONList = MutationObj.as_mutable(JsonType('list'))
276 """
277 A type to encode/decode JSON lists` on the fly
278 """
General Comments 0
You need to be logged in to leave comments. Login now