Show More
@@ -1,610 +1,606 b'' | |||
|
1 | 1 | |
|
2 | 2 | # Copyright (C) 2010-2020 RhodeCode GmbH |
|
3 | 3 | # |
|
4 | 4 | # This program is free software: you can redistribute it and/or modify |
|
5 | 5 | # it under the terms of the GNU Affero General Public License, version 3 |
|
6 | 6 | # (only), as published by the Free Software Foundation. |
|
7 | 7 | # |
|
8 | 8 | # This program is distributed in the hope that it will be useful, |
|
9 | 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
10 | 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
11 | 11 | # GNU General Public License for more details. |
|
12 | 12 | # |
|
13 | 13 | # You should have received a copy of the GNU Affero General Public License |
|
14 | 14 | # along with this program. If not, see <http://www.gnu.org/licenses/>. |
|
15 | 15 | # |
|
16 | 16 | # This program is dual-licensed. If you wish to learn more about the |
|
17 | 17 | # RhodeCode Enterprise Edition, including its added features, Support services, |
|
18 | 18 | # and proprietary license terms, please see https://rhodecode.com/licenses/ |
|
19 | 19 | |
|
20 | 20 | """ |
|
21 | 21 | The base Controller API |
|
22 | 22 | Provides the BaseController class for subclassing. And usage in different |
|
23 | 23 | controllers |
|
24 | 24 | """ |
|
25 | 25 | |
|
26 | 26 | import logging |
|
27 | 27 | import socket |
|
28 | import base64 | |
|
28 | 29 | |
|
29 | 30 | import markupsafe |
|
30 | 31 | import ipaddress |
|
31 | 32 | |
|
33 | import paste.httpheaders | |
|
32 | 34 | from paste.auth.basic import AuthBasicAuthenticator |
|
33 | 35 | from paste.httpexceptions import HTTPUnauthorized, HTTPForbidden, get_exception |
|
34 | from paste.httpheaders import WWW_AUTHENTICATE, AUTHORIZATION | |
|
35 | 36 | |
|
36 | 37 | import rhodecode |
|
37 | 38 | from rhodecode.authentication.base import VCS_TYPE |
|
38 | 39 | from rhodecode.lib import auth, utils2 |
|
39 | 40 | from rhodecode.lib import helpers as h |
|
40 | 41 | from rhodecode.lib.auth import AuthUser, CookieStoreWrapper |
|
41 | 42 | from rhodecode.lib.exceptions import UserCreationError |
|
42 | 43 | from rhodecode.lib.utils import (password_changed, get_enabled_hook_classes) |
|
43 |
from rhodecode.lib.utils2 import |
|
|
44 | str2bool, safe_unicode, AttributeDict, safe_int, sha1, aslist, safe_str) | |
|
44 | from rhodecode.lib.utils2 import AttributeDict | |
|
45 | from rhodecode.lib.str_utils import ascii_bytes, safe_int, safe_str | |
|
46 | from rhodecode.lib.type_utils import aslist, str2bool | |
|
47 | from rhodecode.lib.hash_utils import sha1 | |
|
45 | 48 | from rhodecode.model.db import Repository, User, ChangesetComment, UserBookmark |
|
46 | 49 | from rhodecode.model.notification import NotificationModel |
|
47 | 50 | from rhodecode.model.settings import VcsSettingsModel, SettingsModel |
|
48 | 51 | |
|
49 | 52 | log = logging.getLogger(__name__) |
|
50 | 53 | |
|
51 | 54 | |
|
52 | 55 | def _filter_proxy(ip): |
|
53 | 56 | """ |
|
54 | 57 | Passed in IP addresses in HEADERS can be in a special format of multiple |
|
55 | 58 | ips. Those comma separated IPs are passed from various proxies in the |
|
56 | 59 | chain of request processing. The left-most being the original client. |
|
57 | 60 | We only care about the first IP which came from the org. client. |
|
58 | 61 | |
|
59 | 62 | :param ip: ip string from headers |
|
60 | 63 | """ |
|
61 | 64 | if ',' in ip: |
|
62 | 65 | _ips = ip.split(',') |
|
63 | 66 | _first_ip = _ips[0].strip() |
|
64 | 67 | log.debug('Got multiple IPs %s, using %s', ','.join(_ips), _first_ip) |
|
65 | 68 | return _first_ip |
|
66 | 69 | return ip |
|
67 | 70 | |
|
68 | 71 | |
|
69 | 72 | def _filter_port(ip): |
|
70 | 73 | """ |
|
71 | 74 | Removes a port from ip, there are 4 main cases to handle here. |
|
72 | 75 | - ipv4 eg. 127.0.0.1 |
|
73 | 76 | - ipv6 eg. ::1 |
|
74 | 77 | - ipv4+port eg. 127.0.0.1:8080 |
|
75 | 78 | - ipv6+port eg. [::1]:8080 |
|
76 | 79 | |
|
77 | 80 | :param ip: |
|
78 | 81 | """ |
|
79 | 82 | def is_ipv6(ip_addr): |
|
80 | 83 | if hasattr(socket, 'inet_pton'): |
|
81 | 84 | try: |
|
82 | 85 | socket.inet_pton(socket.AF_INET6, ip_addr) |
|
83 | 86 | except socket.error: |
|
84 | 87 | return False |
|
85 | 88 | else: |
|
86 | 89 | # fallback to ipaddress |
|
87 | 90 | try: |
|
88 | 91 | ipaddress.IPv6Address(safe_str(ip_addr)) |
|
89 | 92 | except Exception: |
|
90 | 93 | return False |
|
91 | 94 | return True |
|
92 | 95 | |
|
93 | 96 | if ':' not in ip: # must be ipv4 pure ip |
|
94 | 97 | return ip |
|
95 | 98 | |
|
96 | 99 | if '[' in ip and ']' in ip: # ipv6 with port |
|
97 | 100 | return ip.split(']')[0][1:].lower() |
|
98 | 101 | |
|
99 | 102 | # must be ipv6 or ipv4 with port |
|
100 | 103 | if is_ipv6(ip): |
|
101 | 104 | return ip |
|
102 | 105 | else: |
|
103 | 106 | ip, _port = ip.split(':')[:2] # means ipv4+port |
|
104 | 107 | return ip |
|
105 | 108 | |
|
106 | 109 | |
|
107 | 110 | def get_ip_addr(environ): |
|
108 | 111 | proxy_key = 'HTTP_X_REAL_IP' |
|
109 | 112 | proxy_key2 = 'HTTP_X_FORWARDED_FOR' |
|
110 | 113 | def_key = 'REMOTE_ADDR' |
|
111 | _filters = lambda x: _filter_port(_filter_proxy(x)) | |
|
114 | ||
|
115 | def ip_filters(ip_): | |
|
116 | return _filter_port(_filter_proxy(ip_)) | |
|
112 | 117 | |
|
113 | 118 | ip = environ.get(proxy_key) |
|
114 | 119 | if ip: |
|
115 | return _filters(ip) | |
|
120 | return ip_filters(ip) | |
|
116 | 121 | |
|
117 | 122 | ip = environ.get(proxy_key2) |
|
118 | 123 | if ip: |
|
119 | return _filters(ip) | |
|
124 | return ip_filters(ip) | |
|
120 | 125 | |
|
121 | 126 | ip = environ.get(def_key, '0.0.0.0') |
|
122 | return _filters(ip) | |
|
127 | return ip_filters(ip) | |
|
123 | 128 | |
|
124 | 129 | |
|
125 | 130 | def get_server_ip_addr(environ, log_errors=True): |
|
126 | 131 | hostname = environ.get('SERVER_NAME') |
|
127 | 132 | try: |
|
128 | 133 | return socket.gethostbyname(hostname) |
|
129 | 134 | except Exception as e: |
|
130 | 135 | if log_errors: |
|
131 | 136 | # in some cases this lookup is not possible, and we don't want to |
|
132 | 137 | # make it an exception in logs |
|
133 | 138 | log.exception('Could not retrieve server ip address: %s', e) |
|
134 | 139 | return hostname |
|
135 | 140 | |
|
136 | 141 | |
|
137 | 142 | def get_server_port(environ): |
|
138 | 143 | return environ.get('SERVER_PORT') |
|
139 | 144 | |
|
140 | 145 | |
|
141 | def get_access_path(environ): | |
|
142 | path = environ.get('PATH_INFO') | |
|
143 | org_req = environ.get('pylons.original_request') | |
|
144 | if org_req: | |
|
145 | path = org_req.environ.get('PATH_INFO') | |
|
146 | return path | |
|
147 | ||
|
148 | 146 | |
|
149 | 147 | def get_user_agent(environ): |
|
150 | 148 | return environ.get('HTTP_USER_AGENT') |
|
151 | 149 | |
|
152 | 150 | |
|
153 | 151 | def vcs_operation_context( |
|
154 | 152 | environ, repo_name, username, action, scm, check_locking=True, |
|
155 | 153 | is_shadow_repo=False, check_branch_perms=False, detect_force_push=False): |
|
156 | 154 | """ |
|
157 | 155 | Generate the context for a vcs operation, e.g. push or pull. |
|
158 | 156 | |
|
159 | 157 | This context is passed over the layers so that hooks triggered by the |
|
160 | 158 | vcs operation know details like the user, the user's IP address etc. |
|
161 | 159 | |
|
162 | 160 | :param check_locking: Allows to switch of the computation of the locking |
|
163 | 161 | data. This serves mainly the need of the simplevcs middleware to be |
|
164 | 162 | able to disable this for certain operations. |
|
165 | 163 | |
|
166 | 164 | """ |
|
167 | 165 | # Tri-state value: False: unlock, None: nothing, True: lock |
|
168 | 166 | make_lock = None |
|
169 | 167 | locked_by = [None, None, None] |
|
170 | 168 | is_anonymous = username == User.DEFAULT_USER |
|
171 | 169 | user = User.get_by_username(username) |
|
172 | 170 | if not is_anonymous and check_locking: |
|
173 | 171 | log.debug('Checking locking on repository "%s"', repo_name) |
|
174 | 172 | repo = Repository.get_by_repo_name(repo_name) |
|
175 | 173 | make_lock, __, locked_by = repo.get_locking_state( |
|
176 | 174 | action, user.user_id) |
|
177 | 175 | user_id = user.user_id |
|
178 | 176 | settings_model = VcsSettingsModel(repo=repo_name) |
|
179 | 177 | ui_settings = settings_model.get_ui_settings() |
|
180 | 178 | |
|
181 | 179 | # NOTE(marcink): This should be also in sync with |
|
182 | 180 | # rhodecode/apps/ssh_support/lib/backends/base.py:update_environment scm_data |
|
183 | 181 | store = [x for x in ui_settings if x.key == '/'] |
|
184 | 182 | repo_store = '' |
|
185 | 183 | if store: |
|
186 | 184 | repo_store = store[0].value |
|
187 | 185 | |
|
188 | 186 | scm_data = { |
|
189 | 187 | 'ip': get_ip_addr(environ), |
|
190 | 188 | 'username': username, |
|
191 | 189 | 'user_id': user_id, |
|
192 | 190 | 'action': action, |
|
193 | 191 | 'repository': repo_name, |
|
194 | 192 | 'scm': scm, |
|
195 | 193 | 'config': rhodecode.CONFIG['__file__'], |
|
196 | 194 | 'repo_store': repo_store, |
|
197 | 195 | 'make_lock': make_lock, |
|
198 | 196 | 'locked_by': locked_by, |
|
199 | 197 | 'server_url': utils2.get_server_url(environ), |
|
200 | 198 | 'user_agent': get_user_agent(environ), |
|
201 | 199 | 'hooks': get_enabled_hook_classes(ui_settings), |
|
202 | 200 | 'is_shadow_repo': is_shadow_repo, |
|
203 | 201 | 'detect_force_push': detect_force_push, |
|
204 | 202 | 'check_branch_perms': check_branch_perms, |
|
205 | 203 | } |
|
206 | 204 | return scm_data |
|
207 | 205 | |
|
208 | 206 | |
|
209 | 207 | class BasicAuth(AuthBasicAuthenticator): |
|
210 | 208 | |
|
211 | 209 | def __init__(self, realm, authfunc, registry, auth_http_code=None, |
|
212 | 210 | initial_call_detection=False, acl_repo_name=None, rc_realm=''): |
|
211 | super(BasicAuth, self).__init__(realm=realm, authfunc=authfunc) | |
|
213 | 212 | self.realm = realm |
|
214 | 213 | self.rc_realm = rc_realm |
|
215 | 214 | self.initial_call = initial_call_detection |
|
216 | 215 | self.authfunc = authfunc |
|
217 | 216 | self.registry = registry |
|
218 | 217 | self.acl_repo_name = acl_repo_name |
|
219 | 218 | self._rc_auth_http_code = auth_http_code |
|
220 | 219 | |
|
221 | def _get_response_from_code(self, http_code): | |
|
220 | def _get_response_from_code(self, http_code, fallback): | |
|
222 | 221 | try: |
|
223 | 222 | return get_exception(safe_int(http_code)) |
|
224 | 223 | except Exception: |
|
225 | log.exception('Failed to fetch response for code %s', http_code) | |
|
226 |
return |
|
|
224 | log.exception('Failed to fetch response class for code %s, using fallback: %s', http_code, fallback) | |
|
225 | return fallback | |
|
227 | 226 | |
|
228 | 227 | def get_rc_realm(self): |
|
229 | 228 | return safe_str(self.rc_realm) |
|
230 | 229 | |
|
231 | 230 | def build_authentication(self): |
|
232 |
head = |
|
|
231 | header = [('WWW-Authenticate', f'Basic realm="{self.realm}"')] | |
|
232 | ||
|
233 | # NOTE: the initial_Call detection seems to be not working/not needed witg latest Mercurial | |
|
234 | # investigate if we still need it. | |
|
233 | 235 | if self._rc_auth_http_code and not self.initial_call: |
|
234 | 236 | # return alternative HTTP code if alternative http return code |
|
235 | 237 | # is specified in RhodeCode config, but ONLY if it's not the |
|
236 | 238 | # FIRST call |
|
237 | custom_response_klass = self._get_response_from_code( | |
|
238 | self._rc_auth_http_code) | |
|
239 | return custom_response_klass(headers=head) | |
|
240 | return HTTPUnauthorized(headers=head) | |
|
239 | custom_response_klass = self._get_response_from_code(self._rc_auth_http_code, fallback=HTTPUnauthorized) | |
|
240 | log.debug('Using custom response class: %s', custom_response_klass) | |
|
241 | return custom_response_klass(headers=header) | |
|
242 | return HTTPUnauthorized(headers=header) | |
|
241 | 243 | |
|
242 | 244 | def authenticate(self, environ): |
|
243 | authorization = AUTHORIZATION(environ) | |
|
245 | authorization = paste.httpheaders.AUTHORIZATION(environ) | |
|
244 | 246 | if not authorization: |
|
245 | 247 | return self.build_authentication() |
|
246 | (authmeth, auth) = authorization.split(' ', 1) | |
|
247 | if 'basic' != authmeth.lower(): | |
|
248 | (auth_meth, auth_creds_b64) = authorization.split(' ', 1) | |
|
249 | if 'basic' != auth_meth.lower(): | |
|
248 | 250 | return self.build_authentication() |
|
249 | auth = auth.strip().decode('base64') | |
|
250 | _parts = auth.split(':', 1) | |
|
251 | ||
|
252 | credentials = safe_str(base64.b64decode(auth_creds_b64.strip())) | |
|
253 | _parts = credentials.split(':', 1) | |
|
251 | 254 | if len(_parts) == 2: |
|
252 | 255 | username, password = _parts |
|
253 | 256 | auth_data = self.authfunc( |
|
254 | 257 | username, password, environ, VCS_TYPE, |
|
255 | 258 | registry=self.registry, acl_repo_name=self.acl_repo_name) |
|
256 | 259 | if auth_data: |
|
257 | 260 | return {'username': username, 'auth_data': auth_data} |
|
258 | 261 | if username and password: |
|
259 | 262 | # we mark that we actually executed authentication once, at |
|
260 | 263 | # that point we can use the alternative auth code |
|
261 | 264 | self.initial_call = False |
|
262 | 265 | |
|
263 | 266 | return self.build_authentication() |
|
264 | 267 | |
|
265 | 268 | __call__ = authenticate |
|
266 | 269 | |
|
267 | 270 | |
|
268 | 271 | def calculate_version_hash(config): |
|
269 | 272 | return sha1( |
|
270 | config.get('beaker.session.secret', '') + | |
|
271 | rhodecode.__version__)[:8] | |
|
273 | config.get(b'beaker.session.secret', b'') + ascii_bytes(rhodecode.__version__) | |
|
274 | )[:8] | |
|
272 | 275 | |
|
273 | 276 | |
|
274 | 277 | def get_current_lang(request): |
|
275 | # NOTE(marcink): remove after pyramid move | |
|
276 | try: | |
|
277 | return translation.get_lang()[0] | |
|
278 | except: | |
|
279 | pass | |
|
280 | ||
|
281 | 278 | return getattr(request, '_LOCALE_', request.locale_name) |
|
282 | 279 | |
|
283 | 280 | |
|
284 | 281 | def attach_context_attributes(context, request, user_id=None, is_api=None): |
|
285 | 282 | """ |
|
286 | 283 | Attach variables into template context called `c`. |
|
287 | 284 | """ |
|
288 | 285 | config = request.registry.settings |
|
289 | 286 | |
|
290 | 287 | rc_config = SettingsModel().get_all_settings(cache=True, from_request=False) |
|
291 | 288 | context.rc_config = rc_config |
|
292 | 289 | context.rhodecode_version = rhodecode.__version__ |
|
293 | 290 | context.rhodecode_edition = config.get('rhodecode.edition') |
|
294 | 291 | context.rhodecode_edition_id = config.get('rhodecode.edition_id') |
|
295 | 292 | # unique secret + version does not leak the version but keep consistency |
|
296 | 293 | context.rhodecode_version_hash = calculate_version_hash(config) |
|
297 | 294 | |
|
298 | 295 | # Default language set for the incoming request |
|
299 | 296 | context.language = get_current_lang(request) |
|
300 | 297 | |
|
301 | 298 | # Visual options |
|
302 | 299 | context.visual = AttributeDict({}) |
|
303 | 300 | |
|
304 | 301 | # DB stored Visual Items |
|
305 | 302 | context.visual.show_public_icon = str2bool( |
|
306 | 303 | rc_config.get('rhodecode_show_public_icon')) |
|
307 | 304 | context.visual.show_private_icon = str2bool( |
|
308 | 305 | rc_config.get('rhodecode_show_private_icon')) |
|
309 | 306 | context.visual.stylify_metatags = str2bool( |
|
310 | 307 | rc_config.get('rhodecode_stylify_metatags')) |
|
311 | 308 | context.visual.dashboard_items = safe_int( |
|
312 | 309 | rc_config.get('rhodecode_dashboard_items', 100)) |
|
313 | 310 | context.visual.admin_grid_items = safe_int( |
|
314 | 311 | rc_config.get('rhodecode_admin_grid_items', 100)) |
|
315 | 312 | context.visual.show_revision_number = str2bool( |
|
316 | 313 | rc_config.get('rhodecode_show_revision_number', True)) |
|
317 | 314 | context.visual.show_sha_length = safe_int( |
|
318 | 315 | rc_config.get('rhodecode_show_sha_length', 100)) |
|
319 | 316 | context.visual.repository_fields = str2bool( |
|
320 | 317 | rc_config.get('rhodecode_repository_fields')) |
|
321 | 318 | context.visual.show_version = str2bool( |
|
322 | 319 | rc_config.get('rhodecode_show_version')) |
|
323 | 320 | context.visual.use_gravatar = str2bool( |
|
324 | 321 | rc_config.get('rhodecode_use_gravatar')) |
|
325 | 322 | context.visual.gravatar_url = rc_config.get('rhodecode_gravatar_url') |
|
326 | 323 | context.visual.default_renderer = rc_config.get( |
|
327 | 324 | 'rhodecode_markup_renderer', 'rst') |
|
328 | 325 | context.visual.comment_types = ChangesetComment.COMMENT_TYPES |
|
329 | 326 | context.visual.rhodecode_support_url = \ |
|
330 | 327 | rc_config.get('rhodecode_support_url') or h.route_url('rhodecode_support') |
|
331 | 328 | |
|
332 | 329 | context.visual.affected_files_cut_off = 60 |
|
333 | 330 | |
|
334 | 331 | context.pre_code = rc_config.get('rhodecode_pre_code') |
|
335 | 332 | context.post_code = rc_config.get('rhodecode_post_code') |
|
336 | 333 | context.rhodecode_name = rc_config.get('rhodecode_title') |
|
337 | 334 | context.default_encodings = aslist(config.get('default_encoding'), sep=',') |
|
338 | 335 | # if we have specified default_encoding in the request, it has more |
|
339 | 336 | # priority |
|
340 | 337 | if request.GET.get('default_encoding'): |
|
341 | 338 | context.default_encodings.insert(0, request.GET.get('default_encoding')) |
|
342 | 339 | context.clone_uri_tmpl = rc_config.get('rhodecode_clone_uri_tmpl') |
|
343 | 340 | context.clone_uri_id_tmpl = rc_config.get('rhodecode_clone_uri_id_tmpl') |
|
344 | 341 | context.clone_uri_ssh_tmpl = rc_config.get('rhodecode_clone_uri_ssh_tmpl') |
|
345 | 342 | |
|
346 | 343 | # INI stored |
|
347 | 344 | context.labs_active = str2bool( |
|
348 | 345 | config.get('labs_settings_active', 'false')) |
|
349 | 346 | context.ssh_enabled = str2bool( |
|
350 | 347 | config.get('ssh.generate_authorized_keyfile', 'false')) |
|
351 | 348 | context.ssh_key_generator_enabled = str2bool( |
|
352 | 349 | config.get('ssh.enable_ui_key_generator', 'true')) |
|
353 | 350 | |
|
354 | 351 | context.visual.allow_repo_location_change = str2bool( |
|
355 | 352 | config.get('allow_repo_location_change', True)) |
|
356 | 353 | context.visual.allow_custom_hooks_settings = str2bool( |
|
357 | 354 | config.get('allow_custom_hooks_settings', True)) |
|
358 | 355 | context.debug_style = str2bool(config.get('debug_style', False)) |
|
359 | 356 | |
|
360 | 357 | context.rhodecode_instanceid = config.get('instance_id') |
|
361 | 358 | |
|
362 | 359 | context.visual.cut_off_limit_diff = safe_int( |
|
363 | config.get('cut_off_limit_diff')) | |
|
360 | config.get('cut_off_limit_diff'), default=0) | |
|
364 | 361 | context.visual.cut_off_limit_file = safe_int( |
|
365 | config.get('cut_off_limit_file')) | |
|
362 | config.get('cut_off_limit_file'), default=0) | |
|
366 | 363 | |
|
367 | 364 | context.license = AttributeDict({}) |
|
368 | 365 | context.license.hide_license_info = str2bool( |
|
369 | 366 | config.get('license.hide_license_info', False)) |
|
370 | 367 | |
|
371 | 368 | # AppEnlight |
|
372 | 369 | context.appenlight_enabled = config.get('appenlight', False) |
|
373 | 370 | context.appenlight_api_public_key = config.get( |
|
374 | 371 | 'appenlight.api_public_key', '') |
|
375 | 372 | context.appenlight_server_url = config.get('appenlight.server_url', '') |
|
376 | 373 | |
|
377 | 374 | diffmode = { |
|
378 | 375 | "unified": "unified", |
|
379 | 376 | "sideside": "sideside" |
|
380 | 377 | }.get(request.GET.get('diffmode')) |
|
381 | 378 | |
|
382 | 379 | if is_api is not None: |
|
383 | 380 | is_api = hasattr(request, 'rpc_user') |
|
384 | 381 | session_attrs = { |
|
385 | 382 | # defaults |
|
386 | 383 | "clone_url_format": "http", |
|
387 | 384 | "diffmode": "sideside", |
|
388 | 385 | "license_fingerprint": request.session.get('license_fingerprint') |
|
389 | 386 | } |
|
390 | 387 | |
|
391 | 388 | if not is_api: |
|
392 | 389 | # don't access pyramid session for API calls |
|
393 | 390 | if diffmode and diffmode != request.session.get('rc_user_session_attr.diffmode'): |
|
394 | 391 | request.session['rc_user_session_attr.diffmode'] = diffmode |
|
395 | 392 | |
|
396 | 393 | # session settings per user |
|
397 | 394 | |
|
398 | for k, v in request.session.items(): | |
|
395 | for k, v in list(request.session.items()): | |
|
399 | 396 | pref = 'rc_user_session_attr.' |
|
400 | 397 | if k and k.startswith(pref): |
|
401 | 398 | k = k[len(pref):] |
|
402 | 399 | session_attrs[k] = v |
|
403 | 400 | |
|
404 | 401 | context.user_session_attrs = session_attrs |
|
405 | 402 | |
|
406 | 403 | # JS template context |
|
407 | 404 | context.template_context = { |
|
408 | 405 | 'repo_name': None, |
|
409 | 406 | 'repo_type': None, |
|
410 | 407 | 'repo_landing_commit': None, |
|
411 | 408 | 'rhodecode_user': { |
|
412 | 409 | 'username': None, |
|
413 | 410 | 'email': None, |
|
414 | 411 | 'notification_status': False |
|
415 | 412 | }, |
|
416 | 413 | 'session_attrs': session_attrs, |
|
417 | 414 | 'visual': { |
|
418 | 415 | 'default_renderer': None |
|
419 | 416 | }, |
|
420 | 417 | 'commit_data': { |
|
421 | 418 | 'commit_id': None |
|
422 | 419 | }, |
|
423 | 420 | 'pull_request_data': {'pull_request_id': None}, |
|
424 | 421 | 'timeago': { |
|
425 | 422 | 'refresh_time': 120 * 1000, |
|
426 | 423 | 'cutoff_limit': 1000 * 60 * 60 * 24 * 7 |
|
427 | 424 | }, |
|
428 | 425 | 'pyramid_dispatch': { |
|
429 | 426 | |
|
430 | 427 | }, |
|
431 | 428 | 'extra': {'plugins': {}} |
|
432 | 429 | } |
|
433 | 430 | # END CONFIG VARS |
|
434 | 431 | if is_api: |
|
435 | 432 | csrf_token = None |
|
436 | 433 | else: |
|
437 | 434 | csrf_token = auth.get_csrf_token(session=request.session) |
|
438 | 435 | |
|
439 | 436 | context.csrf_token = csrf_token |
|
440 | context.backends = rhodecode.BACKENDS.keys() | |
|
437 | context.backends = list(rhodecode.BACKENDS.keys()) | |
|
441 | 438 | |
|
442 | 439 | unread_count = 0 |
|
443 | 440 | user_bookmark_list = [] |
|
444 | 441 | if user_id: |
|
445 | 442 | unread_count = NotificationModel().get_unread_cnt_for_user(user_id) |
|
446 | 443 | user_bookmark_list = UserBookmark.get_bookmarks_for_user(user_id) |
|
447 | 444 | context.unread_notifications = unread_count |
|
448 | 445 | context.bookmark_items = user_bookmark_list |
|
449 | 446 | |
|
450 | 447 | # web case |
|
451 | 448 | if hasattr(request, 'user'): |
|
452 | 449 | context.auth_user = request.user |
|
453 | 450 | context.rhodecode_user = request.user |
|
454 | 451 | |
|
455 | 452 | # api case |
|
456 | 453 | if hasattr(request, 'rpc_user'): |
|
457 | 454 | context.auth_user = request.rpc_user |
|
458 | 455 | context.rhodecode_user = request.rpc_user |
|
459 | 456 | |
|
460 | 457 | # attach the whole call context to the request |
|
461 | 458 | request.set_call_context(context) |
|
462 | 459 | |
|
463 | 460 | |
|
464 | 461 | def get_auth_user(request): |
|
465 | 462 | environ = request.environ |
|
466 | 463 | session = request.session |
|
467 | 464 | |
|
468 | 465 | ip_addr = get_ip_addr(environ) |
|
469 | 466 | |
|
470 | 467 | # make sure that we update permissions each time we call controller |
|
471 | 468 | _auth_token = ( |
|
472 | 469 | # ?auth_token=XXX |
|
473 | 470 | request.GET.get('auth_token', '') |
|
474 | 471 | # ?api_key=XXX !LEGACY |
|
475 | 472 | or request.GET.get('api_key', '') |
|
476 | 473 | # or headers.... |
|
477 | 474 | or request.headers.get('X-Rc-Auth-Token', '') |
|
478 | 475 | ) |
|
479 | 476 | if not _auth_token and request.matchdict: |
|
480 | 477 | url_auth_token = request.matchdict.get('_auth_token') |
|
481 | 478 | _auth_token = url_auth_token |
|
482 | 479 | if _auth_token: |
|
483 | 480 | log.debug('Using URL extracted auth token `...%s`', _auth_token[-4:]) |
|
484 | 481 | |
|
485 | 482 | if _auth_token: |
|
486 | 483 | # when using API_KEY we assume user exists, and |
|
487 | 484 | # doesn't need auth based on cookies. |
|
488 | 485 | auth_user = AuthUser(api_key=_auth_token, ip_addr=ip_addr) |
|
489 | 486 | authenticated = False |
|
490 | 487 | else: |
|
491 | 488 | cookie_store = CookieStoreWrapper(session.get('rhodecode_user')) |
|
492 | 489 | try: |
|
493 | 490 | auth_user = AuthUser(user_id=cookie_store.get('user_id', None), |
|
494 | 491 | ip_addr=ip_addr) |
|
495 | 492 | except UserCreationError as e: |
|
496 | 493 | h.flash(e, 'error') |
|
497 | 494 | # container auth or other auth functions that create users |
|
498 | 495 | # on the fly can throw this exception signaling that there's |
|
499 | 496 | # issue with user creation, explanation should be provided |
|
500 | 497 | # in Exception itself. We then create a simple blank |
|
501 | 498 | # AuthUser |
|
502 | 499 | auth_user = AuthUser(ip_addr=ip_addr) |
|
503 | 500 | |
|
504 | 501 | # in case someone changes a password for user it triggers session |
|
505 | 502 | # flush and forces a re-login |
|
506 | 503 | if password_changed(auth_user, session): |
|
507 | 504 | session.invalidate() |
|
508 | 505 | cookie_store = CookieStoreWrapper(session.get('rhodecode_user')) |
|
509 | 506 | auth_user = AuthUser(ip_addr=ip_addr) |
|
510 | 507 | |
|
511 | 508 | authenticated = cookie_store.get('is_authenticated') |
|
512 | 509 | |
|
513 | 510 | if not auth_user.is_authenticated and auth_user.is_user_object: |
|
514 | 511 | # user is not authenticated and not empty |
|
515 | 512 | auth_user.set_authenticated(authenticated) |
|
516 | 513 | |
|
517 | 514 | return auth_user, _auth_token |
|
518 | 515 | |
|
519 | 516 | |
|
520 | 517 | def h_filter(s): |
|
521 | 518 | """ |
|
522 | 519 | Custom filter for Mako templates. Mako by standard uses `markupsafe.escape` |
|
523 | 520 | we wrap this with additional functionality that converts None to empty |
|
524 | 521 | strings |
|
525 | 522 | """ |
|
526 | 523 | if s is None: |
|
527 | 524 | return markupsafe.Markup() |
|
528 | 525 | return markupsafe.escape(s) |
|
529 | 526 | |
|
530 | 527 | |
|
531 | 528 | def add_events_routes(config): |
|
532 | 529 | """ |
|
533 | 530 | Adds routing that can be used in events. Because some events are triggered |
|
534 | 531 | outside of pyramid context, we need to bootstrap request with some |
|
535 | 532 | routing registered |
|
536 | 533 | """ |
|
537 | 534 | |
|
538 | 535 | from rhodecode.apps._base import ADMIN_PREFIX |
|
539 | 536 | |
|
540 | 537 | config.add_route(name='home', pattern='/') |
|
541 | 538 | config.add_route(name='main_page_repos_data', pattern='/_home_repos') |
|
542 | 539 | config.add_route(name='main_page_repo_groups_data', pattern='/_home_repo_groups') |
|
543 | 540 | |
|
544 | 541 | config.add_route(name='login', pattern=ADMIN_PREFIX + '/login') |
|
545 | 542 | config.add_route(name='logout', pattern=ADMIN_PREFIX + '/logout') |
|
546 | 543 | config.add_route(name='repo_summary', pattern='/{repo_name}') |
|
547 | 544 | config.add_route(name='repo_summary_explicit', pattern='/{repo_name}/summary') |
|
548 | 545 | config.add_route(name='repo_group_home', pattern='/{repo_group_name}') |
|
549 | 546 | |
|
550 | 547 | config.add_route(name='pullrequest_show', |
|
551 | 548 | pattern='/{repo_name}/pull-request/{pull_request_id}') |
|
552 | 549 | config.add_route(name='pull_requests_global', |
|
553 | 550 | pattern='/pull-request/{pull_request_id}') |
|
554 | 551 | |
|
555 | 552 | config.add_route(name='repo_commit', |
|
556 | 553 | pattern='/{repo_name}/changeset/{commit_id}') |
|
557 | 554 | config.add_route(name='repo_files', |
|
558 | 555 | pattern='/{repo_name}/files/{commit_id}/{f_path}') |
|
559 | 556 | |
|
560 | 557 | config.add_route(name='hovercard_user', |
|
561 | 558 | pattern='/_hovercard/user/{user_id}') |
|
562 | 559 | |
|
563 | 560 | config.add_route(name='hovercard_user_group', |
|
564 | 561 | pattern='/_hovercard/user_group/{user_group_id}') |
|
565 | 562 | |
|
566 | 563 | config.add_route(name='hovercard_pull_request', |
|
567 | 564 | pattern='/_hovercard/pull_request/{pull_request_id}') |
|
568 | 565 | |
|
569 | 566 | config.add_route(name='hovercard_repo_commit', |
|
570 | 567 | pattern='/_hovercard/commit/{repo_name}/{commit_id}') |
|
571 | 568 | |
|
572 | 569 | |
|
573 | 570 | def bootstrap_config(request, registry_name='RcTestRegistry'): |
|
574 | 571 | import pyramid.testing |
|
575 | 572 | registry = pyramid.testing.Registry(registry_name) |
|
576 | 573 | |
|
577 | 574 | config = pyramid.testing.setUp(registry=registry, request=request) |
|
578 | 575 | |
|
579 | 576 | # allow pyramid lookup in testing |
|
580 | 577 | config.include('pyramid_mako') |
|
581 | 578 | config.include('rhodecode.lib.rc_beaker') |
|
582 | 579 | config.include('rhodecode.lib.rc_cache') |
|
583 | ||
|
580 | config.include('rhodecode.lib.rc_cache.archive_cache') | |
|
584 | 581 | add_events_routes(config) |
|
585 | 582 | |
|
586 | 583 | return config |
|
587 | 584 | |
|
588 | 585 | |
|
589 | 586 | def bootstrap_request(**kwargs): |
|
590 | 587 | """ |
|
591 | 588 | Returns a thin version of Request Object that is used in non-web context like testing/celery |
|
592 | 589 | """ |
|
593 | 590 | |
|
594 | 591 | import pyramid.testing |
|
595 | 592 | from rhodecode.lib.request import ThinRequest as _ThinRequest |
|
596 | 593 | |
|
597 | 594 | class ThinRequest(_ThinRequest): |
|
598 | 595 | application_url = kwargs.pop('application_url', 'http://example.com') |
|
599 | 596 | host = kwargs.pop('host', 'example.com:80') |
|
600 | 597 | domain = kwargs.pop('domain', 'example.com') |
|
601 | 598 | |
|
602 | 599 | class ThinSession(pyramid.testing.DummySession): |
|
603 | 600 | def save(*arg, **kw): |
|
604 | 601 | pass |
|
605 | 602 | |
|
606 | 603 | request = ThinRequest(**kwargs) |
|
607 | 604 | request.session = ThinSession() |
|
608 | 605 | |
|
609 | 606 | return request |
|
610 |
@@ -1,248 +1,251 b'' | |||
|
1 | 1 | |
|
2 | 2 | # Copyright (C) 2010-2020 RhodeCode GmbH |
|
3 | 3 | # |
|
4 | 4 | # This program is free software: you can redistribute it and/or modify |
|
5 | 5 | # it under the terms of the GNU Affero General Public License, version 3 |
|
6 | 6 | # (only), as published by the Free Software Foundation. |
|
7 | 7 | # |
|
8 | 8 | # This program is distributed in the hope that it will be useful, |
|
9 | 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
10 | 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
11 | 11 | # GNU General Public License for more details. |
|
12 | 12 | # |
|
13 | 13 | # You should have received a copy of the GNU Affero General Public License |
|
14 | 14 | # along with this program. If not, see <http://www.gnu.org/licenses/>. |
|
15 | 15 | # |
|
16 | 16 | # This program is dual-licensed. If you wish to learn more about the |
|
17 | 17 | # RhodeCode Enterprise Edition, including its added features, Support services, |
|
18 | 18 | # and proprietary license terms, please see https://rhodecode.com/licenses/ |
|
19 | 19 | |
|
20 | 20 | """caching_query.py |
|
21 | 21 | |
|
22 | 22 | Represent functions and classes |
|
23 | 23 | which allow the usage of Dogpile caching with SQLAlchemy. |
|
24 | 24 | Introduces a query option called FromCache. |
|
25 | 25 | |
|
26 | 26 | .. versionchanged:: 1.4 the caching approach has been altered to work |
|
27 | 27 | based on a session event. |
|
28 | 28 | |
|
29 | 29 | |
|
30 | 30 | The three new concepts introduced here are: |
|
31 | 31 | |
|
32 | 32 | * ORMCache - an extension for an ORM :class:`.Session` |
|
33 | 33 | retrieves results in/from dogpile.cache. |
|
34 | 34 | * FromCache - a query option that establishes caching |
|
35 | 35 | parameters on a Query |
|
36 | 36 | * RelationshipCache - a variant of FromCache which is specific |
|
37 | 37 | to a query invoked during a lazy load. |
|
38 | 38 | |
|
39 | 39 | The rest of what's here are standard SQLAlchemy and |
|
40 | 40 | dogpile.cache constructs. |
|
41 | 41 | |
|
42 | 42 | """ |
|
43 | 43 | from dogpile.cache.api import NO_VALUE |
|
44 | 44 | |
|
45 | 45 | from sqlalchemy import event |
|
46 | 46 | from sqlalchemy.orm import loading |
|
47 | 47 | from sqlalchemy.orm.interfaces import UserDefinedOption |
|
48 | 48 | |
|
49 | 49 | |
|
50 | 50 | DEFAULT_REGION = "sql_cache_short" |
|
51 | 51 | |
|
52 | 52 | |
|
53 | 53 | class ORMCache: |
|
54 | 54 | |
|
55 | 55 | """An add-on for an ORM :class:`.Session` optionally loads full results |
|
56 | 56 | from a dogpile cache region. |
|
57 | 57 | |
|
58 | 58 | cache = ORMCache(regions={}) |
|
59 | 59 | cache.listen_on_session(Session) |
|
60 | 60 | |
|
61 | 61 | """ |
|
62 | 62 | |
|
63 | 63 | def __init__(self, regions): |
|
64 | 64 | self.cache_regions = regions or self._get_region() |
|
65 | 65 | self._statement_cache = {} |
|
66 | 66 | |
|
67 | 67 | @classmethod |
|
68 | 68 | def _get_region(cls): |
|
69 | 69 | from rhodecode.lib.rc_cache import region_meta |
|
70 | 70 | return region_meta.dogpile_cache_regions |
|
71 | 71 | |
|
72 | 72 | def listen_on_session(self, session_factory): |
|
73 | 73 | event.listen(session_factory, "do_orm_execute", self._do_orm_execute) |
|
74 | 74 | |
|
75 | 75 | def _do_orm_execute(self, orm_context): |
|
76 | ||
|
77 | 76 | for opt in orm_context.user_defined_options: |
|
78 | 77 | if isinstance(opt, RelationshipCache): |
|
79 | 78 | opt = opt._process_orm_context(orm_context) |
|
80 | 79 | if opt is None: |
|
81 | 80 | continue |
|
82 | 81 | |
|
83 | 82 | if isinstance(opt, FromCache): |
|
84 | 83 | dogpile_region = self.cache_regions[opt.region] |
|
85 | 84 | |
|
85 | if dogpile_region.expiration_time <= 0: | |
|
86 | # don't cache 0 time expiration cache | |
|
87 | continue | |
|
88 | ||
|
86 | 89 | if opt.cache_key: |
|
87 | 90 | our_cache_key = f'SQL_CACHE_{opt.cache_key}' |
|
88 | 91 | else: |
|
89 | 92 | our_cache_key = opt._generate_cache_key( |
|
90 | 93 | orm_context.statement, orm_context.parameters, self |
|
91 | 94 | ) |
|
92 | 95 | |
|
93 | 96 | if opt.ignore_expiration: |
|
94 | 97 | cached_value = dogpile_region.get( |
|
95 | 98 | our_cache_key, |
|
96 | 99 | expiration_time=opt.expiration_time, |
|
97 | 100 | ignore_expiration=opt.ignore_expiration, |
|
98 | 101 | ) |
|
99 | 102 | else: |
|
100 | 103 | |
|
101 | 104 | def createfunc(): |
|
102 | 105 | return orm_context.invoke_statement().freeze() |
|
103 | 106 | |
|
104 | 107 | cached_value = dogpile_region.get_or_create( |
|
105 | 108 | our_cache_key, |
|
106 | 109 | createfunc, |
|
107 | 110 | expiration_time=opt.expiration_time, |
|
108 | 111 | ) |
|
109 | 112 | |
|
110 | 113 | if cached_value is NO_VALUE: |
|
111 | 114 | # keyerror? this is bigger than a keyerror... |
|
112 | 115 | raise KeyError() |
|
113 | 116 | |
|
114 | 117 | orm_result = loading.merge_frozen_result( |
|
115 | 118 | orm_context.session, |
|
116 | 119 | orm_context.statement, |
|
117 | 120 | cached_value, |
|
118 | 121 | load=False, |
|
119 | 122 | ) |
|
120 | 123 | return orm_result() |
|
121 | 124 | |
|
122 | 125 | else: |
|
123 | 126 | return None |
|
124 | 127 | |
|
125 | 128 | def invalidate(self, statement, parameters, opt): |
|
126 | 129 | """Invalidate the cache value represented by a statement.""" |
|
127 | 130 | |
|
128 | 131 | statement = statement.__clause_element__() |
|
129 | 132 | |
|
130 | 133 | dogpile_region = self.cache_regions[opt.region] |
|
131 | 134 | |
|
132 | 135 | cache_key = opt._generate_cache_key(statement, parameters, self) |
|
133 | 136 | |
|
134 | 137 | dogpile_region.delete(cache_key) |
|
135 | 138 | |
|
136 | 139 | |
|
137 | 140 | class FromCache(UserDefinedOption): |
|
138 | 141 | """Specifies that a Query should load results from a cache.""" |
|
139 | 142 | |
|
140 | 143 | propagate_to_loaders = False |
|
141 | 144 | |
|
142 | 145 | def __init__( |
|
143 | 146 | self, |
|
144 | 147 | region=DEFAULT_REGION, |
|
145 | 148 | cache_key=None, |
|
146 | 149 | expiration_time=None, |
|
147 | 150 | ignore_expiration=False, |
|
148 | 151 | ): |
|
149 | 152 | """Construct a new FromCache. |
|
150 | 153 | |
|
151 | 154 | :param region: the cache region. Should be a |
|
152 | 155 | region configured in the dictionary of dogpile |
|
153 | 156 | regions. |
|
154 | 157 | |
|
155 | 158 | :param cache_key: optional. A string cache key |
|
156 | 159 | that will serve as the key to the query. Use this |
|
157 | 160 | if your query has a huge amount of parameters (such |
|
158 | 161 | as when using in_()) which correspond more simply to |
|
159 | 162 | some other identifier. |
|
160 | 163 | |
|
161 | 164 | """ |
|
162 | 165 | self.region = region |
|
163 | 166 | self.cache_key = cache_key |
|
164 | 167 | self.expiration_time = expiration_time |
|
165 | 168 | self.ignore_expiration = ignore_expiration |
|
166 | 169 | |
|
167 | 170 | # this is not needed as of SQLAlchemy 1.4.28; |
|
168 | 171 | # UserDefinedOption classes no longer participate in the SQL |
|
169 | 172 | # compilation cache key |
|
170 | 173 | def _gen_cache_key(self, anon_map, bindparams): |
|
171 | 174 | return None |
|
172 | 175 | |
|
173 | 176 | def _generate_cache_key(self, statement, parameters, orm_cache): |
|
174 | 177 | """generate a cache key with which to key the results of a statement. |
|
175 | 178 | |
|
176 | 179 | This leverages the use of the SQL compilation cache key which is |
|
177 | 180 | repurposed as a SQL results key. |
|
178 | 181 | |
|
179 | 182 | """ |
|
180 | 183 | statement_cache_key = statement._generate_cache_key() |
|
181 | 184 | |
|
182 | 185 | key = statement_cache_key.to_offline_string( |
|
183 | 186 | orm_cache._statement_cache, statement, parameters |
|
184 | 187 | ) + repr(self.cache_key) |
|
185 | 188 | # print("here's our key...%s" % key) |
|
186 | 189 | return key |
|
187 | 190 | |
|
188 | 191 | |
|
189 | 192 | class RelationshipCache(FromCache): |
|
190 | 193 | """Specifies that a Query as called within a "lazy load" |
|
191 | 194 | should load results from a cache.""" |
|
192 | 195 | |
|
193 | 196 | propagate_to_loaders = True |
|
194 | 197 | |
|
195 | 198 | def __init__( |
|
196 | 199 | self, |
|
197 | 200 | attribute, |
|
198 | 201 | region=DEFAULT_REGION, |
|
199 | 202 | cache_key=None, |
|
200 | 203 | expiration_time=None, |
|
201 | 204 | ignore_expiration=False, |
|
202 | 205 | ): |
|
203 | 206 | """Construct a new RelationshipCache. |
|
204 | 207 | |
|
205 | 208 | :param attribute: A Class.attribute which |
|
206 | 209 | indicates a particular class relationship() whose |
|
207 | 210 | lazy loader should be pulled from the cache. |
|
208 | 211 | |
|
209 | 212 | :param region: name of the cache region. |
|
210 | 213 | |
|
211 | 214 | :param cache_key: optional. A string cache key |
|
212 | 215 | that will serve as the key to the query, bypassing |
|
213 | 216 | the usual means of forming a key from the Query itself. |
|
214 | 217 | |
|
215 | 218 | """ |
|
216 | 219 | self.region = region |
|
217 | 220 | self.cache_key = cache_key |
|
218 | 221 | self.expiration_time = expiration_time |
|
219 | 222 | self.ignore_expiration = ignore_expiration |
|
220 | 223 | self._relationship_options = { |
|
221 | 224 | (attribute.property.parent.class_, attribute.property.key): self |
|
222 | 225 | } |
|
223 | 226 | |
|
224 | 227 | def _process_orm_context(self, orm_context): |
|
225 | 228 | current_path = orm_context.loader_strategy_path |
|
226 | 229 | |
|
227 | 230 | if current_path: |
|
228 | 231 | mapper, prop = current_path[-2:] |
|
229 | 232 | key = prop.key |
|
230 | 233 | |
|
231 | 234 | for cls in mapper.class_.__mro__: |
|
232 | 235 | if (cls, key) in self._relationship_options: |
|
233 | 236 | relationship_option = self._relationship_options[ |
|
234 | 237 | (cls, key) |
|
235 | 238 | ] |
|
236 | 239 | return relationship_option |
|
237 | 240 | |
|
238 | 241 | def and_(self, option): |
|
239 | 242 | """Chain another RelationshipCache option to this one. |
|
240 | 243 | |
|
241 | 244 | While many RelationshipCache objects can be specified on a single |
|
242 | 245 | Query separately, chaining them together allows for a more efficient |
|
243 | 246 | lookup during load. |
|
244 | 247 | |
|
245 | 248 | """ |
|
246 | 249 | self._relationship_options.update(option._relationship_options) |
|
247 | 250 | return self |
|
248 | 251 |
@@ -1,372 +1,372 b'' | |||
|
1 | 1 | |
|
2 | 2 | |
|
3 | 3 | # Copyright (C) 2016-2020 RhodeCode GmbH |
|
4 | 4 | # |
|
5 | 5 | # This program is free software: you can redistribute it and/or modify |
|
6 | 6 | # it under the terms of the GNU Affero General Public License, version 3 |
|
7 | 7 | # (only), as published by the Free Software Foundation. |
|
8 | 8 | # |
|
9 | 9 | # This program is distributed in the hope that it will be useful, |
|
10 | 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
11 | 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
12 | 12 | # GNU General Public License for more details. |
|
13 | 13 | # |
|
14 | 14 | # You should have received a copy of the GNU Affero General Public License |
|
15 | 15 | # along with this program. If not, see <http://www.gnu.org/licenses/>. |
|
16 | 16 | # |
|
17 | 17 | # This program is dual-licensed. If you wish to learn more about the |
|
18 | 18 | # RhodeCode Enterprise Edition, including its added features, Support services, |
|
19 | 19 | # and proprietary license terms, please see https://rhodecode.com/licenses/ |
|
20 | 20 | |
|
21 | 21 | import os |
|
22 | 22 | import itsdangerous |
|
23 | 23 | import logging |
|
24 | 24 | import requests |
|
25 | 25 | import datetime |
|
26 | 26 | |
|
27 | 27 | from dogpile.util.readwrite_lock import ReadWriteMutex |
|
28 | from pyramid.threadlocal import get_current_registry | |
|
29 | 28 | |
|
30 | 29 | import rhodecode.lib.helpers as h |
|
31 | 30 | from rhodecode.lib.auth import HasRepoPermissionAny |
|
32 | 31 | from rhodecode.lib.ext_json import json |
|
33 | 32 | from rhodecode.model.db import User |
|
34 | 33 | from rhodecode.lib.str_utils import ascii_str |
|
35 | 34 | from rhodecode.lib.hash_utils import sha1_safe |
|
36 | 35 | |
|
37 | 36 | log = logging.getLogger(__name__) |
|
38 | 37 | |
|
39 | 38 | LOCK = ReadWriteMutex() |
|
40 | 39 | |
|
41 | 40 | USER_STATE_PUBLIC_KEYS = [ |
|
42 | 41 | 'id', 'username', 'first_name', 'last_name', |
|
43 | 42 | 'icon_link', 'display_name', 'display_link'] |
|
44 | 43 | |
|
45 | 44 | |
|
46 | 45 | class ChannelstreamException(Exception): |
|
47 | 46 | pass |
|
48 | 47 | |
|
49 | 48 | |
|
50 | 49 | class ChannelstreamConnectionException(ChannelstreamException): |
|
51 | 50 | pass |
|
52 | 51 | |
|
53 | 52 | |
|
54 | 53 | class ChannelstreamPermissionException(ChannelstreamException): |
|
55 | 54 | pass |
|
56 | 55 | |
|
57 | 56 | |
|
58 | 57 | def get_channelstream_server_url(config, endpoint): |
|
59 | 58 | return 'http://{}{}'.format(config['server'], endpoint) |
|
60 | 59 | |
|
61 | 60 | |
|
62 | 61 | def channelstream_request(config, payload, endpoint, raise_exc=True): |
|
63 | 62 | signer = itsdangerous.TimestampSigner(config['secret']) |
|
64 | 63 | sig_for_server = signer.sign(endpoint) |
|
65 | 64 | secret_headers = {'x-channelstream-secret': sig_for_server, |
|
66 | 65 | 'x-channelstream-endpoint': endpoint, |
|
67 | 66 | 'Content-Type': 'application/json'} |
|
68 | 67 | req_url = get_channelstream_server_url(config, endpoint) |
|
69 | 68 | |
|
70 | 69 | log.debug('Sending a channelstream request to endpoint: `%s`', req_url) |
|
71 | 70 | response = None |
|
72 | 71 | try: |
|
73 | 72 | response = requests.post(req_url, data=json.dumps(payload), |
|
74 | 73 | headers=secret_headers).json() |
|
75 | 74 | except requests.ConnectionError: |
|
76 | 75 | log.exception('ConnectionError occurred for endpoint %s', req_url) |
|
77 | 76 | if raise_exc: |
|
78 | 77 | raise ChannelstreamConnectionException(req_url) |
|
79 | 78 | except Exception: |
|
80 | 79 | log.exception('Exception related to Channelstream happened') |
|
81 | 80 | if raise_exc: |
|
82 | 81 | raise ChannelstreamConnectionException() |
|
83 | 82 | log.debug('Got channelstream response: %s', response) |
|
84 | 83 | return response |
|
85 | 84 | |
|
86 | 85 | |
|
87 | 86 | def get_user_data(user_id): |
|
88 | 87 | user = User.get(user_id) |
|
89 | 88 | return { |
|
90 | 89 | 'id': user.user_id, |
|
91 | 90 | 'username': user.username, |
|
92 | 91 | 'first_name': user.first_name, |
|
93 | 92 | 'last_name': user.last_name, |
|
94 | 93 | 'icon_link': h.gravatar_url(user.email, 60), |
|
95 | 94 | 'display_name': h.person(user, 'username_or_name_or_email'), |
|
96 | 95 | 'display_link': h.link_to_user(user), |
|
97 | 96 | 'notifications': user.user_data.get('notification_status', True) |
|
98 | 97 | } |
|
99 | 98 | |
|
100 | 99 | |
|
101 | 100 | def broadcast_validator(channel_name): |
|
102 | 101 | """ checks if user can access the broadcast channel """ |
|
103 | 102 | if channel_name == 'broadcast': |
|
104 | 103 | return True |
|
105 | 104 | |
|
106 | 105 | |
|
107 | 106 | def repo_validator(channel_name): |
|
108 | 107 | """ checks if user can access the broadcast channel """ |
|
109 | 108 | channel_prefix = '/repo$' |
|
110 | 109 | if channel_name.startswith(channel_prefix): |
|
111 | 110 | elements = channel_name[len(channel_prefix):].split('$') |
|
112 | 111 | repo_name = elements[0] |
|
113 | 112 | can_access = HasRepoPermissionAny( |
|
114 | 113 | 'repository.read', |
|
115 | 114 | 'repository.write', |
|
116 | 115 | 'repository.admin')(repo_name) |
|
117 | 116 | log.debug( |
|
118 | 117 | 'permission check for %s channel resulted in %s', |
|
119 | 118 | repo_name, can_access) |
|
120 | 119 | if can_access: |
|
121 | 120 | return True |
|
122 | 121 | return False |
|
123 | 122 | |
|
124 | 123 | |
|
125 | 124 | def check_channel_permissions(channels, plugin_validators, should_raise=True): |
|
126 | 125 | valid_channels = [] |
|
127 | 126 | |
|
128 | 127 | validators = [broadcast_validator, repo_validator] |
|
129 | 128 | if plugin_validators: |
|
130 | 129 | validators.extend(plugin_validators) |
|
131 | 130 | for channel_name in channels: |
|
132 | 131 | is_valid = False |
|
133 | 132 | for validator in validators: |
|
134 | 133 | if validator(channel_name): |
|
135 | 134 | is_valid = True |
|
136 | 135 | break |
|
137 | 136 | if is_valid: |
|
138 | 137 | valid_channels.append(channel_name) |
|
139 | 138 | else: |
|
140 | 139 | if should_raise: |
|
141 | 140 | raise ChannelstreamPermissionException() |
|
142 | 141 | return valid_channels |
|
143 | 142 | |
|
144 | 143 | |
|
145 | 144 | def get_channels_info(self, channels): |
|
146 | 145 | payload = {'channels': channels} |
|
147 | 146 | # gather persistence info |
|
148 | 147 | return channelstream_request(self._config(), payload, '/info') |
|
149 | 148 | |
|
150 | 149 | |
|
151 | 150 | def parse_channels_info(info_result, include_channel_info=None): |
|
152 | 151 | """ |
|
153 | 152 | Returns data that contains only secure information that can be |
|
154 | 153 | presented to clients |
|
155 | 154 | """ |
|
156 | 155 | include_channel_info = include_channel_info or [] |
|
157 | 156 | |
|
158 | 157 | user_state_dict = {} |
|
159 | 158 | for userinfo in info_result['users']: |
|
160 | 159 | user_state_dict[userinfo['user']] = { |
|
161 | 160 | k: v for k, v in list(userinfo['state'].items()) |
|
162 | 161 | if k in USER_STATE_PUBLIC_KEYS |
|
163 | 162 | } |
|
164 | 163 | |
|
165 | 164 | channels_info = {} |
|
166 | 165 | |
|
167 | 166 | for c_name, c_info in list(info_result['channels'].items()): |
|
168 | 167 | if c_name not in include_channel_info: |
|
169 | 168 | continue |
|
170 | 169 | connected_list = [] |
|
171 | 170 | for username in c_info['users']: |
|
172 | 171 | connected_list.append({ |
|
173 | 172 | 'user': username, |
|
174 | 173 | 'state': user_state_dict[username] |
|
175 | 174 | }) |
|
176 | 175 | channels_info[c_name] = {'users': connected_list, |
|
177 | 176 | 'history': c_info['history']} |
|
178 | 177 | |
|
179 | 178 | return channels_info |
|
180 | 179 | |
|
181 | 180 | |
|
182 | 181 | def log_filepath(history_location, channel_name): |
|
183 | 182 | |
|
184 | 183 | channel_hash = ascii_str(sha1_safe(channel_name)) |
|
185 | 184 | filename = f'{channel_hash}.log' |
|
186 | 185 | filepath = os.path.join(history_location, filename) |
|
187 | 186 | return filepath |
|
188 | 187 | |
|
189 | 188 | |
|
190 | 189 | def read_history(history_location, channel_name): |
|
191 | 190 | filepath = log_filepath(history_location, channel_name) |
|
192 | 191 | if not os.path.exists(filepath): |
|
193 | 192 | return [] |
|
194 | 193 | history_lines_limit = -100 |
|
195 | 194 | history = [] |
|
196 | 195 | with open(filepath, 'rb') as f: |
|
197 | 196 | for line in f.readlines()[history_lines_limit:]: |
|
198 | 197 | try: |
|
199 | 198 | history.append(json.loads(line)) |
|
200 | 199 | except Exception: |
|
201 | 200 | log.exception('Failed to load history') |
|
202 | 201 | return history |
|
203 | 202 | |
|
204 | 203 | |
|
205 | 204 | def update_history_from_logs(config, channels, payload): |
|
206 | 205 | history_location = config.get('history.location') |
|
207 | 206 | for channel in channels: |
|
208 | 207 | history = read_history(history_location, channel) |
|
209 | 208 | payload['channels_info'][channel]['history'] = history |
|
210 | 209 | |
|
211 | 210 | |
|
212 | 211 | def write_history(config, message): |
|
213 | 212 | """ writes a message to a base64encoded filename """ |
|
214 | 213 | history_location = config.get('history.location') |
|
215 | 214 | if not os.path.exists(history_location): |
|
216 | 215 | return |
|
217 | 216 | try: |
|
218 | 217 | LOCK.acquire_write_lock() |
|
219 | 218 | filepath = log_filepath(history_location, message['channel']) |
|
220 | 219 | json_message = json.dumps(message) |
|
221 | 220 | with open(filepath, 'ab') as f: |
|
222 | 221 | f.write(json_message) |
|
223 | 222 | f.write('\n') |
|
224 | 223 | finally: |
|
225 | 224 | LOCK.release_write_lock() |
|
226 | 225 | |
|
227 | 226 | |
|
228 | 227 | def get_connection_validators(registry): |
|
229 | 228 | validators = [] |
|
230 | 229 | for k, config in list(registry.rhodecode_plugins.items()): |
|
231 | 230 | validator = config.get('channelstream', {}).get('connect_validator') |
|
232 | 231 | if validator: |
|
233 | 232 | validators.append(validator) |
|
234 | 233 | return validators |
|
235 | 234 | |
|
236 | 235 | |
|
237 | 236 | def get_channelstream_config(registry=None): |
|
238 | 237 | if not registry: |
|
238 | from pyramid.threadlocal import get_current_registry | |
|
239 | 239 | registry = get_current_registry() |
|
240 | 240 | |
|
241 | 241 | rhodecode_plugins = getattr(registry, 'rhodecode_plugins', {}) |
|
242 | 242 | channelstream_config = rhodecode_plugins.get('channelstream', {}) |
|
243 | 243 | return channelstream_config |
|
244 | 244 | |
|
245 | 245 | |
|
246 | 246 | def post_message(channel, message, username, registry=None): |
|
247 | 247 | channelstream_config = get_channelstream_config(registry) |
|
248 | 248 | if not channelstream_config.get('enabled'): |
|
249 | 249 | return |
|
250 | 250 | |
|
251 | 251 | message_obj = message |
|
252 | 252 | if isinstance(message, str): |
|
253 | 253 | message_obj = { |
|
254 | 254 | 'message': message, |
|
255 | 255 | 'level': 'success', |
|
256 | 256 | 'topic': '/notifications' |
|
257 | 257 | } |
|
258 | 258 | |
|
259 | 259 | log.debug('Channelstream: sending notification to channel %s', channel) |
|
260 | 260 | payload = { |
|
261 | 261 | 'type': 'message', |
|
262 | 262 | 'timestamp': datetime.datetime.utcnow(), |
|
263 | 263 | 'user': 'system', |
|
264 | 264 | 'exclude_users': [username], |
|
265 | 265 | 'channel': channel, |
|
266 | 266 | 'message': message_obj |
|
267 | 267 | } |
|
268 | 268 | |
|
269 | 269 | try: |
|
270 | 270 | return channelstream_request( |
|
271 | 271 | channelstream_config, [payload], '/message', |
|
272 | 272 | raise_exc=False) |
|
273 | 273 | except ChannelstreamException: |
|
274 | 274 | log.exception('Failed to send channelstream data') |
|
275 | 275 | raise |
|
276 | 276 | |
|
277 | 277 | |
|
278 | 278 | def _reload_link(label): |
|
279 | 279 | return ( |
|
280 | 280 | '<a onclick="window.location.reload()">' |
|
281 | 281 | '<strong>{}</strong>' |
|
282 | 282 | '</a>'.format(label) |
|
283 | 283 | ) |
|
284 | 284 | |
|
285 | 285 | |
|
286 | 286 | def pr_channel(pull_request): |
|
287 | 287 | repo_name = pull_request.target_repo.repo_name |
|
288 | 288 | pull_request_id = pull_request.pull_request_id |
|
289 | 289 | channel = '/repo${}$/pr/{}'.format(repo_name, pull_request_id) |
|
290 | 290 | log.debug('Getting pull-request channelstream broadcast channel: %s', channel) |
|
291 | 291 | return channel |
|
292 | 292 | |
|
293 | 293 | |
|
294 | 294 | def comment_channel(repo_name, commit_obj=None, pull_request_obj=None): |
|
295 | 295 | channel = None |
|
296 | 296 | if commit_obj: |
|
297 | 297 | channel = '/repo${}$/commit/{}'.format( |
|
298 | 298 | repo_name, commit_obj.raw_id |
|
299 | 299 | ) |
|
300 | 300 | elif pull_request_obj: |
|
301 | 301 | channel = '/repo${}$/pr/{}'.format( |
|
302 | 302 | repo_name, pull_request_obj.pull_request_id |
|
303 | 303 | ) |
|
304 | 304 | log.debug('Getting comment channelstream broadcast channel: %s', channel) |
|
305 | 305 | |
|
306 | 306 | return channel |
|
307 | 307 | |
|
308 | 308 | |
|
309 | 309 | def pr_update_channelstream_push(request, pr_broadcast_channel, user, msg, **kwargs): |
|
310 | 310 | """ |
|
311 | 311 | Channel push on pull request update |
|
312 | 312 | """ |
|
313 | 313 | if not pr_broadcast_channel: |
|
314 | 314 | return |
|
315 | 315 | |
|
316 | 316 | _ = request.translate |
|
317 | 317 | |
|
318 | 318 | message = '{} {}'.format( |
|
319 | 319 | msg, |
|
320 | 320 | _reload_link(_(' Reload page to load changes'))) |
|
321 | 321 | |
|
322 | 322 | message_obj = { |
|
323 | 323 | 'message': message, |
|
324 | 324 | 'level': 'success', |
|
325 | 325 | 'topic': '/notifications' |
|
326 | 326 | } |
|
327 | 327 | |
|
328 | 328 | post_message( |
|
329 | 329 | pr_broadcast_channel, message_obj, user.username, |
|
330 | 330 | registry=request.registry) |
|
331 | 331 | |
|
332 | 332 | |
|
333 | 333 | def comment_channelstream_push(request, comment_broadcast_channel, user, msg, **kwargs): |
|
334 | 334 | """ |
|
335 | 335 | Channelstream push on comment action, on commit, or pull-request |
|
336 | 336 | """ |
|
337 | 337 | if not comment_broadcast_channel: |
|
338 | 338 | return |
|
339 | 339 | |
|
340 | 340 | _ = request.translate |
|
341 | 341 | |
|
342 | 342 | comment_data = kwargs.pop('comment_data', {}) |
|
343 | 343 | user_data = kwargs.pop('user_data', {}) |
|
344 | 344 | comment_id = list(comment_data.keys())[0] if comment_data else '' |
|
345 | 345 | |
|
346 | 346 | message = '<strong>{}</strong> {} #{}'.format( |
|
347 | 347 | user.username, |
|
348 | 348 | msg, |
|
349 | 349 | comment_id, |
|
350 | 350 | ) |
|
351 | 351 | |
|
352 | 352 | message_obj = { |
|
353 | 353 | 'message': message, |
|
354 | 354 | 'level': 'success', |
|
355 | 355 | 'topic': '/notifications' |
|
356 | 356 | } |
|
357 | 357 | |
|
358 | 358 | post_message( |
|
359 | 359 | comment_broadcast_channel, message_obj, user.username, |
|
360 | 360 | registry=request.registry) |
|
361 | 361 | |
|
362 | 362 | message_obj = { |
|
363 | 363 | 'message': None, |
|
364 | 364 | 'user': user.username, |
|
365 | 365 | 'comment_id': comment_id, |
|
366 | 366 | 'comment_data': comment_data, |
|
367 | 367 | 'user_data': user_data, |
|
368 | 368 | 'topic': '/comment' |
|
369 | 369 | } |
|
370 | 370 | post_message( |
|
371 | 371 | comment_broadcast_channel, message_obj, user.username, |
|
372 | 372 | registry=request.registry) |
@@ -1,797 +1,819 b'' | |||
|
1 | 1 | |
|
2 | 2 | |
|
3 | 3 | # Copyright (C) 2011-2020 RhodeCode GmbH |
|
4 | 4 | # |
|
5 | 5 | # This program is free software: you can redistribute it and/or modify |
|
6 | 6 | # it under the terms of the GNU Affero General Public License, version 3 |
|
7 | 7 | # (only), as published by the Free Software Foundation. |
|
8 | 8 | # |
|
9 | 9 | # This program is distributed in the hope that it will be useful, |
|
10 | 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
11 | 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
12 | 12 | # GNU General Public License for more details. |
|
13 | 13 | # |
|
14 | 14 | # You should have received a copy of the GNU Affero General Public License |
|
15 | 15 | # along with this program. If not, see <http://www.gnu.org/licenses/>. |
|
16 | 16 | # |
|
17 | 17 | # This program is dual-licensed. If you wish to learn more about the |
|
18 | 18 | # RhodeCode Enterprise Edition, including its added features, Support services, |
|
19 | 19 | # and proprietary license terms, please see https://rhodecode.com/licenses/ |
|
20 | 20 | |
|
21 | 21 | import logging |
|
22 | 22 | import difflib |
|
23 | from itertools import groupby | |
|
23 | import itertools | |
|
24 | 24 | |
|
25 | 25 | from pygments import lex |
|
26 | 26 | from pygments.formatters.html import _get_ttype_class as pygment_token_class |
|
27 | 27 | from pygments.lexers.special import TextLexer, Token |
|
28 | 28 | from pygments.lexers import get_lexer_by_name |
|
29 | 29 | |
|
30 | 30 | from rhodecode.lib.helpers import ( |
|
31 | 31 | get_lexer_for_filenode, html_escape, get_custom_lexer) |
|
32 | from rhodecode.lib.utils2 import AttributeDict, StrictAttributeDict, safe_unicode | |
|
32 | from rhodecode.lib.str_utils import safe_str | |
|
33 | from rhodecode.lib.utils2 import AttributeDict, StrictAttributeDict | |
|
33 | 34 | from rhodecode.lib.vcs.nodes import FileNode |
|
34 |
from rhodecode.lib.vcs.exceptions import |
|
|
35 | from rhodecode.lib.vcs.exceptions import NodeDoesNotExistError | |
|
35 | 36 | from rhodecode.lib.diff_match_patch import diff_match_patch |
|
36 | 37 | from rhodecode.lib.diffs import LimitedDiffContainer, DEL_FILENODE, BIN_FILENODE |
|
37 | 38 | |
|
38 | 39 | |
|
39 | 40 | plain_text_lexer = get_lexer_by_name( |
|
40 | 41 | 'text', stripall=False, stripnl=False, ensurenl=False) |
|
41 | 42 | |
|
42 | 43 | |
|
43 | 44 | log = logging.getLogger(__name__) |
|
44 | 45 | |
|
45 | 46 | |
|
46 | 47 | def filenode_as_lines_tokens(filenode, lexer=None): |
|
47 | 48 | org_lexer = lexer |
|
48 | 49 | lexer = lexer or get_lexer_for_filenode(filenode) |
|
49 | log.debug('Generating file node pygment tokens for %s, %s, org_lexer:%s', | |
|
50 | log.debug('Generating file node pygment tokens for %s, file=`%s`, org_lexer:%s', | |
|
50 | 51 | lexer, filenode, org_lexer) |
|
51 | content = filenode.content | |
|
52 | content = filenode.str_content | |
|
52 | 53 | tokens = tokenize_string(content, lexer) |
|
53 | 54 | lines = split_token_stream(tokens, content) |
|
54 | 55 | rv = list(lines) |
|
55 | 56 | return rv |
|
56 | 57 | |
|
57 | 58 | |
|
58 | 59 | def tokenize_string(content, lexer): |
|
59 | 60 | """ |
|
60 | 61 | Use pygments to tokenize some content based on a lexer |
|
61 | 62 | ensuring all original new lines and whitespace is preserved |
|
62 | 63 | """ |
|
63 | 64 | |
|
64 | 65 | lexer.stripall = False |
|
65 | 66 | lexer.stripnl = False |
|
66 | 67 | lexer.ensurenl = False |
|
67 | 68 | |
|
69 | # pygments needs to operate on str | |
|
70 | str_content = safe_str(content) | |
|
71 | ||
|
68 | 72 | if isinstance(lexer, TextLexer): |
|
69 | lexed = [(Token.Text, content)] | |
|
73 | # we convert content here to STR because pygments does that while tokenizing | |
|
74 | # if we DON'T get a lexer for unknown file type | |
|
75 | lexed = [(Token.Text, str_content)] | |
|
70 | 76 | else: |
|
71 | lexed = lex(content, lexer) | |
|
77 | lexed = lex(str_content, lexer) | |
|
72 | 78 | |
|
73 | 79 | for token_type, token_text in lexed: |
|
74 | 80 | yield pygment_token_class(token_type), token_text |
|
75 | 81 | |
|
76 | 82 | |
|
77 | 83 | def split_token_stream(tokens, content): |
|
78 | 84 | """ |
|
79 | 85 | Take a list of (TokenType, text) tuples and split them by a string |
|
80 | 86 | |
|
81 | 87 | split_token_stream([(TEXT, 'some\ntext'), (TEXT, 'more\n')]) |
|
82 | 88 | [(TEXT, 'some'), (TEXT, 'text'), |
|
83 | 89 | (TEXT, 'more'), (TEXT, 'text')] |
|
84 | 90 | """ |
|
85 | 91 | |
|
86 | 92 | token_buffer = [] |
|
93 | ||
|
87 | 94 | for token_class, token_text in tokens: |
|
95 | ||
|
96 | # token_text, should be str | |
|
88 | 97 | parts = token_text.split('\n') |
|
89 | 98 | for part in parts[:-1]: |
|
90 | 99 | token_buffer.append((token_class, part)) |
|
91 | 100 | yield token_buffer |
|
92 | 101 | token_buffer = [] |
|
93 | 102 | |
|
94 | 103 | token_buffer.append((token_class, parts[-1])) |
|
95 | 104 | |
|
96 | 105 | if token_buffer: |
|
97 | 106 | yield token_buffer |
|
98 | 107 | elif content: |
|
99 | 108 | # this is a special case, we have the content, but tokenization didn't produce |
|
100 |
# any results. T |
|
|
109 | # any results. This can happen if know file extensions like .css have some bogus | |
|
101 | 110 | # unicode content without any newline characters |
|
102 | 111 | yield [(pygment_token_class(Token.Text), content)] |
|
103 | 112 | |
|
104 | 113 | |
|
105 | 114 | def filenode_as_annotated_lines_tokens(filenode): |
|
106 | 115 | """ |
|
107 | 116 | Take a file node and return a list of annotations => lines, if no annotation |
|
108 | 117 | is found, it will be None. |
|
109 | 118 | |
|
110 | 119 | eg: |
|
111 | 120 | |
|
112 | 121 | [ |
|
113 | 122 | (annotation1, [ |
|
114 | 123 | (1, line1_tokens_list), |
|
115 | 124 | (2, line2_tokens_list), |
|
116 | 125 | ]), |
|
117 | 126 | (annotation2, [ |
|
118 | 127 | (3, line1_tokens_list), |
|
119 | 128 | ]), |
|
120 | 129 | (None, [ |
|
121 | 130 | (4, line1_tokens_list), |
|
122 | 131 | ]), |
|
123 | 132 | (annotation1, [ |
|
124 | 133 | (5, line1_tokens_list), |
|
125 | 134 | (6, line2_tokens_list), |
|
126 | 135 | ]) |
|
127 | 136 | ] |
|
128 | 137 | """ |
|
129 | 138 | |
|
130 | 139 | commit_cache = {} # cache commit_getter lookups |
|
131 | 140 | |
|
132 | 141 | def _get_annotation(commit_id, commit_getter): |
|
133 | 142 | if commit_id not in commit_cache: |
|
134 | 143 | commit_cache[commit_id] = commit_getter() |
|
135 | 144 | return commit_cache[commit_id] |
|
136 | 145 | |
|
137 | 146 | annotation_lookup = { |
|
138 | 147 | line_no: _get_annotation(commit_id, commit_getter) |
|
139 | 148 | for line_no, commit_id, commit_getter, line_content |
|
140 | 149 | in filenode.annotate |
|
141 | 150 | } |
|
142 | 151 | |
|
143 | 152 | annotations_lines = ((annotation_lookup.get(line_no), line_no, tokens) |
|
144 | 153 | for line_no, tokens |
|
145 | 154 | in enumerate(filenode_as_lines_tokens(filenode), 1)) |
|
146 | 155 | |
|
147 | grouped_annotations_lines = groupby(annotations_lines, lambda x: x[0]) | |
|
156 | grouped_annotations_lines = itertools.groupby(annotations_lines, lambda x: x[0]) | |
|
148 | 157 | |
|
149 | 158 | for annotation, group in grouped_annotations_lines: |
|
150 | 159 | yield ( |
|
151 | 160 | annotation, [(line_no, tokens) |
|
152 | 161 | for (_, line_no, tokens) in group] |
|
153 | 162 | ) |
|
154 | 163 | |
|
155 | 164 | |
|
156 | 165 | def render_tokenstream(tokenstream): |
|
157 | 166 | result = [] |
|
158 | 167 | for token_class, token_ops_texts in rollup_tokenstream(tokenstream): |
|
159 | 168 | |
|
160 | 169 | if token_class: |
|
161 |
result.append('<span class=" |
|
|
170 | result.append(f'<span class="{token_class}">') | |
|
162 | 171 | else: |
|
163 | 172 | result.append('<span>') |
|
164 | 173 | |
|
165 | 174 | for op_tag, token_text in token_ops_texts: |
|
166 | 175 | |
|
167 | 176 | if op_tag: |
|
168 |
result.append('< |
|
|
177 | result.append(f'<{op_tag}>') | |
|
169 | 178 | |
|
170 | 179 | # NOTE(marcink): in some cases of mixed encodings, we might run into |
|
171 | 180 | # troubles in the html_escape, in this case we say unicode force on token_text |
|
172 | 181 | # that would ensure "correct" data even with the cost of rendered |
|
173 | 182 | try: |
|
174 | 183 | escaped_text = html_escape(token_text) |
|
175 | 184 | except TypeError: |
|
176 |
escaped_text = html_escape(safe_ |
|
|
185 | escaped_text = html_escape(safe_str(token_text)) | |
|
177 | 186 | |
|
178 | 187 | # TODO: dan: investigate showing hidden characters like space/nl/tab |
|
179 | 188 | # escaped_text = escaped_text.replace(' ', '<sp> </sp>') |
|
180 | 189 | # escaped_text = escaped_text.replace('\n', '<nl>\n</nl>') |
|
181 | 190 | # escaped_text = escaped_text.replace('\t', '<tab>\t</tab>') |
|
182 | 191 | |
|
183 | 192 | result.append(escaped_text) |
|
184 | 193 | |
|
185 | 194 | if op_tag: |
|
186 |
result.append('</ |
|
|
195 | result.append(f'</{op_tag}>') | |
|
187 | 196 | |
|
188 | 197 | result.append('</span>') |
|
189 | 198 | |
|
190 | 199 | html = ''.join(result) |
|
191 | 200 | return html |
|
192 | 201 | |
|
193 | 202 | |
|
194 | 203 | def rollup_tokenstream(tokenstream): |
|
195 | 204 | """ |
|
196 | 205 | Group a token stream of the format: |
|
197 | 206 | |
|
198 | 207 | ('class', 'op', 'text') |
|
199 | 208 | or |
|
200 | 209 | ('class', 'text') |
|
201 | 210 | |
|
202 | 211 | into |
|
203 | 212 | |
|
204 | 213 | [('class1', |
|
205 | 214 | [('op1', 'text'), |
|
206 | 215 | ('op2', 'text')]), |
|
207 | 216 | ('class2', |
|
208 | 217 | [('op3', 'text')])] |
|
209 | 218 | |
|
210 | 219 | This is used to get the minimal tags necessary when |
|
211 | 220 | rendering to html eg for a token stream ie. |
|
212 | 221 | |
|
213 | 222 | <span class="A"><ins>he</ins>llo</span> |
|
214 | 223 | vs |
|
215 | 224 | <span class="A"><ins>he</ins></span><span class="A">llo</span> |
|
216 | 225 | |
|
217 | 226 | If a 2 tuple is passed in, the output op will be an empty string. |
|
218 | 227 | |
|
219 | 228 | eg: |
|
220 | 229 | |
|
221 | 230 | >>> rollup_tokenstream([('classA', '', 'h'), |
|
222 | 231 | ('classA', 'del', 'ell'), |
|
223 | 232 | ('classA', '', 'o'), |
|
224 | 233 | ('classB', '', ' '), |
|
225 | 234 | ('classA', '', 'the'), |
|
226 | 235 | ('classA', '', 're'), |
|
227 | 236 | ]) |
|
228 | 237 | |
|
229 | 238 | [('classA', [('', 'h'), ('del', 'ell'), ('', 'o')], |
|
230 | 239 | ('classB', [('', ' ')], |
|
231 | 240 | ('classA', [('', 'there')]] |
|
232 | 241 | |
|
233 | 242 | """ |
|
234 | 243 | if tokenstream and len(tokenstream[0]) == 2: |
|
235 | 244 | tokenstream = ((t[0], '', t[1]) for t in tokenstream) |
|
236 | 245 | |
|
237 | 246 | result = [] |
|
238 | for token_class, op_list in groupby(tokenstream, lambda t: t[0]): | |
|
247 | for token_class, op_list in itertools.groupby(tokenstream, lambda t: t[0]): | |
|
239 | 248 | ops = [] |
|
240 | for token_op, token_text_list in groupby(op_list, lambda o: o[1]): | |
|
249 | for token_op, token_text_list in itertools.groupby(op_list, lambda o: o[1]): | |
|
241 | 250 | text_buffer = [] |
|
242 | 251 | for t_class, t_op, t_text in token_text_list: |
|
243 | 252 | text_buffer.append(t_text) |
|
253 | ||
|
244 | 254 | ops.append((token_op, ''.join(text_buffer))) |
|
245 | 255 | result.append((token_class, ops)) |
|
246 | 256 | return result |
|
247 | 257 | |
|
248 | 258 | |
|
249 | 259 | def tokens_diff(old_tokens, new_tokens, use_diff_match_patch=True): |
|
250 | 260 | """ |
|
251 | 261 | Converts a list of (token_class, token_text) tuples to a list of |
|
252 | 262 | (token_class, token_op, token_text) tuples where token_op is one of |
|
253 | 263 | ('ins', 'del', '') |
|
254 | 264 | |
|
255 | 265 | :param old_tokens: list of (token_class, token_text) tuples of old line |
|
256 | 266 | :param new_tokens: list of (token_class, token_text) tuples of new line |
|
257 | 267 | :param use_diff_match_patch: boolean, will use google's diff match patch |
|
258 | 268 | library which has options to 'smooth' out the character by character |
|
259 | 269 | differences making nicer ins/del blocks |
|
260 | 270 | """ |
|
261 | 271 | |
|
262 | 272 | old_tokens_result = [] |
|
263 | 273 | new_tokens_result = [] |
|
264 | 274 | |
|
265 | similarity = difflib.SequenceMatcher(None, | |
|
275 | def int_convert(val): | |
|
276 | if isinstance(val, int): | |
|
277 | return str(val) | |
|
278 | return val | |
|
279 | ||
|
280 | similarity = difflib.SequenceMatcher( | |
|
281 | None, | |
|
266 | 282 | ''.join(token_text for token_class, token_text in old_tokens), |
|
267 | 283 | ''.join(token_text for token_class, token_text in new_tokens) |
|
268 | 284 | ).ratio() |
|
269 | 285 | |
|
270 | if similarity < 0.6: # return, the blocks are too different | |
|
286 | if similarity < 0.6: # return, the blocks are too different | |
|
271 | 287 | for token_class, token_text in old_tokens: |
|
272 | 288 | old_tokens_result.append((token_class, '', token_text)) |
|
273 | 289 | for token_class, token_text in new_tokens: |
|
274 | 290 | new_tokens_result.append((token_class, '', token_text)) |
|
275 | 291 | return old_tokens_result, new_tokens_result, similarity |
|
276 | 292 | |
|
277 |
token_sequence_matcher = difflib.SequenceMatcher( |
|
|
293 | token_sequence_matcher = difflib.SequenceMatcher( | |
|
294 | None, | |
|
278 | 295 | [x[1] for x in old_tokens], |
|
279 | 296 | [x[1] for x in new_tokens]) |
|
280 | 297 | |
|
281 | 298 | for tag, o1, o2, n1, n2 in token_sequence_matcher.get_opcodes(): |
|
282 |
# check the differences by token block types first to give a |
|
|
299 | # check the differences by token block types first to give a | |
|
283 | 300 | # nicer "block" level replacement vs character diffs |
|
284 | 301 | |
|
285 | 302 | if tag == 'equal': |
|
286 | 303 | for token_class, token_text in old_tokens[o1:o2]: |
|
287 | 304 | old_tokens_result.append((token_class, '', token_text)) |
|
288 | 305 | for token_class, token_text in new_tokens[n1:n2]: |
|
289 | 306 | new_tokens_result.append((token_class, '', token_text)) |
|
290 | 307 | elif tag == 'delete': |
|
291 | 308 | for token_class, token_text in old_tokens[o1:o2]: |
|
292 | old_tokens_result.append((token_class, 'del', token_text)) | |
|
309 | old_tokens_result.append((token_class, 'del', int_convert(token_text))) | |
|
293 | 310 | elif tag == 'insert': |
|
294 | 311 | for token_class, token_text in new_tokens[n1:n2]: |
|
295 | new_tokens_result.append((token_class, 'ins', token_text)) | |
|
312 | new_tokens_result.append((token_class, 'ins', int_convert(token_text))) | |
|
296 | 313 | elif tag == 'replace': |
|
297 | 314 | # if same type token blocks must be replaced, do a diff on the |
|
298 | 315 | # characters in the token blocks to show individual changes |
|
299 | 316 | |
|
300 | 317 | old_char_tokens = [] |
|
301 | 318 | new_char_tokens = [] |
|
302 | 319 | for token_class, token_text in old_tokens[o1:o2]: |
|
303 | for char in token_text: | |
|
320 | for char in map(lambda i: i, token_text): | |
|
304 | 321 | old_char_tokens.append((token_class, char)) |
|
305 | 322 | |
|
306 | 323 | for token_class, token_text in new_tokens[n1:n2]: |
|
307 | for char in token_text: | |
|
324 | for char in map(lambda i: i, token_text): | |
|
308 | 325 | new_char_tokens.append((token_class, char)) |
|
309 | 326 | |
|
310 | 327 | old_string = ''.join([token_text for |
|
311 | token_class, token_text in old_char_tokens]) | |
|
328 | token_class, token_text in old_char_tokens]) | |
|
312 | 329 | new_string = ''.join([token_text for |
|
313 | token_class, token_text in new_char_tokens]) | |
|
330 | token_class, token_text in new_char_tokens]) | |
|
314 | 331 | |
|
315 | 332 | char_sequence = difflib.SequenceMatcher( |
|
316 | 333 | None, old_string, new_string) |
|
317 | 334 | copcodes = char_sequence.get_opcodes() |
|
318 | 335 | obuffer, nbuffer = [], [] |
|
319 | 336 | |
|
320 | 337 | if use_diff_match_patch: |
|
321 | 338 | dmp = diff_match_patch() |
|
322 | 339 | dmp.Diff_EditCost = 11 # TODO: dan: extract this to a setting |
|
323 | 340 | reps = dmp.diff_main(old_string, new_string) |
|
324 | 341 | dmp.diff_cleanupEfficiency(reps) |
|
325 | 342 | |
|
326 | 343 | a, b = 0, 0 |
|
327 | 344 | for op, rep in reps: |
|
328 | 345 | l = len(rep) |
|
329 | 346 | if op == 0: |
|
330 | 347 | for i, c in enumerate(rep): |
|
331 | 348 | obuffer.append((old_char_tokens[a+i][0], '', c)) |
|
332 | 349 | nbuffer.append((new_char_tokens[b+i][0], '', c)) |
|
333 | 350 | a += l |
|
334 | 351 | b += l |
|
335 | 352 | elif op == -1: |
|
336 | 353 | for i, c in enumerate(rep): |
|
337 | obuffer.append((old_char_tokens[a+i][0], 'del', c)) | |
|
354 | obuffer.append((old_char_tokens[a+i][0], 'del', int_convert(c))) | |
|
338 | 355 | a += l |
|
339 | 356 | elif op == 1: |
|
340 | 357 | for i, c in enumerate(rep): |
|
341 | nbuffer.append((new_char_tokens[b+i][0], 'ins', c)) | |
|
358 | nbuffer.append((new_char_tokens[b+i][0], 'ins', int_convert(c))) | |
|
342 | 359 | b += l |
|
343 | 360 | else: |
|
344 | 361 | for ctag, co1, co2, cn1, cn2 in copcodes: |
|
345 | 362 | if ctag == 'equal': |
|
346 | 363 | for token_class, token_text in old_char_tokens[co1:co2]: |
|
347 | 364 | obuffer.append((token_class, '', token_text)) |
|
348 | 365 | for token_class, token_text in new_char_tokens[cn1:cn2]: |
|
349 | 366 | nbuffer.append((token_class, '', token_text)) |
|
350 | 367 | elif ctag == 'delete': |
|
351 | 368 | for token_class, token_text in old_char_tokens[co1:co2]: |
|
352 | obuffer.append((token_class, 'del', token_text)) | |
|
369 | obuffer.append((token_class, 'del', int_convert(token_text))) | |
|
353 | 370 | elif ctag == 'insert': |
|
354 | 371 | for token_class, token_text in new_char_tokens[cn1:cn2]: |
|
355 | nbuffer.append((token_class, 'ins', token_text)) | |
|
372 | nbuffer.append((token_class, 'ins', int_convert(token_text))) | |
|
356 | 373 | elif ctag == 'replace': |
|
357 | 374 | for token_class, token_text in old_char_tokens[co1:co2]: |
|
358 | obuffer.append((token_class, 'del', token_text)) | |
|
375 | obuffer.append((token_class, 'del', int_convert(token_text))) | |
|
359 | 376 | for token_class, token_text in new_char_tokens[cn1:cn2]: |
|
360 | nbuffer.append((token_class, 'ins', token_text)) | |
|
377 | nbuffer.append((token_class, 'ins', int_convert(token_text))) | |
|
361 | 378 | |
|
362 | 379 | old_tokens_result.extend(obuffer) |
|
363 | 380 | new_tokens_result.extend(nbuffer) |
|
364 | 381 | |
|
365 | 382 | return old_tokens_result, new_tokens_result, similarity |
|
366 | 383 | |
|
367 | 384 | |
|
368 | 385 | def diffset_node_getter(commit): |
|
369 | def get_node(fname): | |
|
386 | def get_diff_node(file_name): | |
|
387 | ||
|
370 | 388 | try: |
|
371 | return commit.get_node(fname) | |
|
389 | return commit.get_node(file_name, pre_load=['size', 'flags', 'data']) | |
|
372 | 390 | except NodeDoesNotExistError: |
|
373 | 391 | return None |
|
374 | 392 | |
|
375 | return get_node | |
|
393 | return get_diff_node | |
|
376 | 394 | |
|
377 | 395 | |
|
378 | 396 | class DiffSet(object): |
|
379 | 397 | """ |
|
380 | 398 | An object for parsing the diff result from diffs.DiffProcessor and |
|
381 | 399 | adding highlighting, side by side/unified renderings and line diffs |
|
382 | 400 | """ |
|
383 | 401 | |
|
384 | 402 | HL_REAL = 'REAL' # highlights using original file, slow |
|
385 | 403 | HL_FAST = 'FAST' # highlights using just the line, fast but not correct |
|
386 | 404 | # in the case of multiline code |
|
387 | 405 | HL_NONE = 'NONE' # no highlighting, fastest |
|
388 | 406 | |
|
389 | 407 | def __init__(self, highlight_mode=HL_REAL, repo_name=None, |
|
390 | 408 | source_repo_name=None, |
|
391 | 409 | source_node_getter=lambda filename: None, |
|
392 | 410 | target_repo_name=None, |
|
393 | 411 | target_node_getter=lambda filename: None, |
|
394 | 412 | source_nodes=None, target_nodes=None, |
|
395 | 413 | # files over this size will use fast highlighting |
|
396 | 414 | max_file_size_limit=150 * 1024, |
|
397 | 415 | ): |
|
398 | 416 | |
|
399 | 417 | self.highlight_mode = highlight_mode |
|
400 | 418 | self.highlighted_filenodes = { |
|
401 | 419 | 'before': {}, |
|
402 | 420 | 'after': {} |
|
403 | 421 | } |
|
404 | 422 | self.source_node_getter = source_node_getter |
|
405 | 423 | self.target_node_getter = target_node_getter |
|
406 | 424 | self.source_nodes = source_nodes or {} |
|
407 | 425 | self.target_nodes = target_nodes or {} |
|
408 | 426 | self.repo_name = repo_name |
|
409 | 427 | self.target_repo_name = target_repo_name or repo_name |
|
410 | 428 | self.source_repo_name = source_repo_name or repo_name |
|
411 | 429 | self.max_file_size_limit = max_file_size_limit |
|
412 | 430 | |
|
413 | 431 | def render_patchset(self, patchset, source_ref=None, target_ref=None): |
|
414 | 432 | diffset = AttributeDict(dict( |
|
415 | 433 | lines_added=0, |
|
416 | 434 | lines_deleted=0, |
|
417 | 435 | changed_files=0, |
|
418 | 436 | files=[], |
|
419 | 437 | file_stats={}, |
|
420 | 438 | limited_diff=isinstance(patchset, LimitedDiffContainer), |
|
421 | 439 | repo_name=self.repo_name, |
|
422 | 440 | target_repo_name=self.target_repo_name, |
|
423 | 441 | source_repo_name=self.source_repo_name, |
|
424 | 442 | source_ref=source_ref, |
|
425 | 443 | target_ref=target_ref, |
|
426 | 444 | )) |
|
427 | 445 | for patch in patchset: |
|
428 | 446 | diffset.file_stats[patch['filename']] = patch['stats'] |
|
429 | 447 | filediff = self.render_patch(patch) |
|
430 | 448 | filediff.diffset = StrictAttributeDict(dict( |
|
431 | 449 | source_ref=diffset.source_ref, |
|
432 | 450 | target_ref=diffset.target_ref, |
|
433 | 451 | repo_name=diffset.repo_name, |
|
434 | 452 | source_repo_name=diffset.source_repo_name, |
|
435 | 453 | target_repo_name=diffset.target_repo_name, |
|
436 | 454 | )) |
|
437 | 455 | diffset.files.append(filediff) |
|
438 | 456 | diffset.changed_files += 1 |
|
439 | 457 | if not patch['stats']['binary']: |
|
440 | 458 | diffset.lines_added += patch['stats']['added'] |
|
441 | 459 | diffset.lines_deleted += patch['stats']['deleted'] |
|
442 | 460 | |
|
443 | 461 | return diffset |
|
444 | 462 | |
|
445 | 463 | _lexer_cache = {} |
|
446 | 464 | |
|
447 | 465 | def _get_lexer_for_filename(self, filename, filenode=None): |
|
448 | 466 | # cached because we might need to call it twice for source/target |
|
449 | 467 | if filename not in self._lexer_cache: |
|
450 | 468 | if filenode: |
|
451 | 469 | lexer = filenode.lexer |
|
452 | 470 | extension = filenode.extension |
|
453 | 471 | else: |
|
454 | 472 | lexer = FileNode.get_lexer(filename=filename) |
|
455 | 473 | extension = filename.split('.')[-1] |
|
456 | 474 | |
|
457 | 475 | lexer = get_custom_lexer(extension) or lexer |
|
458 | 476 | self._lexer_cache[filename] = lexer |
|
459 | 477 | return self._lexer_cache[filename] |
|
460 | 478 | |
|
461 | 479 | def render_patch(self, patch): |
|
462 | 480 | log.debug('rendering diff for %r', patch['filename']) |
|
463 | 481 | |
|
464 | 482 | source_filename = patch['original_filename'] |
|
465 | 483 | target_filename = patch['filename'] |
|
466 | 484 | |
|
467 | 485 | source_lexer = plain_text_lexer |
|
468 | 486 | target_lexer = plain_text_lexer |
|
469 | 487 | |
|
470 | 488 | if not patch['stats']['binary']: |
|
471 | 489 | node_hl_mode = self.HL_NONE if patch['chunks'] == [] else None |
|
472 | 490 | hl_mode = node_hl_mode or self.highlight_mode |
|
473 | 491 | |
|
474 | 492 | if hl_mode == self.HL_REAL: |
|
475 | 493 | if (source_filename and patch['operation'] in ('D', 'M') |
|
476 | 494 | and source_filename not in self.source_nodes): |
|
477 | 495 | self.source_nodes[source_filename] = ( |
|
478 | 496 | self.source_node_getter(source_filename)) |
|
479 | 497 | |
|
480 | 498 | if (target_filename and patch['operation'] in ('A', 'M') |
|
481 | 499 | and target_filename not in self.target_nodes): |
|
482 | 500 | self.target_nodes[target_filename] = ( |
|
483 | 501 | self.target_node_getter(target_filename)) |
|
484 | 502 | |
|
485 | 503 | elif hl_mode == self.HL_FAST: |
|
486 | 504 | source_lexer = self._get_lexer_for_filename(source_filename) |
|
487 | 505 | target_lexer = self._get_lexer_for_filename(target_filename) |
|
488 | 506 | |
|
489 | 507 | source_file = self.source_nodes.get(source_filename, source_filename) |
|
490 | 508 | target_file = self.target_nodes.get(target_filename, target_filename) |
|
491 | 509 | raw_id_uid = '' |
|
492 | 510 | if self.source_nodes.get(source_filename): |
|
493 | 511 | raw_id_uid = self.source_nodes[source_filename].commit.raw_id |
|
494 | 512 | |
|
495 | 513 | if not raw_id_uid and self.target_nodes.get(target_filename): |
|
496 | 514 | # in case this is a new file we only have it in target |
|
497 | 515 | raw_id_uid = self.target_nodes[target_filename].commit.raw_id |
|
498 | 516 | |
|
499 | 517 | source_filenode, target_filenode = None, None |
|
500 | 518 | |
|
501 | 519 | # TODO: dan: FileNode.lexer works on the content of the file - which |
|
502 | 520 | # can be slow - issue #4289 explains a lexer clean up - which once |
|
503 | 521 | # done can allow caching a lexer for a filenode to avoid the file lookup |
|
504 | 522 | if isinstance(source_file, FileNode): |
|
505 | 523 | source_filenode = source_file |
|
506 | 524 | #source_lexer = source_file.lexer |
|
507 | 525 | source_lexer = self._get_lexer_for_filename(source_filename) |
|
508 | 526 | source_file.lexer = source_lexer |
|
509 | 527 | |
|
510 | 528 | if isinstance(target_file, FileNode): |
|
511 | 529 | target_filenode = target_file |
|
512 | 530 | #target_lexer = target_file.lexer |
|
513 | 531 | target_lexer = self._get_lexer_for_filename(target_filename) |
|
514 | 532 | target_file.lexer = target_lexer |
|
515 | 533 | |
|
516 | 534 | source_file_path, target_file_path = None, None |
|
517 | 535 | |
|
518 | 536 | if source_filename != '/dev/null': |
|
519 | 537 | source_file_path = source_filename |
|
520 | 538 | if target_filename != '/dev/null': |
|
521 | 539 | target_file_path = target_filename |
|
522 | 540 | |
|
523 | 541 | source_file_type = source_lexer.name |
|
524 | 542 | target_file_type = target_lexer.name |
|
525 | 543 | |
|
526 | 544 | filediff = AttributeDict({ |
|
527 | 545 | 'source_file_path': source_file_path, |
|
528 | 546 | 'target_file_path': target_file_path, |
|
529 | 547 | 'source_filenode': source_filenode, |
|
530 | 548 | 'target_filenode': target_filenode, |
|
531 | 549 | 'source_file_type': target_file_type, |
|
532 | 550 | 'target_file_type': source_file_type, |
|
533 | 551 | 'patch': {'filename': patch['filename'], 'stats': patch['stats']}, |
|
534 | 552 | 'operation': patch['operation'], |
|
535 | 553 | 'source_mode': patch['stats']['old_mode'], |
|
536 | 554 | 'target_mode': patch['stats']['new_mode'], |
|
537 | 555 | 'limited_diff': patch['is_limited_diff'], |
|
538 | 556 | 'hunks': [], |
|
539 | 557 | 'hunk_ops': None, |
|
540 | 558 | 'diffset': self, |
|
541 | 559 | 'raw_id': raw_id_uid, |
|
542 | 560 | }) |
|
543 | 561 | |
|
544 | 562 | file_chunks = patch['chunks'][1:] |
|
545 | 563 | for i, hunk in enumerate(file_chunks, 1): |
|
546 | 564 | hunkbit = self.parse_hunk(hunk, source_file, target_file) |
|
547 | 565 | hunkbit.source_file_path = source_file_path |
|
548 | 566 | hunkbit.target_file_path = target_file_path |
|
549 | 567 | hunkbit.index = i |
|
550 | 568 | filediff.hunks.append(hunkbit) |
|
551 | 569 | |
|
552 | 570 | # Simulate hunk on OPS type line which doesn't really contain any diff |
|
553 | 571 | # this allows commenting on those |
|
554 | 572 | if not file_chunks: |
|
555 | 573 | actions = [] |
|
556 | for op_id, op_text in filediff.patch['stats']['ops'].items(): | |
|
574 | for op_id, op_text in list(filediff.patch['stats']['ops'].items()): | |
|
557 | 575 | if op_id == DEL_FILENODE: |
|
558 | 576 | actions.append('file was removed') |
|
559 | 577 | elif op_id == BIN_FILENODE: |
|
560 | 578 | actions.append('binary diff hidden') |
|
561 | 579 | else: |
|
562 |
actions.append(safe_ |
|
|
580 | actions.append(safe_str(op_text)) | |
|
563 | 581 | action_line = 'NO CONTENT: ' + \ |
|
564 | 582 | ', '.join(actions) or 'UNDEFINED_ACTION' |
|
565 | 583 | |
|
566 | 584 | hunk_ops = {'source_length': 0, 'source_start': 0, |
|
567 | 585 | 'lines': [ |
|
568 | 586 | {'new_lineno': 0, 'old_lineno': 1, |
|
569 | 587 | 'action': 'unmod-no-hl', 'line': action_line} |
|
570 | 588 | ], |
|
571 | 589 | 'section_header': '', 'target_start': 1, 'target_length': 1} |
|
572 | 590 | |
|
573 | 591 | hunkbit = self.parse_hunk(hunk_ops, source_file, target_file) |
|
574 | 592 | hunkbit.source_file_path = source_file_path |
|
575 | 593 | hunkbit.target_file_path = target_file_path |
|
576 | 594 | filediff.hunk_ops = hunkbit |
|
577 | 595 | return filediff |
|
578 | 596 | |
|
579 | 597 | def parse_hunk(self, hunk, source_file, target_file): |
|
580 | 598 | result = AttributeDict(dict( |
|
581 | 599 | source_start=hunk['source_start'], |
|
582 | 600 | source_length=hunk['source_length'], |
|
583 | 601 | target_start=hunk['target_start'], |
|
584 | 602 | target_length=hunk['target_length'], |
|
585 | 603 | section_header=hunk['section_header'], |
|
586 | 604 | lines=[], |
|
587 | 605 | )) |
|
588 | 606 | before, after = [], [] |
|
589 | 607 | |
|
590 | 608 | for line in hunk['lines']: |
|
609 | ||
|
591 | 610 | if line['action'] in ['unmod', 'unmod-no-hl']: |
|
592 | 611 | no_hl = line['action'] == 'unmod-no-hl' |
|
593 | result.lines.extend( | |
|
594 | self.parse_lines(before, after, source_file, target_file, no_hl=no_hl)) | |
|
612 | parsed_lines = self.parse_lines(before, after, source_file, target_file, no_hl=no_hl) | |
|
613 | result.lines.extend(parsed_lines) | |
|
595 | 614 | after.append(line) |
|
596 | 615 | before.append(line) |
|
597 | 616 | elif line['action'] == 'add': |
|
598 | 617 | after.append(line) |
|
599 | 618 | elif line['action'] == 'del': |
|
600 | 619 | before.append(line) |
|
601 | 620 | elif line['action'] == 'old-no-nl': |
|
602 | 621 | before.append(line) |
|
622 | #line['line'] = safe_str(line['line']) | |
|
603 | 623 | elif line['action'] == 'new-no-nl': |
|
624 | #line['line'] = safe_str(line['line']) | |
|
604 | 625 | after.append(line) |
|
605 | 626 | |
|
606 | 627 | all_actions = [x['action'] for x in after] + [x['action'] for x in before] |
|
607 | 628 | no_hl = {x for x in all_actions} == {'unmod-no-hl'} |
|
608 | result.lines.extend( | |
|
609 | self.parse_lines(before, after, source_file, target_file, no_hl=no_hl)) | |
|
610 | # NOTE(marcink): we must keep list() call here so we can cache the result... | |
|
629 | parsed_no_hl_lines = self.parse_lines(before, after, source_file, target_file, no_hl=no_hl) | |
|
630 | result.lines.extend(parsed_no_hl_lines) | |
|
631 | ||
|
632 | # NOTE(marcink): we must keep list() call here, so we can cache the result... | |
|
611 | 633 | result.unified = list(self.as_unified(result.lines)) |
|
612 | 634 | result.sideside = result.lines |
|
613 | 635 | |
|
614 | 636 | return result |
|
615 | 637 | |
|
616 | 638 | def parse_lines(self, before_lines, after_lines, source_file, target_file, |
|
617 | 639 | no_hl=False): |
|
618 | 640 | # TODO: dan: investigate doing the diff comparison and fast highlighting |
|
619 | 641 | # on the entire before and after buffered block lines rather than by |
|
620 | 642 | # line, this means we can get better 'fast' highlighting if the context |
|
621 | 643 | # allows it - eg. |
|
622 | 644 | # line 4: """ |
|
623 | 645 | # line 5: this gets highlighted as a string |
|
624 | 646 | # line 6: """ |
|
625 | 647 | |
|
626 | 648 | lines = [] |
|
627 | 649 | |
|
628 | 650 | before_newline = AttributeDict() |
|
629 | 651 | after_newline = AttributeDict() |
|
630 | 652 | if before_lines and before_lines[-1]['action'] == 'old-no-nl': |
|
631 | 653 | before_newline_line = before_lines.pop(-1) |
|
632 | 654 | before_newline.content = '\n {}'.format( |
|
633 | 655 | render_tokenstream( |
|
634 | [(x[0], '', x[1]) | |
|
656 | [(x[0], '', safe_str(x[1])) | |
|
635 | 657 | for x in [('nonl', before_newline_line['line'])]])) |
|
636 | 658 | |
|
637 | 659 | if after_lines and after_lines[-1]['action'] == 'new-no-nl': |
|
638 | 660 | after_newline_line = after_lines.pop(-1) |
|
639 | 661 | after_newline.content = '\n {}'.format( |
|
640 | 662 | render_tokenstream( |
|
641 | [(x[0], '', x[1]) | |
|
663 | [(x[0], '', safe_str(x[1])) | |
|
642 | 664 | for x in [('nonl', after_newline_line['line'])]])) |
|
643 | 665 | |
|
644 | 666 | while before_lines or after_lines: |
|
645 | 667 | before, after = None, None |
|
646 | 668 | before_tokens, after_tokens = None, None |
|
647 | 669 | |
|
648 | 670 | if before_lines: |
|
649 | 671 | before = before_lines.pop(0) |
|
650 | 672 | if after_lines: |
|
651 | 673 | after = after_lines.pop(0) |
|
652 | 674 | |
|
653 | 675 | original = AttributeDict() |
|
654 | 676 | modified = AttributeDict() |
|
655 | 677 | |
|
656 | 678 | if before: |
|
657 | 679 | if before['action'] == 'old-no-nl': |
|
658 | before_tokens = [('nonl', before['line'])] | |
|
680 | before_tokens = [('nonl', safe_str(before['line']))] | |
|
659 | 681 | else: |
|
660 | 682 | before_tokens = self.get_line_tokens( |
|
661 | 683 | line_text=before['line'], line_number=before['old_lineno'], |
|
662 | 684 | input_file=source_file, no_hl=no_hl, source='before') |
|
663 | 685 | original.lineno = before['old_lineno'] |
|
664 | 686 | original.content = before['line'] |
|
665 | 687 | original.action = self.action_to_op(before['action']) |
|
666 | 688 | |
|
667 | 689 | original.get_comment_args = ( |
|
668 | 690 | source_file, 'o', before['old_lineno']) |
|
669 | 691 | |
|
670 | 692 | if after: |
|
671 | 693 | if after['action'] == 'new-no-nl': |
|
672 | after_tokens = [('nonl', after['line'])] | |
|
694 | after_tokens = [('nonl', safe_str(after['line']))] | |
|
673 | 695 | else: |
|
674 | 696 | after_tokens = self.get_line_tokens( |
|
675 | 697 | line_text=after['line'], line_number=after['new_lineno'], |
|
676 | 698 | input_file=target_file, no_hl=no_hl, source='after') |
|
677 | 699 | modified.lineno = after['new_lineno'] |
|
678 | 700 | modified.content = after['line'] |
|
679 | 701 | modified.action = self.action_to_op(after['action']) |
|
680 | 702 | |
|
681 | 703 | modified.get_comment_args = (target_file, 'n', after['new_lineno']) |
|
682 | 704 | |
|
683 | 705 | # diff the lines |
|
684 | 706 | if before_tokens and after_tokens: |
|
685 | 707 | o_tokens, m_tokens, similarity = tokens_diff( |
|
686 | 708 | before_tokens, after_tokens) |
|
687 | 709 | original.content = render_tokenstream(o_tokens) |
|
688 | 710 | modified.content = render_tokenstream(m_tokens) |
|
689 | 711 | elif before_tokens: |
|
690 | 712 | original.content = render_tokenstream( |
|
691 | 713 | [(x[0], '', x[1]) for x in before_tokens]) |
|
692 | 714 | elif after_tokens: |
|
693 | 715 | modified.content = render_tokenstream( |
|
694 | 716 | [(x[0], '', x[1]) for x in after_tokens]) |
|
695 | 717 | |
|
696 | 718 | if not before_lines and before_newline: |
|
697 | 719 | original.content += before_newline.content |
|
698 | 720 | before_newline = None |
|
699 | 721 | if not after_lines and after_newline: |
|
700 | 722 | modified.content += after_newline.content |
|
701 | 723 | after_newline = None |
|
702 | 724 | |
|
703 | 725 | lines.append(AttributeDict({ |
|
704 | 726 | 'original': original, |
|
705 | 727 | 'modified': modified, |
|
706 | 728 | })) |
|
707 | 729 | |
|
708 | 730 | return lines |
|
709 | 731 | |
|
710 | 732 | def get_line_tokens(self, line_text, line_number, input_file=None, no_hl=False, source=''): |
|
711 | 733 | filenode = None |
|
712 | 734 | filename = None |
|
713 | 735 | |
|
714 | 736 | if isinstance(input_file, str): |
|
715 | 737 | filename = input_file |
|
716 | 738 | elif isinstance(input_file, FileNode): |
|
717 | 739 | filenode = input_file |
|
718 |
filename = input_file. |
|
|
740 | filename = input_file.str_path | |
|
719 | 741 | |
|
720 | 742 | hl_mode = self.HL_NONE if no_hl else self.highlight_mode |
|
721 | 743 | if hl_mode == self.HL_REAL and filenode: |
|
722 | 744 | lexer = self._get_lexer_for_filename(filename) |
|
723 |
file_size_allowed = |
|
|
745 | file_size_allowed = filenode.size < self.max_file_size_limit | |
|
724 | 746 | if line_number and file_size_allowed: |
|
725 |
return self.get_tokenized_filenode_line( |
|
|
747 | return self.get_tokenized_filenode_line(filenode, line_number, lexer, source) | |
|
726 | 748 | |
|
727 | 749 | if hl_mode in (self.HL_REAL, self.HL_FAST) and filename: |
|
728 | 750 | lexer = self._get_lexer_for_filename(filename) |
|
729 | 751 | return list(tokenize_string(line_text, lexer)) |
|
730 | 752 | |
|
731 | 753 | return list(tokenize_string(line_text, plain_text_lexer)) |
|
732 | 754 | |
|
733 | 755 | def get_tokenized_filenode_line(self, filenode, line_number, lexer=None, source=''): |
|
756 | name_hash = hash(filenode) | |
|
734 | 757 | |
|
735 | def tokenize(_filenode): | |
|
736 | self.highlighted_filenodes[source][filenode] = filenode_as_lines_tokens(filenode, lexer) | |
|
758 | hl_node_code = self.highlighted_filenodes[source] | |
|
737 | 759 | |
|
738 | if filenode not in self.highlighted_filenodes[source]: | |
|
739 | tokenize(filenode) | |
|
760 | if name_hash not in hl_node_code: | |
|
761 | hl_node_code[name_hash] = filenode_as_lines_tokens(filenode, lexer) | |
|
740 | 762 | |
|
741 | 763 | try: |
|
742 |
return |
|
|
764 | return hl_node_code[name_hash][line_number - 1] | |
|
743 | 765 | except Exception: |
|
744 | log.exception('diff rendering error') | |
|
766 | log.exception('diff rendering error on L:%s and file=%s', line_number - 1, filenode.name) | |
|
745 | 767 | return [('', 'L{}: rhodecode diff rendering error'.format(line_number))] |
|
746 | 768 | |
|
747 | 769 | def action_to_op(self, action): |
|
748 | 770 | return { |
|
749 | 771 | 'add': '+', |
|
750 | 772 | 'del': '-', |
|
751 | 773 | 'unmod': ' ', |
|
752 | 774 | 'unmod-no-hl': ' ', |
|
753 | 775 | 'old-no-nl': ' ', |
|
754 | 776 | 'new-no-nl': ' ', |
|
755 | 777 | }.get(action, action) |
|
756 | 778 | |
|
757 | 779 | def as_unified(self, lines): |
|
758 | 780 | """ |
|
759 | 781 | Return a generator that yields the lines of a diff in unified order |
|
760 | 782 | """ |
|
761 | 783 | def generator(): |
|
762 | 784 | buf = [] |
|
763 | 785 | for line in lines: |
|
764 | 786 | |
|
765 | 787 | if buf and not line.original or line.original.action == ' ': |
|
766 | 788 | for b in buf: |
|
767 | 789 | yield b |
|
768 | 790 | buf = [] |
|
769 | 791 | |
|
770 | 792 | if line.original: |
|
771 | 793 | if line.original.action == ' ': |
|
772 | 794 | yield (line.original.lineno, line.modified.lineno, |
|
773 | 795 | line.original.action, line.original.content, |
|
774 | 796 | line.original.get_comment_args) |
|
775 | 797 | continue |
|
776 | 798 | |
|
777 | 799 | if line.original.action == '-': |
|
778 | 800 | yield (line.original.lineno, None, |
|
779 | 801 | line.original.action, line.original.content, |
|
780 | 802 | line.original.get_comment_args) |
|
781 | 803 | |
|
782 | 804 | if line.modified.action == '+': |
|
783 | 805 | buf.append(( |
|
784 | 806 | None, line.modified.lineno, |
|
785 | 807 | line.modified.action, line.modified.content, |
|
786 | 808 | line.modified.get_comment_args)) |
|
787 | 809 | continue |
|
788 | 810 | |
|
789 | 811 | if line.modified: |
|
790 | 812 | yield (None, line.modified.lineno, |
|
791 | 813 | line.modified.action, line.modified.content, |
|
792 | 814 | line.modified.get_comment_args) |
|
793 | 815 | |
|
794 | 816 | for b in buf: |
|
795 | 817 | yield b |
|
796 | 818 | |
|
797 | 819 | return generator() |
@@ -1,679 +1,688 b'' | |||
|
1 | 1 | |
|
2 | 2 | # Copyright (C) 2010-2020 RhodeCode GmbH |
|
3 | 3 | # |
|
4 | 4 | # This program is free software: you can redistribute it and/or modify |
|
5 | 5 | # it under the terms of the GNU Affero General Public License, version 3 |
|
6 | 6 | # (only), as published by the Free Software Foundation. |
|
7 | 7 | # |
|
8 | 8 | # This program is distributed in the hope that it will be useful, |
|
9 | 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
10 | 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
11 | 11 | # GNU General Public License for more details. |
|
12 | 12 | # |
|
13 | 13 | # You should have received a copy of the GNU Affero General Public License |
|
14 | 14 | # along with this program. If not, see <http://www.gnu.org/licenses/>. |
|
15 | 15 | # |
|
16 | 16 | # This program is dual-licensed. If you wish to learn more about the |
|
17 | 17 | # RhodeCode Enterprise Edition, including its added features, Support services, |
|
18 | 18 | # and proprietary license terms, please see https://rhodecode.com/licenses/ |
|
19 | 19 | |
|
20 | 20 | """ |
|
21 | 21 | Database creation, and setup module for RhodeCode Enterprise. Used for creation |
|
22 | 22 | of database as well as for migration operations |
|
23 | 23 | """ |
|
24 | 24 | |
|
25 | 25 | import os |
|
26 | 26 | import sys |
|
27 | 27 | import time |
|
28 | 28 | import uuid |
|
29 | 29 | import logging |
|
30 | 30 | import getpass |
|
31 | 31 | from os.path import dirname as dn, join as jn |
|
32 | 32 | |
|
33 | 33 | from sqlalchemy.engine import create_engine |
|
34 | 34 | |
|
35 | 35 | from rhodecode import __dbversion__ |
|
36 | 36 | from rhodecode.model import init_model |
|
37 | 37 | from rhodecode.model.user import UserModel |
|
38 | 38 | from rhodecode.model.db import ( |
|
39 | 39 | User, Permission, RhodeCodeUi, RhodeCodeSetting, UserToPerm, |
|
40 | 40 | DbMigrateVersion, RepoGroup, UserRepoGroupToPerm, CacheKey, Repository) |
|
41 | 41 | from rhodecode.model.meta import Session, Base |
|
42 | 42 | from rhodecode.model.permission import PermissionModel |
|
43 | 43 | from rhodecode.model.repo import RepoModel |
|
44 | 44 | from rhodecode.model.repo_group import RepoGroupModel |
|
45 | 45 | from rhodecode.model.settings import SettingsModel |
|
46 | 46 | |
|
47 | 47 | |
|
48 | 48 | log = logging.getLogger(__name__) |
|
49 | 49 | |
|
50 | 50 | |
|
51 | 51 | def notify(msg): |
|
52 | 52 | """ |
|
53 | 53 | Notification for migrations messages |
|
54 | 54 | """ |
|
55 | 55 | ml = len(msg) + (4 * 2) |
|
56 | print(('\n%s\n*** %s ***\n%s' % ('*' * ml, msg, '*' * ml)).upper()) | |
|
56 | print((('\n%s\n*** %s ***\n%s' % ('*' * ml, msg, '*' * ml)).upper())) | |
|
57 | 57 | |
|
58 | 58 | |
|
59 | 59 | class DbManage(object): |
|
60 | 60 | |
|
61 | 61 | def __init__(self, log_sql, dbconf, root, tests=False, |
|
62 | SESSION=None, cli_args=None): | |
|
62 | SESSION=None, cli_args=None, enc_key=b''): | |
|
63 | ||
|
63 | 64 | self.dbname = dbconf.split('/')[-1] |
|
64 | 65 | self.tests = tests |
|
65 | 66 | self.root = root |
|
66 | 67 | self.dburi = dbconf |
|
67 | 68 | self.log_sql = log_sql |
|
68 | 69 | self.cli_args = cli_args or {} |
|
70 | self.sa = None | |
|
71 | self.engine = None | |
|
72 | self.enc_key = enc_key | |
|
73 | # sets .sa .engine | |
|
69 | 74 | self.init_db(SESSION=SESSION) |
|
75 | ||
|
70 | 76 | self.ask_ok = self.get_ask_ok_func(self.cli_args.get('force_ask')) |
|
71 | 77 | |
|
72 | 78 | def db_exists(self): |
|
73 | 79 | if not self.sa: |
|
74 | 80 | self.init_db() |
|
75 | 81 | try: |
|
76 | 82 | self.sa.query(RhodeCodeUi)\ |
|
77 | 83 | .filter(RhodeCodeUi.ui_key == '/')\ |
|
78 | 84 | .scalar() |
|
79 | 85 | return True |
|
80 | 86 | except Exception: |
|
81 | 87 | return False |
|
82 | 88 | finally: |
|
83 | 89 | self.sa.rollback() |
|
84 | 90 | |
|
85 | 91 | def get_ask_ok_func(self, param): |
|
86 | 92 | if param not in [None]: |
|
87 | 93 | # return a function lambda that has a default set to param |
|
88 | 94 | return lambda *args, **kwargs: param |
|
89 | 95 | else: |
|
90 | 96 | from rhodecode.lib.utils import ask_ok |
|
91 | 97 | return ask_ok |
|
92 | 98 | |
|
93 | 99 | def init_db(self, SESSION=None): |
|
100 | ||
|
94 | 101 | if SESSION: |
|
95 | 102 | self.sa = SESSION |
|
103 | self.engine = SESSION.bind | |
|
96 | 104 | else: |
|
97 | 105 | # init new sessions |
|
98 | 106 | engine = create_engine(self.dburi, echo=self.log_sql) |
|
99 | init_model(engine) | |
|
107 | init_model(engine, encryption_key=self.enc_key) | |
|
100 | 108 | self.sa = Session() |
|
109 | self.engine = engine | |
|
101 | 110 | |
|
102 | 111 | def create_tables(self, override=False): |
|
103 | 112 | """ |
|
104 | 113 | Create a auth database |
|
105 | 114 | """ |
|
106 | 115 | |
|
107 | 116 | log.info("Existing database with the same name is going to be destroyed.") |
|
108 | 117 | log.info("Setup command will run DROP ALL command on that database.") |
|
118 | engine = self.engine | |
|
119 | ||
|
109 | 120 | if self.tests: |
|
110 | 121 | destroy = True |
|
111 | 122 | else: |
|
112 | 123 | destroy = self.ask_ok('Are you sure that you want to destroy the old database? [y/n]') |
|
113 | 124 | if not destroy: |
|
114 | 125 | log.info('db tables bootstrap: Nothing done.') |
|
115 | 126 | sys.exit(0) |
|
116 | 127 | if destroy: |
|
117 | Base.metadata.drop_all() | |
|
128 | Base.metadata.drop_all(bind=engine) | |
|
118 | 129 | |
|
119 | 130 | checkfirst = not override |
|
120 | Base.metadata.create_all(checkfirst=checkfirst) | |
|
131 | Base.metadata.create_all(bind=engine, checkfirst=checkfirst) | |
|
121 | 132 | log.info('Created tables for %s', self.dbname) |
|
122 | 133 | |
|
123 | 134 | def set_db_version(self): |
|
124 | 135 | ver = DbMigrateVersion() |
|
125 | 136 | ver.version = __dbversion__ |
|
126 | 137 | ver.repository_id = 'rhodecode_db_migrations' |
|
127 | 138 | ver.repository_path = 'versions' |
|
128 | 139 | self.sa.add(ver) |
|
129 | 140 | log.info('db version set to: %s', __dbversion__) |
|
130 | 141 | |
|
131 | 142 | def run_post_migration_tasks(self): |
|
132 | 143 | """ |
|
133 | 144 | Run various tasks before actually doing migrations |
|
134 | 145 | """ |
|
135 | 146 | # delete cache keys on each upgrade |
|
136 | 147 | total = CacheKey.query().count() |
|
137 | 148 | log.info("Deleting (%s) cache keys now...", total) |
|
138 | 149 | CacheKey.delete_all_cache() |
|
139 | 150 | |
|
140 | 151 | def upgrade(self, version=None): |
|
141 | 152 | """ |
|
142 | 153 | Upgrades given database schema to given revision following |
|
143 | 154 | all needed steps, to perform the upgrade |
|
144 | 155 | |
|
145 | 156 | """ |
|
146 | 157 | |
|
147 | 158 | from rhodecode.lib.dbmigrate.migrate.versioning import api |
|
148 |
from rhodecode.lib.dbmigrate.migrate.exceptions import |
|
|
149 | DatabaseNotControlledError | |
|
159 | from rhodecode.lib.dbmigrate.migrate.exceptions import DatabaseNotControlledError | |
|
150 | 160 | |
|
151 | 161 | if 'sqlite' in self.dburi: |
|
152 | 162 | print( |
|
153 | 163 | '********************** WARNING **********************\n' |
|
154 | 164 | 'Make sure your version of sqlite is at least 3.7.X. \n' |
|
155 | 165 | 'Earlier versions are known to fail on some migrations\n' |
|
156 | 166 | '*****************************************************\n') |
|
157 | 167 | |
|
158 | 168 | upgrade = self.ask_ok( |
|
159 | 169 | 'You are about to perform a database upgrade. Make ' |
|
160 | 170 | 'sure you have backed up your database. ' |
|
161 | 171 | 'Continue ? [y/n]') |
|
162 | 172 | if not upgrade: |
|
163 | 173 | log.info('No upgrade performed') |
|
164 | 174 | sys.exit(0) |
|
165 | 175 | |
|
166 | 176 | repository_path = jn(dn(dn(dn(os.path.realpath(__file__)))), |
|
167 | 177 | 'rhodecode/lib/dbmigrate') |
|
168 | 178 | db_uri = self.dburi |
|
169 | 179 | |
|
170 | 180 | if version: |
|
171 | 181 | DbMigrateVersion.set_version(version) |
|
172 | 182 | |
|
173 | 183 | try: |
|
174 | 184 | curr_version = api.db_version(db_uri, repository_path) |
|
175 | msg = ('Found current database db_uri under version ' | |
|
176 |
'control with version {}' |
|
|
185 | msg = (f'Found current database db_uri under version ' | |
|
186 | f'control with version {curr_version}') | |
|
177 | 187 | |
|
178 | 188 | except (RuntimeError, DatabaseNotControlledError): |
|
179 | 189 | curr_version = 1 |
|
180 |
msg = |
|
|
181 |
|
|
|
190 | msg = f'Current database is not under version control. ' \ | |
|
191 | f'Setting as version {curr_version}' | |
|
182 | 192 | api.version_control(db_uri, repository_path, curr_version) |
|
183 | 193 | |
|
184 | 194 | notify(msg) |
|
185 | 195 | |
|
186 | ||
|
187 | 196 | if curr_version == __dbversion__: |
|
188 | 197 | log.info('This database is already at the newest version') |
|
189 | 198 | sys.exit(0) |
|
190 | 199 | |
|
191 | upgrade_steps = range(curr_version + 1, __dbversion__ + 1) | |
|
192 | notify('attempting to upgrade database from ' | |
|
193 |
'version |
|
|
200 | upgrade_steps = list(range(curr_version + 1, __dbversion__ + 1)) | |
|
201 | notify(f'attempting to upgrade database from ' | |
|
202 | f'version {curr_version} to version {__dbversion__}') | |
|
194 | 203 | |
|
195 | 204 | # CALL THE PROPER ORDER OF STEPS TO PERFORM FULL UPGRADE |
|
196 | 205 | _step = None |
|
197 | 206 | for step in upgrade_steps: |
|
198 |
notify('performing upgrade step |
|
|
207 | notify(f'performing upgrade step {step}') | |
|
199 | 208 | time.sleep(0.5) |
|
200 | 209 | |
|
201 | 210 | api.upgrade(db_uri, repository_path, step) |
|
202 | 211 | self.sa.rollback() |
|
203 |
notify('schema upgrade for step |
|
|
212 | notify(f'schema upgrade for step {step} completed') | |
|
204 | 213 | |
|
205 | 214 | _step = step |
|
206 | 215 | |
|
207 | 216 | self.run_post_migration_tasks() |
|
208 |
notify('upgrade to version |
|
|
217 | notify(f'upgrade to version {step} successful') | |
|
209 | 218 | |
|
210 | 219 | def fix_repo_paths(self): |
|
211 | 220 | """ |
|
212 | 221 | Fixes an old RhodeCode version path into new one without a '*' |
|
213 | 222 | """ |
|
214 | 223 | |
|
215 | 224 | paths = self.sa.query(RhodeCodeUi)\ |
|
216 | 225 | .filter(RhodeCodeUi.ui_key == '/')\ |
|
217 | 226 | .scalar() |
|
218 | 227 | |
|
219 | 228 | paths.ui_value = paths.ui_value.replace('*', '') |
|
220 | 229 | |
|
221 | 230 | try: |
|
222 | 231 | self.sa.add(paths) |
|
223 | 232 | self.sa.commit() |
|
224 | 233 | except Exception: |
|
225 | 234 | self.sa.rollback() |
|
226 | 235 | raise |
|
227 | 236 | |
|
228 | 237 | def fix_default_user(self): |
|
229 | 238 | """ |
|
230 | 239 | Fixes an old default user with some 'nicer' default values, |
|
231 | 240 | used mostly for anonymous access |
|
232 | 241 | """ |
|
233 | 242 | def_user = self.sa.query(User)\ |
|
234 |
|
|
|
235 |
|
|
|
243 | .filter(User.username == User.DEFAULT_USER)\ | |
|
244 | .one() | |
|
236 | 245 | |
|
237 | 246 | def_user.name = 'Anonymous' |
|
238 | 247 | def_user.lastname = 'User' |
|
239 | 248 | def_user.email = User.DEFAULT_USER_EMAIL |
|
240 | 249 | |
|
241 | 250 | try: |
|
242 | 251 | self.sa.add(def_user) |
|
243 | 252 | self.sa.commit() |
|
244 | 253 | except Exception: |
|
245 | 254 | self.sa.rollback() |
|
246 | 255 | raise |
|
247 | 256 | |
|
248 | 257 | def fix_settings(self): |
|
249 | 258 | """ |
|
250 | 259 | Fixes rhodecode settings and adds ga_code key for google analytics |
|
251 | 260 | """ |
|
252 | 261 | |
|
253 | 262 | hgsettings3 = RhodeCodeSetting('ga_code', '') |
|
254 | 263 | |
|
255 | 264 | try: |
|
256 | 265 | self.sa.add(hgsettings3) |
|
257 | 266 | self.sa.commit() |
|
258 | 267 | except Exception: |
|
259 | 268 | self.sa.rollback() |
|
260 | 269 | raise |
|
261 | 270 | |
|
262 | 271 | def create_admin_and_prompt(self): |
|
263 | 272 | |
|
264 | 273 | # defaults |
|
265 | 274 | defaults = self.cli_args |
|
266 | 275 | username = defaults.get('username') |
|
267 | 276 | password = defaults.get('password') |
|
268 | 277 | email = defaults.get('email') |
|
269 | 278 | |
|
270 | 279 | if username is None: |
|
271 | 280 | username = eval(input('Specify admin username:')) |
|
272 | 281 | if password is None: |
|
273 | 282 | password = self._get_admin_password() |
|
274 | 283 | if not password: |
|
275 | 284 | # second try |
|
276 | 285 | password = self._get_admin_password() |
|
277 | 286 | if not password: |
|
278 | 287 | sys.exit() |
|
279 | 288 | if email is None: |
|
280 | 289 | email = eval(input('Specify admin email:')) |
|
281 | 290 | api_key = self.cli_args.get('api_key') |
|
282 | 291 | self.create_user(username, password, email, True, |
|
283 | 292 | strict_creation_check=False, |
|
284 | 293 | api_key=api_key) |
|
285 | 294 | |
|
286 | 295 | def _get_admin_password(self): |
|
287 | 296 | password = getpass.getpass('Specify admin password ' |
|
288 | 297 | '(min 6 chars):') |
|
289 | 298 | confirm = getpass.getpass('Confirm password:') |
|
290 | 299 | |
|
291 | 300 | if password != confirm: |
|
292 | 301 | log.error('passwords mismatch') |
|
293 | 302 | return False |
|
294 | 303 | if len(password) < 6: |
|
295 | 304 | log.error('password is too short - use at least 6 characters') |
|
296 | 305 | return False |
|
297 | 306 | |
|
298 | 307 | return password |
|
299 | 308 | |
|
300 | 309 | def create_test_admin_and_users(self): |
|
301 | 310 | log.info('creating admin and regular test users') |
|
302 | 311 | from rhodecode.tests import TEST_USER_ADMIN_LOGIN, \ |
|
303 | 312 | TEST_USER_ADMIN_PASS, TEST_USER_ADMIN_EMAIL, \ |
|
304 | 313 | TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS, \ |
|
305 | 314 | TEST_USER_REGULAR_EMAIL, TEST_USER_REGULAR2_LOGIN, \ |
|
306 | 315 | TEST_USER_REGULAR2_PASS, TEST_USER_REGULAR2_EMAIL |
|
307 | 316 | |
|
308 | 317 | self.create_user(TEST_USER_ADMIN_LOGIN, TEST_USER_ADMIN_PASS, |
|
309 | 318 | TEST_USER_ADMIN_EMAIL, True, api_key=True) |
|
310 | 319 | |
|
311 | 320 | self.create_user(TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS, |
|
312 | 321 | TEST_USER_REGULAR_EMAIL, False, api_key=True) |
|
313 | 322 | |
|
314 | 323 | self.create_user(TEST_USER_REGULAR2_LOGIN, TEST_USER_REGULAR2_PASS, |
|
315 | 324 | TEST_USER_REGULAR2_EMAIL, False, api_key=True) |
|
316 | 325 | |
|
317 | 326 | def create_ui_settings(self, repo_store_path): |
|
318 | 327 | """ |
|
319 | 328 | Creates ui settings, fills out hooks |
|
320 | 329 | and disables dotencode |
|
321 | 330 | """ |
|
322 | 331 | settings_model = SettingsModel(sa=self.sa) |
|
323 | 332 | from rhodecode.lib.vcs.backends.hg import largefiles_store |
|
324 | 333 | from rhodecode.lib.vcs.backends.git import lfs_store |
|
325 | 334 | |
|
326 | 335 | # Build HOOKS |
|
327 | 336 | hooks = [ |
|
328 | 337 | (RhodeCodeUi.HOOK_REPO_SIZE, 'python:vcsserver.hooks.repo_size'), |
|
329 | 338 | |
|
330 | 339 | # HG |
|
331 | 340 | (RhodeCodeUi.HOOK_PRE_PULL, 'python:vcsserver.hooks.pre_pull'), |
|
332 | 341 | (RhodeCodeUi.HOOK_PULL, 'python:vcsserver.hooks.log_pull_action'), |
|
333 | 342 | (RhodeCodeUi.HOOK_PRE_PUSH, 'python:vcsserver.hooks.pre_push'), |
|
334 | 343 | (RhodeCodeUi.HOOK_PRETX_PUSH, 'python:vcsserver.hooks.pre_push'), |
|
335 | 344 | (RhodeCodeUi.HOOK_PUSH, 'python:vcsserver.hooks.log_push_action'), |
|
336 | 345 | (RhodeCodeUi.HOOK_PUSH_KEY, 'python:vcsserver.hooks.key_push'), |
|
337 | 346 | |
|
338 | 347 | ] |
|
339 | 348 | |
|
340 | 349 | for key, value in hooks: |
|
341 | 350 | hook_obj = settings_model.get_ui_by_key(key) |
|
342 | 351 | hooks2 = hook_obj if hook_obj else RhodeCodeUi() |
|
343 | 352 | hooks2.ui_section = 'hooks' |
|
344 | 353 | hooks2.ui_key = key |
|
345 | 354 | hooks2.ui_value = value |
|
346 | 355 | self.sa.add(hooks2) |
|
347 | 356 | |
|
348 | 357 | # enable largefiles |
|
349 | 358 | largefiles = RhodeCodeUi() |
|
350 | 359 | largefiles.ui_section = 'extensions' |
|
351 | 360 | largefiles.ui_key = 'largefiles' |
|
352 | 361 | largefiles.ui_value = '' |
|
353 | 362 | self.sa.add(largefiles) |
|
354 | 363 | |
|
355 | 364 | # set default largefiles cache dir, defaults to |
|
356 | 365 | # /repo_store_location/.cache/largefiles |
|
357 | 366 | largefiles = RhodeCodeUi() |
|
358 | 367 | largefiles.ui_section = 'largefiles' |
|
359 | 368 | largefiles.ui_key = 'usercache' |
|
360 | 369 | largefiles.ui_value = largefiles_store(repo_store_path) |
|
361 | 370 | |
|
362 | 371 | self.sa.add(largefiles) |
|
363 | 372 | |
|
364 | 373 | # set default lfs cache dir, defaults to |
|
365 | 374 | # /repo_store_location/.cache/lfs_store |
|
366 | 375 | lfsstore = RhodeCodeUi() |
|
367 | 376 | lfsstore.ui_section = 'vcs_git_lfs' |
|
368 | 377 | lfsstore.ui_key = 'store_location' |
|
369 | 378 | lfsstore.ui_value = lfs_store(repo_store_path) |
|
370 | 379 | |
|
371 | 380 | self.sa.add(lfsstore) |
|
372 | 381 | |
|
373 | 382 | # enable hgsubversion disabled by default |
|
374 | 383 | hgsubversion = RhodeCodeUi() |
|
375 | 384 | hgsubversion.ui_section = 'extensions' |
|
376 | 385 | hgsubversion.ui_key = 'hgsubversion' |
|
377 | 386 | hgsubversion.ui_value = '' |
|
378 | 387 | hgsubversion.ui_active = False |
|
379 | 388 | self.sa.add(hgsubversion) |
|
380 | 389 | |
|
381 | 390 | # enable hgevolve disabled by default |
|
382 | 391 | hgevolve = RhodeCodeUi() |
|
383 | 392 | hgevolve.ui_section = 'extensions' |
|
384 | 393 | hgevolve.ui_key = 'evolve' |
|
385 | 394 | hgevolve.ui_value = '' |
|
386 | 395 | hgevolve.ui_active = False |
|
387 | 396 | self.sa.add(hgevolve) |
|
388 | 397 | |
|
389 | 398 | hgevolve = RhodeCodeUi() |
|
390 | 399 | hgevolve.ui_section = 'experimental' |
|
391 | 400 | hgevolve.ui_key = 'evolution' |
|
392 | 401 | hgevolve.ui_value = '' |
|
393 | 402 | hgevolve.ui_active = False |
|
394 | 403 | self.sa.add(hgevolve) |
|
395 | 404 | |
|
396 | 405 | hgevolve = RhodeCodeUi() |
|
397 | 406 | hgevolve.ui_section = 'experimental' |
|
398 | 407 | hgevolve.ui_key = 'evolution.exchange' |
|
399 | 408 | hgevolve.ui_value = '' |
|
400 | 409 | hgevolve.ui_active = False |
|
401 | 410 | self.sa.add(hgevolve) |
|
402 | 411 | |
|
403 | 412 | hgevolve = RhodeCodeUi() |
|
404 | 413 | hgevolve.ui_section = 'extensions' |
|
405 | 414 | hgevolve.ui_key = 'topic' |
|
406 | 415 | hgevolve.ui_value = '' |
|
407 | 416 | hgevolve.ui_active = False |
|
408 | 417 | self.sa.add(hgevolve) |
|
409 | 418 | |
|
410 | 419 | # enable hggit disabled by default |
|
411 | 420 | hggit = RhodeCodeUi() |
|
412 | 421 | hggit.ui_section = 'extensions' |
|
413 | 422 | hggit.ui_key = 'hggit' |
|
414 | 423 | hggit.ui_value = '' |
|
415 | 424 | hggit.ui_active = False |
|
416 | 425 | self.sa.add(hggit) |
|
417 | 426 | |
|
418 | 427 | # set svn branch defaults |
|
419 | 428 | branches = ["/branches/*", "/trunk"] |
|
420 | 429 | tags = ["/tags/*"] |
|
421 | 430 | |
|
422 | 431 | for branch in branches: |
|
423 | 432 | settings_model.create_ui_section_value( |
|
424 | 433 | RhodeCodeUi.SVN_BRANCH_ID, branch) |
|
425 | 434 | |
|
426 | 435 | for tag in tags: |
|
427 | 436 | settings_model.create_ui_section_value(RhodeCodeUi.SVN_TAG_ID, tag) |
|
428 | 437 | |
|
429 | 438 | def create_auth_plugin_options(self, skip_existing=False): |
|
430 | 439 | """ |
|
431 | 440 | Create default auth plugin settings, and make it active |
|
432 | 441 | |
|
433 | 442 | :param skip_existing: |
|
434 | 443 | """ |
|
435 | 444 | defaults = [ |
|
436 | 445 | ('auth_plugins', |
|
437 | 446 | 'egg:rhodecode-enterprise-ce#token,egg:rhodecode-enterprise-ce#rhodecode', |
|
438 | 447 | 'list'), |
|
439 | 448 | |
|
440 | 449 | ('auth_authtoken_enabled', |
|
441 | 450 | 'True', |
|
442 | 451 | 'bool'), |
|
443 | 452 | |
|
444 | 453 | ('auth_rhodecode_enabled', |
|
445 | 454 | 'True', |
|
446 | 455 | 'bool'), |
|
447 | 456 | ] |
|
448 | 457 | for k, v, t in defaults: |
|
449 | 458 | if (skip_existing and |
|
450 | 459 | SettingsModel().get_setting_by_name(k) is not None): |
|
451 | 460 | log.debug('Skipping option %s', k) |
|
452 | 461 | continue |
|
453 | 462 | setting = RhodeCodeSetting(k, v, t) |
|
454 | 463 | self.sa.add(setting) |
|
455 | 464 | |
|
456 | 465 | def create_default_options(self, skip_existing=False): |
|
457 | 466 | """Creates default settings""" |
|
458 | 467 | |
|
459 | 468 | for k, v, t in [ |
|
460 | 469 | ('default_repo_enable_locking', False, 'bool'), |
|
461 | 470 | ('default_repo_enable_downloads', False, 'bool'), |
|
462 | 471 | ('default_repo_enable_statistics', False, 'bool'), |
|
463 | 472 | ('default_repo_private', False, 'bool'), |
|
464 | 473 | ('default_repo_type', 'hg', 'unicode')]: |
|
465 | 474 | |
|
466 | 475 | if (skip_existing and |
|
467 | 476 | SettingsModel().get_setting_by_name(k) is not None): |
|
468 | 477 | log.debug('Skipping option %s', k) |
|
469 | 478 | continue |
|
470 | 479 | setting = RhodeCodeSetting(k, v, t) |
|
471 | 480 | self.sa.add(setting) |
|
472 | 481 | |
|
473 | 482 | def fixup_groups(self): |
|
474 | 483 | def_usr = User.get_default_user() |
|
475 | 484 | for g in RepoGroup.query().all(): |
|
476 | 485 | g.group_name = g.get_new_name(g.name) |
|
477 | 486 | self.sa.add(g) |
|
478 | 487 | # get default perm |
|
479 | 488 | default = UserRepoGroupToPerm.query()\ |
|
480 | 489 | .filter(UserRepoGroupToPerm.group == g)\ |
|
481 | 490 | .filter(UserRepoGroupToPerm.user == def_usr)\ |
|
482 | 491 | .scalar() |
|
483 | 492 | |
|
484 | 493 | if default is None: |
|
485 | 494 | log.debug('missing default permission for group %s adding', g) |
|
486 | 495 | perm_obj = RepoGroupModel()._create_default_perms(g) |
|
487 | 496 | self.sa.add(perm_obj) |
|
488 | 497 | |
|
489 | 498 | def reset_permissions(self, username): |
|
490 | 499 | """ |
|
491 | 500 | Resets permissions to default state, useful when old systems had |
|
492 | 501 | bad permissions, we must clean them up |
|
493 | 502 | |
|
494 | 503 | :param username: |
|
495 | 504 | """ |
|
496 | 505 | default_user = User.get_by_username(username) |
|
497 | 506 | if not default_user: |
|
498 | 507 | return |
|
499 | 508 | |
|
500 | 509 | u2p = UserToPerm.query()\ |
|
501 | 510 | .filter(UserToPerm.user == default_user).all() |
|
502 | 511 | fixed = False |
|
503 | 512 | if len(u2p) != len(Permission.DEFAULT_USER_PERMISSIONS): |
|
504 | 513 | for p in u2p: |
|
505 | 514 | Session().delete(p) |
|
506 | 515 | fixed = True |
|
507 | 516 | self.populate_default_permissions() |
|
508 | 517 | return fixed |
|
509 | 518 | |
|
510 | 519 | def config_prompt(self, test_repo_path='', retries=3): |
|
511 | 520 | defaults = self.cli_args |
|
512 | 521 | _path = defaults.get('repos_location') |
|
513 | 522 | if retries == 3: |
|
514 | 523 | log.info('Setting up repositories config') |
|
515 | 524 | |
|
516 | 525 | if _path is not None: |
|
517 | 526 | path = _path |
|
518 | 527 | elif not self.tests and not test_repo_path: |
|
519 | 528 | path = eval(input( |
|
520 | 529 | 'Enter a valid absolute path to store repositories. ' |
|
521 | 530 | 'All repositories in that path will be added automatically:' |
|
522 | 531 | )) |
|
523 | 532 | else: |
|
524 | 533 | path = test_repo_path |
|
525 | 534 | path_ok = True |
|
526 | 535 | |
|
527 | 536 | # check proper dir |
|
528 | 537 | if not os.path.isdir(path): |
|
529 | 538 | path_ok = False |
|
530 | 539 | log.error('Given path %s is not a valid directory', path) |
|
531 | 540 | |
|
532 | 541 | elif not os.path.isabs(path): |
|
533 | 542 | path_ok = False |
|
534 | 543 | log.error('Given path %s is not an absolute path', path) |
|
535 | 544 | |
|
536 | 545 | # check if path is at least readable. |
|
537 | 546 | if not os.access(path, os.R_OK): |
|
538 | 547 | path_ok = False |
|
539 | 548 | log.error('Given path %s is not readable', path) |
|
540 | 549 | |
|
541 | 550 | # check write access, warn user about non writeable paths |
|
542 | 551 | elif not os.access(path, os.W_OK) and path_ok: |
|
543 | 552 | log.warning('No write permission to given path %s', path) |
|
544 | 553 | |
|
545 |
q = ('Given path |
|
|
546 |
'continue with read only mode ? [y/n]' |
|
|
554 | q = (f'Given path {path} is not writeable, do you want to ' | |
|
555 | f'continue with read only mode ? [y/n]') | |
|
547 | 556 | if not self.ask_ok(q): |
|
548 | 557 | log.error('Canceled by user') |
|
549 | 558 | sys.exit(-1) |
|
550 | 559 | |
|
551 | 560 | if retries == 0: |
|
552 | 561 | sys.exit('max retries reached') |
|
553 | 562 | if not path_ok: |
|
554 | 563 | retries -= 1 |
|
555 | 564 | return self.config_prompt(test_repo_path, retries) |
|
556 | 565 | |
|
557 | 566 | real_path = os.path.normpath(os.path.realpath(path)) |
|
558 | 567 | |
|
559 | 568 | if real_path != os.path.normpath(path): |
|
560 | q = ('Path looks like a symlink, RhodeCode Enterprise will store ' | |
|
561 |
'given path as |
|
|
569 | q = (f'Path looks like a symlink, RhodeCode Enterprise will store ' | |
|
570 | f'given path as {real_path} ? [y/n]') | |
|
562 | 571 | if not self.ask_ok(q): |
|
563 | 572 | log.error('Canceled by user') |
|
564 | 573 | sys.exit(-1) |
|
565 | 574 | |
|
566 | 575 | return real_path |
|
567 | 576 | |
|
568 | 577 | def create_settings(self, path): |
|
569 | 578 | |
|
570 | 579 | self.create_ui_settings(path) |
|
571 | 580 | |
|
572 | 581 | ui_config = [ |
|
573 | 582 | ('web', 'push_ssl', 'False'), |
|
574 | 583 | ('web', 'allow_archive', 'gz zip bz2'), |
|
575 | 584 | ('web', 'allow_push', '*'), |
|
576 | 585 | ('web', 'baseurl', '/'), |
|
577 | 586 | ('paths', '/', path), |
|
578 | 587 | ('phases', 'publish', 'True') |
|
579 | 588 | ] |
|
580 | 589 | for section, key, value in ui_config: |
|
581 | 590 | ui_conf = RhodeCodeUi() |
|
582 | 591 | setattr(ui_conf, 'ui_section', section) |
|
583 | 592 | setattr(ui_conf, 'ui_key', key) |
|
584 | 593 | setattr(ui_conf, 'ui_value', value) |
|
585 | 594 | self.sa.add(ui_conf) |
|
586 | 595 | |
|
587 | 596 | # rhodecode app settings |
|
588 | 597 | settings = [ |
|
589 | 598 | ('realm', 'RhodeCode', 'unicode'), |
|
590 | 599 | ('title', '', 'unicode'), |
|
591 | 600 | ('pre_code', '', 'unicode'), |
|
592 | 601 | ('post_code', '', 'unicode'), |
|
593 | 602 | |
|
594 | 603 | # Visual |
|
595 | 604 | ('show_public_icon', True, 'bool'), |
|
596 | 605 | ('show_private_icon', True, 'bool'), |
|
597 | 606 | ('stylify_metatags', True, 'bool'), |
|
598 | 607 | ('dashboard_items', 100, 'int'), |
|
599 | 608 | ('admin_grid_items', 25, 'int'), |
|
600 | 609 | |
|
601 | 610 | ('markup_renderer', 'markdown', 'unicode'), |
|
602 | 611 | |
|
603 | 612 | ('repository_fields', True, 'bool'), |
|
604 | 613 | ('show_version', True, 'bool'), |
|
605 | 614 | ('show_revision_number', True, 'bool'), |
|
606 | 615 | ('show_sha_length', 12, 'int'), |
|
607 | 616 | |
|
608 | 617 | ('use_gravatar', False, 'bool'), |
|
609 | 618 | ('gravatar_url', User.DEFAULT_GRAVATAR_URL, 'unicode'), |
|
610 | 619 | |
|
611 | 620 | ('clone_uri_tmpl', Repository.DEFAULT_CLONE_URI, 'unicode'), |
|
612 | 621 | ('clone_uri_id_tmpl', Repository.DEFAULT_CLONE_URI_ID, 'unicode'), |
|
613 | 622 | ('clone_uri_ssh_tmpl', Repository.DEFAULT_CLONE_URI_SSH, 'unicode'), |
|
614 | 623 | ('support_url', '', 'unicode'), |
|
615 | 624 | ('update_url', RhodeCodeSetting.DEFAULT_UPDATE_URL, 'unicode'), |
|
616 | 625 | |
|
617 | 626 | # VCS Settings |
|
618 | 627 | ('pr_merge_enabled', True, 'bool'), |
|
619 | 628 | ('use_outdated_comments', True, 'bool'), |
|
620 | 629 | ('diff_cache', True, 'bool'), |
|
621 | 630 | ] |
|
622 | 631 | |
|
623 | 632 | for key, val, type_ in settings: |
|
624 | 633 | sett = RhodeCodeSetting(key, val, type_) |
|
625 | 634 | self.sa.add(sett) |
|
626 | 635 | |
|
627 | 636 | self.create_auth_plugin_options() |
|
628 | 637 | self.create_default_options() |
|
629 | 638 | |
|
630 | 639 | log.info('created ui config') |
|
631 | 640 | |
|
632 | 641 | def create_user(self, username, password, email='', admin=False, |
|
633 | 642 | strict_creation_check=True, api_key=None): |
|
634 | 643 | log.info('creating user `%s`', username) |
|
635 | 644 | user = UserModel().create_or_update( |
|
636 | 645 | username, password, email, firstname='RhodeCode', lastname='Admin', |
|
637 | 646 | active=True, admin=admin, extern_type="rhodecode", |
|
638 | 647 | strict_creation_check=strict_creation_check) |
|
639 | 648 | |
|
640 | 649 | if api_key: |
|
641 | 650 | log.info('setting a new default auth token for user `%s`', username) |
|
642 | 651 | UserModel().add_auth_token( |
|
643 | 652 | user=user, lifetime_minutes=-1, |
|
644 | 653 | role=UserModel.auth_token_role.ROLE_ALL, |
|
645 | 654 | description='BUILTIN TOKEN') |
|
646 | 655 | |
|
647 | 656 | def create_default_user(self): |
|
648 | 657 | log.info('creating default user') |
|
649 | 658 | # create default user for handling default permissions. |
|
650 | 659 | user = UserModel().create_or_update(username=User.DEFAULT_USER, |
|
651 | 660 | password=str(uuid.uuid1())[:20], |
|
652 | 661 | email=User.DEFAULT_USER_EMAIL, |
|
653 | 662 | firstname='Anonymous', |
|
654 | 663 | lastname='User', |
|
655 | 664 | strict_creation_check=False) |
|
656 | 665 | # based on configuration options activate/de-activate this user which |
|
657 | 666 | # controls anonymous access |
|
658 | 667 | if self.cli_args.get('public_access') is False: |
|
659 | 668 | log.info('Public access disabled') |
|
660 | 669 | user.active = False |
|
661 | 670 | Session().add(user) |
|
662 | 671 | Session().commit() |
|
663 | 672 | |
|
664 | 673 | def create_permissions(self): |
|
665 | 674 | """ |
|
666 | 675 | Creates all permissions defined in the system |
|
667 | 676 | """ |
|
668 | 677 | # module.(access|create|change|delete)_[name] |
|
669 | 678 | # module.(none|read|write|admin) |
|
670 | 679 | log.info('creating permissions') |
|
671 | 680 | PermissionModel(self.sa).create_permissions() |
|
672 | 681 | |
|
673 | 682 | def populate_default_permissions(self): |
|
674 | 683 | """ |
|
675 | 684 | Populate default permissions. It will create only the default |
|
676 | 685 | permissions that are missing, and not alter already defined ones |
|
677 | 686 | """ |
|
678 | 687 | log.info('creating default user permissions') |
|
679 | 688 | PermissionModel(self.sa).create_default_user_permissions(user=User.DEFAULT_USER) |
@@ -1,182 +1,202 b'' | |||
|
1 | 1 | |
|
2 | 2 | # Copyright (C) 2010-2020 RhodeCode GmbH |
|
3 | 3 | # |
|
4 | 4 | # This program is free software: you can redistribute it and/or modify |
|
5 | 5 | # it under the terms of the GNU Affero General Public License, version 3 |
|
6 | 6 | # (only), as published by the Free Software Foundation. |
|
7 | 7 | # |
|
8 | 8 | # This program is distributed in the hope that it will be useful, |
|
9 | 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
10 | 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
11 | 11 | # GNU General Public License for more details. |
|
12 | 12 | # |
|
13 | 13 | # You should have received a copy of the GNU Affero General Public License |
|
14 | 14 | # along with this program. If not, see <http://www.gnu.org/licenses/>. |
|
15 | 15 | # |
|
16 | 16 | # This program is dual-licensed. If you wish to learn more about the |
|
17 | 17 | # RhodeCode Enterprise Edition, including its added features, Support services, |
|
18 | 18 | # and proprietary license terms, please see https://rhodecode.com/licenses/ |
|
19 | 19 | |
|
20 | 20 | """ |
|
21 | 21 | Set of custom exceptions used in RhodeCode |
|
22 | 22 | """ |
|
23 | 23 | |
|
24 | 24 | from webob.exc import HTTPClientError |
|
25 | 25 | from pyramid.httpexceptions import HTTPBadGateway |
|
26 | 26 | |
|
27 | 27 | |
|
28 | 28 | class LdapUsernameError(Exception): |
|
29 | 29 | pass |
|
30 | 30 | |
|
31 | 31 | |
|
32 | 32 | class LdapPasswordError(Exception): |
|
33 | 33 | pass |
|
34 | 34 | |
|
35 | 35 | |
|
36 | 36 | class LdapConnectionError(Exception): |
|
37 | 37 | pass |
|
38 | 38 | |
|
39 | 39 | |
|
40 | 40 | class LdapImportError(Exception): |
|
41 | 41 | pass |
|
42 | 42 | |
|
43 | 43 | |
|
44 | 44 | class DefaultUserException(Exception): |
|
45 | 45 | pass |
|
46 | 46 | |
|
47 | 47 | |
|
48 | 48 | class UserOwnsReposException(Exception): |
|
49 | 49 | pass |
|
50 | 50 | |
|
51 | 51 | |
|
52 | 52 | class UserOwnsRepoGroupsException(Exception): |
|
53 | 53 | pass |
|
54 | 54 | |
|
55 | 55 | |
|
56 | 56 | class UserOwnsUserGroupsException(Exception): |
|
57 | 57 | pass |
|
58 | 58 | |
|
59 | 59 | |
|
60 | 60 | class UserOwnsPullRequestsException(Exception): |
|
61 | 61 | pass |
|
62 | 62 | |
|
63 | 63 | |
|
64 | 64 | class UserOwnsArtifactsException(Exception): |
|
65 | 65 | pass |
|
66 | 66 | |
|
67 | 67 | |
|
68 | 68 | class UserGroupAssignedException(Exception): |
|
69 | 69 | pass |
|
70 | 70 | |
|
71 | 71 | |
|
72 | 72 | class StatusChangeOnClosedPullRequestError(Exception): |
|
73 | 73 | pass |
|
74 | 74 | |
|
75 | 75 | |
|
76 | 76 | class AttachedForksError(Exception): |
|
77 | 77 | pass |
|
78 | 78 | |
|
79 | 79 | |
|
80 | 80 | class AttachedPullRequestsError(Exception): |
|
81 | 81 | pass |
|
82 | 82 | |
|
83 | 83 | |
|
84 | 84 | class RepoGroupAssignmentError(Exception): |
|
85 | 85 | pass |
|
86 | 86 | |
|
87 | 87 | |
|
88 | 88 | class NonRelativePathError(Exception): |
|
89 | 89 | pass |
|
90 | 90 | |
|
91 | 91 | |
|
92 | 92 | class HTTPRequirementError(HTTPClientError): |
|
93 | 93 | title = explanation = 'Repository Requirement Missing' |
|
94 | 94 | reason = None |
|
95 | 95 | |
|
96 | 96 | def __init__(self, message, *args, **kwargs): |
|
97 | 97 | self.title = self.explanation = message |
|
98 | 98 | super(HTTPRequirementError, self).__init__(*args, **kwargs) |
|
99 | 99 | self.args = (message, ) |
|
100 | 100 | |
|
101 | 101 | |
|
102 | 102 | class HTTPLockedRC(HTTPClientError): |
|
103 | 103 | """ |
|
104 | 104 | Special Exception For locked Repos in RhodeCode, the return code can |
|
105 | 105 | be overwritten by _code keyword argument passed into constructors |
|
106 | 106 | """ |
|
107 | 107 | code = 423 |
|
108 | 108 | title = explanation = 'Repository Locked' |
|
109 | 109 | reason = None |
|
110 | 110 | |
|
111 | 111 | def __init__(self, message, *args, **kwargs): |
|
112 |
|
|
|
113 | from rhodecode.lib.utils2 import safe_int | |
|
114 | _code = CONFIG.get('lock_ret_code') | |
|
115 | self.code = safe_int(_code, self.code) | |
|
112 | import rhodecode | |
|
113 | ||
|
114 | self.code = rhodecode.ConfigGet().get_int('lock_ret_code', missing=self.code) | |
|
115 | ||
|
116 | 116 | self.title = self.explanation = message |
|
117 | 117 | super(HTTPLockedRC, self).__init__(*args, **kwargs) |
|
118 | 118 | self.args = (message, ) |
|
119 | 119 | |
|
120 | 120 | |
|
121 | 121 | class HTTPBranchProtected(HTTPClientError): |
|
122 | 122 | """ |
|
123 | 123 | Special Exception For Indicating that branch is protected in RhodeCode, the |
|
124 | 124 | return code can be overwritten by _code keyword argument passed into constructors |
|
125 | 125 | """ |
|
126 | 126 | code = 403 |
|
127 | 127 | title = explanation = 'Branch Protected' |
|
128 | 128 | reason = None |
|
129 | 129 | |
|
130 | 130 | def __init__(self, message, *args, **kwargs): |
|
131 | 131 | self.title = self.explanation = message |
|
132 | 132 | super(HTTPBranchProtected, self).__init__(*args, **kwargs) |
|
133 | 133 | self.args = (message, ) |
|
134 | 134 | |
|
135 | 135 | |
|
136 | 136 | class IMCCommitError(Exception): |
|
137 | 137 | pass |
|
138 | 138 | |
|
139 | 139 | |
|
140 | 140 | class UserCreationError(Exception): |
|
141 | 141 | pass |
|
142 | 142 | |
|
143 | 143 | |
|
144 | 144 | class NotAllowedToCreateUserError(Exception): |
|
145 | 145 | pass |
|
146 | 146 | |
|
147 | 147 | |
|
148 | 148 | class RepositoryCreationError(Exception): |
|
149 | 149 | pass |
|
150 | 150 | |
|
151 | 151 | |
|
152 | 152 | class VCSServerUnavailable(HTTPBadGateway): |
|
153 | 153 | """ HTTP Exception class for VCS Server errors """ |
|
154 | 154 | code = 502 |
|
155 | 155 | title = 'VCS Server Error' |
|
156 | 156 | causes = [ |
|
157 | 157 | 'VCS Server is not running', |
|
158 | 158 | 'Incorrect vcs.server=host:port', |
|
159 | 159 | 'Incorrect vcs.server.protocol', |
|
160 | 160 | ] |
|
161 | 161 | |
|
162 | 162 | def __init__(self, message=''): |
|
163 | 163 | self.explanation = 'Could not connect to VCS Server' |
|
164 | 164 | if message: |
|
165 | 165 | self.explanation += ': ' + message |
|
166 | 166 | super(VCSServerUnavailable, self).__init__() |
|
167 | 167 | |
|
168 | 168 | |
|
169 | 169 | class ArtifactMetadataDuplicate(ValueError): |
|
170 | 170 | |
|
171 | 171 | def __init__(self, *args, **kwargs): |
|
172 | 172 | self.err_section = kwargs.pop('err_section', None) |
|
173 | 173 | self.err_key = kwargs.pop('err_key', None) |
|
174 | 174 | super(ArtifactMetadataDuplicate, self).__init__(*args, **kwargs) |
|
175 | 175 | |
|
176 | 176 | |
|
177 | 177 | class ArtifactMetadataBadValueType(ValueError): |
|
178 | 178 | pass |
|
179 | 179 | |
|
180 | 180 | |
|
181 | 181 | class CommentVersionMismatch(ValueError): |
|
182 | 182 | pass |
|
183 | ||
|
184 | ||
|
185 | class SignatureVerificationError(ValueError): | |
|
186 | pass | |
|
187 | ||
|
188 | ||
|
189 | def signature_verification_error(msg): | |
|
190 | details = """ | |
|
191 | Encryption signature verification failed. | |
|
192 | Please check your value of secret key, and/or encrypted value stored. | |
|
193 | Secret key stored inside .ini file: | |
|
194 | `rhodecode.encrypted_values.secret` or defaults to | |
|
195 | `beaker.session.secret` | |
|
196 | ||
|
197 | Probably the stored values were encrypted using a different secret then currently set in .ini file | |
|
198 | """ | |
|
199 | ||
|
200 | final_msg = f'{msg}\n{details}' | |
|
201 | return SignatureVerificationError(final_msg) | |
|
202 |
@@ -1,444 +1,443 b'' | |||
|
1 | 1 | # Copyright (c) Django Software Foundation and individual contributors. |
|
2 | 2 | # All rights reserved. |
|
3 | 3 | # |
|
4 | 4 | # Redistribution and use in source and binary forms, with or without modification, |
|
5 | 5 | # are permitted provided that the following conditions are met: |
|
6 | 6 | # |
|
7 | 7 | # 1. Redistributions of source code must retain the above copyright notice, |
|
8 | 8 | # this list of conditions and the following disclaimer. |
|
9 | 9 | # |
|
10 | 10 | # 2. Redistributions in binary form must reproduce the above copyright |
|
11 | 11 | # notice, this list of conditions and the following disclaimer in the |
|
12 | 12 | # documentation and/or other materials provided with the distribution. |
|
13 | 13 | # |
|
14 | 14 | # 3. Neither the name of Django nor the names of its contributors may be used |
|
15 | 15 | # to endorse or promote products derived from this software without |
|
16 | 16 | # specific prior written permission. |
|
17 | 17 | # |
|
18 | 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND |
|
19 | 19 | # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED |
|
20 | 20 | # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
|
21 | 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR |
|
22 | 22 | # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES |
|
23 | 23 | # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; |
|
24 | 24 | # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON |
|
25 | 25 | # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT |
|
26 | 26 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS |
|
27 | 27 | # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
|
28 | 28 | |
|
29 | 29 | """ |
|
30 | 30 | For definitions of the different versions of RSS, see: |
|
31 | 31 | http://web.archive.org/web/20110718035220/http://diveintomark.org/archives/2004/02/04/incompatible-rss |
|
32 | 32 | """ |
|
33 | 33 | |
|
34 | 34 | |
|
35 | 35 | import datetime |
|
36 | 36 | import io |
|
37 | 37 | |
|
38 | import pytz | |
|
39 | 38 | from six.moves.urllib import parse as urlparse |
|
40 | 39 | |
|
41 | 40 | from rhodecode.lib.feedgenerator import datetime_safe |
|
42 | 41 | from rhodecode.lib.feedgenerator.utils import SimplerXMLGenerator, iri_to_uri, force_text |
|
43 | 42 | |
|
44 | 43 | |
|
45 | 44 | #### The following code comes from ``django.utils.feedgenerator`` #### |
|
46 | 45 | |
|
47 | 46 | |
|
48 | 47 | def rfc2822_date(date): |
|
49 | 48 | # We can't use strftime() because it produces locale-dependent results, so |
|
50 | 49 | # we have to map english month and day names manually |
|
51 | 50 | months = ('Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec',) |
|
52 | 51 | days = ('Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun') |
|
53 | 52 | # Support datetime objects older than 1900 |
|
54 | 53 | date = datetime_safe.new_datetime(date) |
|
55 | 54 | # We do this ourselves to be timezone aware, email.Utils is not tz aware. |
|
56 | 55 | dow = days[date.weekday()] |
|
57 | 56 | month = months[date.month - 1] |
|
58 | 57 | time_str = date.strftime('%s, %%d %s %%Y %%H:%%M:%%S ' % (dow, month)) |
|
59 | 58 | |
|
60 | 59 | offset = date.utcoffset() |
|
61 | 60 | # Historically, this function assumes that naive datetimes are in UTC. |
|
62 | 61 | if offset is None: |
|
63 | 62 | return time_str + '-0000' |
|
64 | 63 | else: |
|
65 | 64 | timezone = (offset.days * 24 * 60) + (offset.seconds // 60) |
|
66 | 65 | hour, minute = divmod(timezone, 60) |
|
67 | 66 | return time_str + '%+03d%02d' % (hour, minute) |
|
68 | 67 | |
|
69 | 68 | |
|
70 | 69 | def rfc3339_date(date): |
|
71 | 70 | # Support datetime objects older than 1900 |
|
72 | 71 | date = datetime_safe.new_datetime(date) |
|
73 | 72 | time_str = date.strftime('%Y-%m-%dT%H:%M:%S') |
|
74 | 73 | |
|
75 | 74 | offset = date.utcoffset() |
|
76 | 75 | # Historically, this function assumes that naive datetimes are in UTC. |
|
77 | 76 | if offset is None: |
|
78 | 77 | return time_str + 'Z' |
|
79 | 78 | else: |
|
80 | 79 | timezone = (offset.days * 24 * 60) + (offset.seconds // 60) |
|
81 | 80 | hour, minute = divmod(timezone, 60) |
|
82 | 81 | return time_str + '%+03d:%02d' % (hour, minute) |
|
83 | 82 | |
|
84 | 83 | |
|
85 | 84 | def get_tag_uri(url, date): |
|
86 | 85 | """ |
|
87 | 86 | Creates a TagURI. |
|
88 | 87 | |
|
89 | 88 | See http://web.archive.org/web/20110514113830/http://diveintomark.org/archives/2004/05/28/howto-atom-id |
|
90 | 89 | """ |
|
91 | 90 | bits = urlparse(url) |
|
92 | 91 | d = '' |
|
93 | 92 | if date is not None: |
|
94 | 93 | d = ',%s' % datetime_safe.new_datetime(date).strftime('%Y-%m-%d') |
|
95 | 94 | return 'tag:%s%s:%s/%s' % (bits.hostname, d, bits.path, bits.fragment) |
|
96 | 95 | |
|
97 | 96 | |
|
98 | 97 | class SyndicationFeed(object): |
|
99 | 98 | """Base class for all syndication feeds. Subclasses should provide write()""" |
|
100 | 99 | |
|
101 | 100 | def __init__(self, title, link, description, language=None, author_email=None, |
|
102 | 101 | author_name=None, author_link=None, subtitle=None, categories=None, |
|
103 | 102 | feed_url=None, feed_copyright=None, feed_guid=None, ttl=None, **kwargs): |
|
104 | 103 | def to_unicode(s): |
|
105 | 104 | return force_text(s, strings_only=True) |
|
106 | 105 | if categories: |
|
107 | 106 | categories = [force_text(c) for c in categories] |
|
108 | 107 | if ttl is not None: |
|
109 | 108 | # Force ints to unicode |
|
110 | 109 | ttl = force_text(ttl) |
|
111 | 110 | self.feed = { |
|
112 | 111 | 'title': to_unicode(title), |
|
113 | 112 | 'link': iri_to_uri(link), |
|
114 | 113 | 'description': to_unicode(description), |
|
115 | 114 | 'language': to_unicode(language), |
|
116 | 115 | 'author_email': to_unicode(author_email), |
|
117 | 116 | 'author_name': to_unicode(author_name), |
|
118 | 117 | 'author_link': iri_to_uri(author_link), |
|
119 | 118 | 'subtitle': to_unicode(subtitle), |
|
120 | 119 | 'categories': categories or (), |
|
121 | 120 | 'feed_url': iri_to_uri(feed_url), |
|
122 | 121 | 'feed_copyright': to_unicode(feed_copyright), |
|
123 | 122 | 'id': feed_guid or link, |
|
124 | 123 | 'ttl': ttl, |
|
125 | 124 | } |
|
126 | 125 | self.feed.update(kwargs) |
|
127 | 126 | self.items = [] |
|
128 | 127 | |
|
129 | 128 | def add_item(self, title, link, description, author_email=None, |
|
130 | 129 | author_name=None, author_link=None, pubdate=None, comments=None, |
|
131 | 130 | unique_id=None, unique_id_is_permalink=None, enclosure=None, |
|
132 | 131 | categories=(), item_copyright=None, ttl=None, updateddate=None, |
|
133 | 132 | enclosures=None, **kwargs): |
|
134 | 133 | """ |
|
135 | 134 | Adds an item to the feed. All args are expected to be Python Unicode |
|
136 | 135 | objects except pubdate and updateddate, which are datetime.datetime |
|
137 | 136 | objects, and enclosures, which is an iterable of instances of the |
|
138 | 137 | Enclosure class. |
|
139 | 138 | """ |
|
140 | 139 | def to_unicode(s): |
|
141 | 140 | return force_text(s, strings_only=True) |
|
142 | 141 | if categories: |
|
143 | 142 | categories = [to_unicode(c) for c in categories] |
|
144 | 143 | if ttl is not None: |
|
145 | 144 | # Force ints to unicode |
|
146 | 145 | ttl = force_text(ttl) |
|
147 | 146 | if enclosure is None: |
|
148 | 147 | enclosures = [] if enclosures is None else enclosures |
|
149 | 148 | |
|
150 | 149 | item = { |
|
151 | 150 | 'title': to_unicode(title), |
|
152 | 151 | 'link': iri_to_uri(link), |
|
153 | 152 | 'description': to_unicode(description), |
|
154 | 153 | 'author_email': to_unicode(author_email), |
|
155 | 154 | 'author_name': to_unicode(author_name), |
|
156 | 155 | 'author_link': iri_to_uri(author_link), |
|
157 | 156 | 'pubdate': pubdate, |
|
158 | 157 | 'updateddate': updateddate, |
|
159 | 158 | 'comments': to_unicode(comments), |
|
160 | 159 | 'unique_id': to_unicode(unique_id), |
|
161 | 160 | 'unique_id_is_permalink': unique_id_is_permalink, |
|
162 | 161 | 'enclosures': enclosures, |
|
163 | 162 | 'categories': categories or (), |
|
164 | 163 | 'item_copyright': to_unicode(item_copyright), |
|
165 | 164 | 'ttl': ttl, |
|
166 | 165 | } |
|
167 | 166 | item.update(kwargs) |
|
168 | 167 | self.items.append(item) |
|
169 | 168 | |
|
170 | 169 | def num_items(self): |
|
171 | 170 | return len(self.items) |
|
172 | 171 | |
|
173 | 172 | def root_attributes(self): |
|
174 | 173 | """ |
|
175 | 174 | Return extra attributes to place on the root (i.e. feed/channel) element. |
|
176 | 175 | Called from write(). |
|
177 | 176 | """ |
|
178 | 177 | return {} |
|
179 | 178 | |
|
180 | 179 | def add_root_elements(self, handler): |
|
181 | 180 | """ |
|
182 | 181 | Add elements in the root (i.e. feed/channel) element. Called |
|
183 | 182 | from write(). |
|
184 | 183 | """ |
|
185 | 184 | pass |
|
186 | 185 | |
|
187 | 186 | def item_attributes(self, item): |
|
188 | 187 | """ |
|
189 | 188 | Return extra attributes to place on each item (i.e. item/entry) element. |
|
190 | 189 | """ |
|
191 | 190 | return {} |
|
192 | 191 | |
|
193 | 192 | def add_item_elements(self, handler, item): |
|
194 | 193 | """ |
|
195 | 194 | Add elements on each item (i.e. item/entry) element. |
|
196 | 195 | """ |
|
197 | 196 | pass |
|
198 | 197 | |
|
199 | 198 | def write(self, outfile, encoding): |
|
200 | 199 | """ |
|
201 | 200 | Outputs the feed in the given encoding to outfile, which is a file-like |
|
202 | 201 | object. Subclasses should override this. |
|
203 | 202 | """ |
|
204 | 203 | raise NotImplementedError('subclasses of SyndicationFeed must provide a write() method') |
|
205 | 204 | |
|
206 | 205 | def writeString(self, encoding): |
|
207 | 206 | """ |
|
208 | 207 | Returns the feed in the given encoding as a string. |
|
209 | 208 | """ |
|
210 | 209 | s = io.StringIO() |
|
211 | 210 | self.write(s, encoding) |
|
212 | 211 | return s.getvalue() |
|
213 | 212 | |
|
214 | 213 | def latest_post_date(self): |
|
215 | 214 | """ |
|
216 | 215 | Returns the latest item's pubdate or updateddate. If no items |
|
217 | 216 | have either of these attributes this returns the current UTC date/time. |
|
218 | 217 | """ |
|
219 | 218 | latest_date = None |
|
220 | 219 | date_keys = ('updateddate', 'pubdate') |
|
221 | 220 | |
|
222 | 221 | for item in self.items: |
|
223 | 222 | for date_key in date_keys: |
|
224 | 223 | item_date = item.get(date_key) |
|
225 | 224 | if item_date: |
|
226 | 225 | if latest_date is None or item_date > latest_date: |
|
227 | 226 | latest_date = item_date |
|
228 | 227 | |
|
229 | 228 | # datetime.now(tz=utc) is slower, as documented in django.utils.timezone.now |
|
230 |
return latest_date or datetime.datetime.utcnow().replace(tzinfo= |
|
|
229 | return latest_date or datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc) | |
|
231 | 230 | |
|
232 | 231 | |
|
233 | 232 | class Enclosure(object): |
|
234 | 233 | """Represents an RSS enclosure""" |
|
235 | 234 | def __init__(self, url, length, mime_type): |
|
236 | 235 | """All args are expected to be Python Unicode objects""" |
|
237 | 236 | self.length, self.mime_type = length, mime_type |
|
238 | 237 | self.url = iri_to_uri(url) |
|
239 | 238 | |
|
240 | 239 | |
|
241 | 240 | class RssFeed(SyndicationFeed): |
|
242 | 241 | content_type = 'application/rss+xml; charset=utf-8' |
|
243 | 242 | |
|
244 | 243 | def write(self, outfile, encoding): |
|
245 | 244 | handler = SimplerXMLGenerator(outfile, encoding) |
|
246 | 245 | handler.startDocument() |
|
247 | 246 | handler.startElement("rss", self.rss_attributes()) |
|
248 | 247 | handler.startElement("channel", self.root_attributes()) |
|
249 | 248 | self.add_root_elements(handler) |
|
250 | 249 | self.write_items(handler) |
|
251 | 250 | self.endChannelElement(handler) |
|
252 | 251 | handler.endElement("rss") |
|
253 | 252 | |
|
254 | 253 | def rss_attributes(self): |
|
255 | 254 | return {"version": self._version, |
|
256 | 255 | "xmlns:atom": "http://www.w3.org/2005/Atom"} |
|
257 | 256 | |
|
258 | 257 | def write_items(self, handler): |
|
259 | 258 | for item in self.items: |
|
260 | 259 | handler.startElement('item', self.item_attributes(item)) |
|
261 | 260 | self.add_item_elements(handler, item) |
|
262 | 261 | handler.endElement("item") |
|
263 | 262 | |
|
264 | 263 | def add_root_elements(self, handler): |
|
265 | 264 | handler.addQuickElement("title", self.feed['title']) |
|
266 | 265 | handler.addQuickElement("link", self.feed['link']) |
|
267 | 266 | handler.addQuickElement("description", self.feed['description']) |
|
268 | 267 | if self.feed['feed_url'] is not None: |
|
269 | 268 | handler.addQuickElement("atom:link", None, {"rel": "self", "href": self.feed['feed_url']}) |
|
270 | 269 | if self.feed['language'] is not None: |
|
271 | 270 | handler.addQuickElement("language", self.feed['language']) |
|
272 | 271 | for cat in self.feed['categories']: |
|
273 | 272 | handler.addQuickElement("category", cat) |
|
274 | 273 | if self.feed['feed_copyright'] is not None: |
|
275 | 274 | handler.addQuickElement("copyright", self.feed['feed_copyright']) |
|
276 | 275 | handler.addQuickElement("lastBuildDate", rfc2822_date(self.latest_post_date())) |
|
277 | 276 | if self.feed['ttl'] is not None: |
|
278 | 277 | handler.addQuickElement("ttl", self.feed['ttl']) |
|
279 | 278 | |
|
280 | 279 | def endChannelElement(self, handler): |
|
281 | 280 | handler.endElement("channel") |
|
282 | 281 | |
|
283 | 282 | |
|
284 | 283 | class RssUserland091Feed(RssFeed): |
|
285 | 284 | _version = "0.91" |
|
286 | 285 | |
|
287 | 286 | def add_item_elements(self, handler, item): |
|
288 | 287 | handler.addQuickElement("title", item['title']) |
|
289 | 288 | handler.addQuickElement("link", item['link']) |
|
290 | 289 | if item['description'] is not None: |
|
291 | 290 | handler.addQuickElement("description", item['description']) |
|
292 | 291 | |
|
293 | 292 | |
|
294 | 293 | class Rss201rev2Feed(RssFeed): |
|
295 | 294 | # Spec: http://blogs.law.harvard.edu/tech/rss |
|
296 | 295 | _version = "2.0" |
|
297 | 296 | |
|
298 | 297 | def add_item_elements(self, handler, item): |
|
299 | 298 | handler.addQuickElement("title", item['title']) |
|
300 | 299 | handler.addQuickElement("link", item['link']) |
|
301 | 300 | if item['description'] is not None: |
|
302 | 301 | handler.addQuickElement("description", item['description']) |
|
303 | 302 | |
|
304 | 303 | # Author information. |
|
305 | 304 | if item["author_name"] and item["author_email"]: |
|
306 | 305 | handler.addQuickElement("author", "%s (%s)" % (item['author_email'], item['author_name'])) |
|
307 | 306 | elif item["author_email"]: |
|
308 | 307 | handler.addQuickElement("author", item["author_email"]) |
|
309 | 308 | elif item["author_name"]: |
|
310 | 309 | handler.addQuickElement( |
|
311 | 310 | "dc:creator", item["author_name"], {"xmlns:dc": "http://purl.org/dc/elements/1.1/"} |
|
312 | 311 | ) |
|
313 | 312 | |
|
314 | 313 | if item['pubdate'] is not None: |
|
315 | 314 | handler.addQuickElement("pubDate", rfc2822_date(item['pubdate'])) |
|
316 | 315 | if item['comments'] is not None: |
|
317 | 316 | handler.addQuickElement("comments", item['comments']) |
|
318 | 317 | if item['unique_id'] is not None: |
|
319 | 318 | guid_attrs = {} |
|
320 | 319 | if isinstance(item.get('unique_id_is_permalink'), bool): |
|
321 | 320 | guid_attrs['isPermaLink'] = str(item['unique_id_is_permalink']).lower() |
|
322 | 321 | handler.addQuickElement("guid", item['unique_id'], guid_attrs) |
|
323 | 322 | if item['ttl'] is not None: |
|
324 | 323 | handler.addQuickElement("ttl", item['ttl']) |
|
325 | 324 | |
|
326 | 325 | # Enclosure. |
|
327 | 326 | if item['enclosures']: |
|
328 | 327 | enclosures = list(item['enclosures']) |
|
329 | 328 | if len(enclosures) > 1: |
|
330 | 329 | raise ValueError( |
|
331 | 330 | "RSS feed items may only have one enclosure, see " |
|
332 | 331 | "http://www.rssboard.org/rss-profile#element-channel-item-enclosure" |
|
333 | 332 | ) |
|
334 | 333 | enclosure = enclosures[0] |
|
335 | 334 | handler.addQuickElement('enclosure', '', { |
|
336 | 335 | 'url': enclosure.url, |
|
337 | 336 | 'length': enclosure.length, |
|
338 | 337 | 'type': enclosure.mime_type, |
|
339 | 338 | }) |
|
340 | 339 | |
|
341 | 340 | # Categories. |
|
342 | 341 | for cat in item['categories']: |
|
343 | 342 | handler.addQuickElement("category", cat) |
|
344 | 343 | |
|
345 | 344 | |
|
346 | 345 | class Atom1Feed(SyndicationFeed): |
|
347 | 346 | # Spec: https://tools.ietf.org/html/rfc4287 |
|
348 | 347 | content_type = 'application/atom+xml; charset=utf-8' |
|
349 | 348 | ns = "http://www.w3.org/2005/Atom" |
|
350 | 349 | |
|
351 | 350 | def write(self, outfile, encoding): |
|
352 | 351 | handler = SimplerXMLGenerator(outfile, encoding) |
|
353 | 352 | handler.startDocument() |
|
354 | 353 | handler.startElement('feed', self.root_attributes()) |
|
355 | 354 | self.add_root_elements(handler) |
|
356 | 355 | self.write_items(handler) |
|
357 | 356 | handler.endElement("feed") |
|
358 | 357 | |
|
359 | 358 | def root_attributes(self): |
|
360 | 359 | if self.feed['language'] is not None: |
|
361 | 360 | return {"xmlns": self.ns, "xml:lang": self.feed['language']} |
|
362 | 361 | else: |
|
363 | 362 | return {"xmlns": self.ns} |
|
364 | 363 | |
|
365 | 364 | def add_root_elements(self, handler): |
|
366 | 365 | handler.addQuickElement("title", self.feed['title']) |
|
367 | 366 | handler.addQuickElement("link", "", {"rel": "alternate", "href": self.feed['link']}) |
|
368 | 367 | if self.feed['feed_url'] is not None: |
|
369 | 368 | handler.addQuickElement("link", "", {"rel": "self", "href": self.feed['feed_url']}) |
|
370 | 369 | handler.addQuickElement("id", self.feed['id']) |
|
371 | 370 | handler.addQuickElement("updated", rfc3339_date(self.latest_post_date())) |
|
372 | 371 | if self.feed['author_name'] is not None: |
|
373 | 372 | handler.startElement("author", {}) |
|
374 | 373 | handler.addQuickElement("name", self.feed['author_name']) |
|
375 | 374 | if self.feed['author_email'] is not None: |
|
376 | 375 | handler.addQuickElement("email", self.feed['author_email']) |
|
377 | 376 | if self.feed['author_link'] is not None: |
|
378 | 377 | handler.addQuickElement("uri", self.feed['author_link']) |
|
379 | 378 | handler.endElement("author") |
|
380 | 379 | if self.feed['subtitle'] is not None: |
|
381 | 380 | handler.addQuickElement("subtitle", self.feed['subtitle']) |
|
382 | 381 | for cat in self.feed['categories']: |
|
383 | 382 | handler.addQuickElement("category", "", {"term": cat}) |
|
384 | 383 | if self.feed['feed_copyright'] is not None: |
|
385 | 384 | handler.addQuickElement("rights", self.feed['feed_copyright']) |
|
386 | 385 | |
|
387 | 386 | def write_items(self, handler): |
|
388 | 387 | for item in self.items: |
|
389 | 388 | handler.startElement("entry", self.item_attributes(item)) |
|
390 | 389 | self.add_item_elements(handler, item) |
|
391 | 390 | handler.endElement("entry") |
|
392 | 391 | |
|
393 | 392 | def add_item_elements(self, handler, item): |
|
394 | 393 | handler.addQuickElement("title", item['title']) |
|
395 | 394 | handler.addQuickElement("link", "", {"href": item['link'], "rel": "alternate"}) |
|
396 | 395 | |
|
397 | 396 | if item['pubdate'] is not None: |
|
398 | 397 | handler.addQuickElement('published', rfc3339_date(item['pubdate'])) |
|
399 | 398 | |
|
400 | 399 | if item['updateddate'] is not None: |
|
401 | 400 | handler.addQuickElement('updated', rfc3339_date(item['updateddate'])) |
|
402 | 401 | |
|
403 | 402 | # Author information. |
|
404 | 403 | if item['author_name'] is not None: |
|
405 | 404 | handler.startElement("author", {}) |
|
406 | 405 | handler.addQuickElement("name", item['author_name']) |
|
407 | 406 | if item['author_email'] is not None: |
|
408 | 407 | handler.addQuickElement("email", item['author_email']) |
|
409 | 408 | if item['author_link'] is not None: |
|
410 | 409 | handler.addQuickElement("uri", item['author_link']) |
|
411 | 410 | handler.endElement("author") |
|
412 | 411 | |
|
413 | 412 | # Unique ID. |
|
414 | 413 | if item['unique_id'] is not None: |
|
415 | 414 | unique_id = item['unique_id'] |
|
416 | 415 | else: |
|
417 | 416 | unique_id = get_tag_uri(item['link'], item['pubdate']) |
|
418 | 417 | handler.addQuickElement("id", unique_id) |
|
419 | 418 | |
|
420 | 419 | # Summary. |
|
421 | 420 | if item['description'] is not None: |
|
422 | 421 | handler.addQuickElement("summary", item['description'], {"type": "html"}) |
|
423 | 422 | |
|
424 | 423 | # Enclosures. |
|
425 | 424 | for enclosure in item['enclosures']: |
|
426 | 425 | handler.addQuickElement('link', '', { |
|
427 | 426 | 'rel': 'enclosure', |
|
428 | 427 | 'href': enclosure.url, |
|
429 | 428 | 'length': enclosure.length, |
|
430 | 429 | 'type': enclosure.mime_type, |
|
431 | 430 | }) |
|
432 | 431 | |
|
433 | 432 | # Categories. |
|
434 | 433 | for cat in item['categories']: |
|
435 | 434 | handler.addQuickElement("category", "", {"term": cat}) |
|
436 | 435 | |
|
437 | 436 | # Rights. |
|
438 | 437 | if item['item_copyright'] is not None: |
|
439 | 438 | handler.addQuickElement("rights", item['item_copyright']) |
|
440 | 439 | |
|
441 | 440 | |
|
442 | 441 | # This isolates the decision of what the system default is, so calling code can |
|
443 | 442 | # do "feedgenerator.DefaultFeed" instead of "feedgenerator.Rss201rev2Feed". |
|
444 | 443 | DefaultFeed = Rss201rev2Feed No newline at end of file |
@@ -1,155 +1,155 b'' | |||
|
1 | 1 | |
|
2 | 2 | |
|
3 | 3 | # Copyright (C) 2012-2020 RhodeCode GmbH |
|
4 | 4 | # |
|
5 | 5 | # This program is free software: you can redistribute it and/or modify |
|
6 | 6 | # it under the terms of the GNU Affero General Public License, version 3 |
|
7 | 7 | # (only), as published by the Free Software Foundation. |
|
8 | 8 | # |
|
9 | 9 | # This program is distributed in the hope that it will be useful, |
|
10 | 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
11 | 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
12 | 12 | # GNU General Public License for more details. |
|
13 | 13 | # |
|
14 | 14 | # You should have received a copy of the GNU Affero General Public License |
|
15 | 15 | # along with this program. If not, see <http://www.gnu.org/licenses/>. |
|
16 | 16 | # |
|
17 | 17 | # This program is dual-licensed. If you wish to learn more about the |
|
18 | 18 | # RhodeCode Enterprise Edition, including its added features, Support services, |
|
19 | 19 | # and proprietary license terms, please see https://rhodecode.com/licenses/ |
|
20 | 20 | |
|
21 | 21 | """ |
|
22 | 22 | Index schema for RhodeCode |
|
23 | 23 | """ |
|
24 | 24 | |
|
25 | 25 | import importlib |
|
26 | 26 | import logging |
|
27 | 27 | |
|
28 | 28 | from rhodecode.lib.index.search_utils import normalize_text_for_matching |
|
29 | 29 | |
|
30 | 30 | log = logging.getLogger(__name__) |
|
31 | 31 | |
|
32 | 32 | # leave defaults for backward compat |
|
33 | 33 | default_searcher = 'rhodecode.lib.index.whoosh' |
|
34 | 34 | default_location = '%(here)s/data/index' |
|
35 | 35 | |
|
36 | 36 | ES_VERSION_2 = '2' |
|
37 | 37 | ES_VERSION_6 = '6' |
|
38 | 38 | # for legacy reasons we keep 2 compat as default |
|
39 | 39 | DEFAULT_ES_VERSION = ES_VERSION_2 |
|
40 | 40 | |
|
41 | 41 | try: |
|
42 | 42 | from rhodecode_tools.lib.fts_index.elasticsearch_engine_6 import ES_CONFIG # pragma: no cover |
|
43 | 43 | except ImportError: |
|
44 | 44 | log.warning('rhodecode_tools not available, use of full text search is limited') |
|
45 | 45 | pass |
|
46 | 46 | |
|
47 | 47 | |
|
48 | 48 | class BaseSearcher(object): |
|
49 | 49 | query_lang_doc = '' |
|
50 | 50 | es_version = None |
|
51 | 51 | name = None |
|
52 | 52 | DIRECTION_ASC = 'asc' |
|
53 | 53 | DIRECTION_DESC = 'desc' |
|
54 | 54 | |
|
55 | 55 | def __init__(self): |
|
56 | 56 | pass |
|
57 | 57 | |
|
58 | 58 | def cleanup(self): |
|
59 | 59 | pass |
|
60 | 60 | |
|
61 | 61 | def search(self, query, document_type, search_user, |
|
62 | 62 | repo_name=None, repo_group_name=None, |
|
63 | 63 | raise_on_exc=True): |
|
64 | 64 | raise Exception('NotImplemented') |
|
65 | 65 | |
|
66 | 66 | @staticmethod |
|
67 | 67 | def query_to_mark(query, default_field=None): |
|
68 | 68 | """ |
|
69 | 69 | Formats the query to mark token for jquery.mark.js highlighting. ES could |
|
70 | 70 | have a different format optionally. |
|
71 | 71 | |
|
72 | 72 | :param default_field: |
|
73 | 73 | :param query: |
|
74 | 74 | """ |
|
75 | 75 | return ' '.join(normalize_text_for_matching(query).split()) |
|
76 | 76 | |
|
77 | 77 | @property |
|
78 | 78 | def is_es_6(self): |
|
79 | 79 | return self.es_version == ES_VERSION_6 |
|
80 | 80 | |
|
81 | 81 | def get_handlers(self): |
|
82 | 82 | return {} |
|
83 | 83 | |
|
84 | 84 | @staticmethod |
|
85 | 85 | def extract_search_tags(query): |
|
86 | 86 | return [] |
|
87 | 87 | |
|
88 | 88 | @staticmethod |
|
89 | 89 | def escape_specials(val): |
|
90 | 90 | """ |
|
91 | 91 | Handle and escape reserved chars for search |
|
92 | 92 | """ |
|
93 | 93 | return val |
|
94 | 94 | |
|
95 | 95 | def sort_def(self, search_type, direction, sort_field): |
|
96 | 96 | """ |
|
97 | 97 | Defines sorting for search. This function should decide if for given |
|
98 | 98 | search_type, sorting can be done with sort_field. |
|
99 | 99 | |
|
100 | 100 | It also should translate common sort fields into backend specific. e.g elasticsearch |
|
101 | 101 | """ |
|
102 | 102 | raise NotImplementedError() |
|
103 | 103 | |
|
104 | 104 | @staticmethod |
|
105 | 105 | def get_sort(search_type, search_val): |
|
106 | 106 | """ |
|
107 | 107 | Method used to parse the GET search sort value to a field and direction. |
|
108 | 108 | e.g asc:lines == asc, lines |
|
109 | 109 | |
|
110 | 110 | There's also a legacy support for newfirst/oldfirst which defines commit |
|
111 | 111 | sorting only |
|
112 | 112 | """ |
|
113 | 113 | |
|
114 | 114 | direction = BaseSearcher.DIRECTION_ASC |
|
115 | 115 | sort_field = None |
|
116 | 116 | |
|
117 | 117 | if not search_val: |
|
118 | 118 | return direction, sort_field |
|
119 | 119 | |
|
120 | 120 | if search_val.startswith('asc:'): |
|
121 | 121 | sort_field = search_val[4:] |
|
122 | 122 | direction = BaseSearcher.DIRECTION_ASC |
|
123 | 123 | elif search_val.startswith('desc:'): |
|
124 | 124 | sort_field = search_val[5:] |
|
125 | 125 | direction = BaseSearcher.DIRECTION_DESC |
|
126 | 126 | elif search_val == 'newfirst' and search_type == 'commit': |
|
127 | 127 | sort_field = 'date' |
|
128 | 128 | direction = BaseSearcher.DIRECTION_DESC |
|
129 | 129 | elif search_val == 'oldfirst' and search_type == 'commit': |
|
130 | 130 | sort_field = 'date' |
|
131 | 131 | direction = BaseSearcher.DIRECTION_ASC |
|
132 | 132 | |
|
133 | 133 | return direction, sort_field |
|
134 | 134 | |
|
135 | 135 | |
|
136 | 136 | def search_config(config, prefix='search.'): |
|
137 | 137 | _config = {} |
|
138 | 138 | for key in config.keys(): |
|
139 | 139 | if key.startswith(prefix): |
|
140 | 140 | _config[key[len(prefix):]] = config[key] |
|
141 | 141 | return _config |
|
142 | 142 | |
|
143 | 143 | |
|
144 | 144 | def searcher_from_config(config, prefix='search.'): |
|
145 | 145 | _config = search_config(config, prefix) |
|
146 | 146 | |
|
147 | 147 | if 'location' not in _config: |
|
148 | 148 | _config['location'] = default_location |
|
149 | 149 | if 'es_version' not in _config: |
|
150 | # use old legacy ES version set to 2 | |
|
150 | # use an old legacy ES version set to 2 | |
|
151 | 151 | _config['es_version'] = '2' |
|
152 | 152 | |
|
153 | 153 | imported = importlib.import_module(_config.get('module', default_searcher)) |
|
154 | 154 | searcher = imported.Searcher(config=_config) |
|
155 | 155 | return searcher |
@@ -1,278 +1,173 b'' | |||
|
1 | 1 | |
|
2 | 2 | # Copyright (C) 2010-2020 RhodeCode GmbH |
|
3 | 3 | # |
|
4 | 4 | # This program is free software: you can redistribute it and/or modify |
|
5 | 5 | # it under the terms of the GNU Affero General Public License, version 3 |
|
6 | 6 | # (only), as published by the Free Software Foundation. |
|
7 | 7 | # |
|
8 | 8 | # This program is distributed in the hope that it will be useful, |
|
9 | 9 | # but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
10 | 10 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
11 | 11 | # GNU General Public License for more details. |
|
12 | 12 | # |
|
13 | 13 | # You should have received a copy of the GNU Affero General Public License |
|
14 | 14 | # along with this program. If not, see <http://www.gnu.org/licenses/>. |
|
15 | 15 | # |
|
16 | 16 | # This program is dual-licensed. If you wish to learn more about the |
|
17 | 17 | # RhodeCode Enterprise Edition, including its added features, Support services, |
|
18 | 18 | # and proprietary license terms, please see https://rhodecode.com/licenses/ |
|
19 | 19 | |
|
20 | import collections | |
|
21 | ||
|
22 | 20 | import sqlalchemy |
|
23 | 21 | from sqlalchemy import UnicodeText |
|
24 | from sqlalchemy.ext.mutable import Mutable | |
|
22 | from sqlalchemy.ext.mutable import Mutable, \ | |
|
23 | MutableList as MutationList, \ | |
|
24 | MutableDict as MutationDict | |
|
25 | 25 | |
|
26 |
from rhodecode.lib |
|
|
27 | from rhodecode.lib.utils2 import safe_unicode | |
|
26 | from rhodecode.lib import ext_json | |
|
28 | 27 | |
|
29 | 28 | |
|
30 | 29 | class JsonRaw(str): |
|
31 | 30 | """ |
|
32 | 31 | Allows interacting with a JSON types field using a raw string. |
|
33 | 32 | |
|
34 | 33 | For example:: |
|
35 | 34 | db_instance = JsonTable() |
|
36 | 35 | db_instance.enabled = True |
|
37 | 36 | db_instance.json_data = JsonRaw('{"a": 4}') |
|
38 | 37 | |
|
39 | 38 | This will bypass serialization/checks, and allow storing |
|
40 | 39 | raw values |
|
41 | 40 | """ |
|
42 | 41 | pass |
|
43 | 42 | |
|
44 | 43 | |
|
45 | # Set this to the standard dict if Order is not required | |
|
46 | DictClass = collections.OrderedDict | |
|
47 | ||
|
48 | ||
|
49 | 44 | class JSONEncodedObj(sqlalchemy.types.TypeDecorator): |
|
50 | 45 | """ |
|
51 | 46 | Represents an immutable structure as a json-encoded string. |
|
52 | 47 | |
|
53 | 48 | If default is, for example, a dict, then a NULL value in the |
|
54 | 49 | database will be exposed as an empty dict. |
|
55 | 50 | """ |
|
56 | 51 | |
|
57 | 52 | impl = UnicodeText |
|
58 | 53 | safe = True |
|
59 |
enforce_ |
|
|
54 | enforce_str = True | |
|
60 | 55 | |
|
61 | 56 | def __init__(self, *args, **kwargs): |
|
62 | 57 | self.default = kwargs.pop('default', None) |
|
63 | 58 | self.safe = kwargs.pop('safe_json', self.safe) |
|
64 |
self.enforce_ |
|
|
59 | self.enforce_str = kwargs.pop('enforce_str', self.enforce_str) | |
|
65 | 60 | self.dialect_map = kwargs.pop('dialect_map', {}) |
|
66 | 61 | super(JSONEncodedObj, self).__init__(*args, **kwargs) |
|
67 | 62 | |
|
68 | 63 | def load_dialect_impl(self, dialect): |
|
69 | 64 | if dialect.name in self.dialect_map: |
|
70 | 65 | return dialect.type_descriptor(self.dialect_map[dialect.name]) |
|
71 | 66 | return dialect.type_descriptor(self.impl) |
|
72 | 67 | |
|
73 | 68 | def process_bind_param(self, value, dialect): |
|
74 | 69 | if isinstance(value, JsonRaw): |
|
75 | 70 | value = value |
|
76 | 71 | elif value is not None: |
|
77 | value = json.dumps(value) | |
|
78 | if self.enforce_unicode: | |
|
79 | value = safe_unicode(value) | |
|
72 | if self.enforce_str: | |
|
73 | value = ext_json.str_json(value) | |
|
74 | else: | |
|
75 | value = ext_json.json.dumps(value) | |
|
80 | 76 | return value |
|
81 | 77 | |
|
82 | 78 | def process_result_value(self, value, dialect): |
|
83 | 79 | if self.default is not None and (not value or value == '""'): |
|
84 | 80 | return self.default() |
|
85 | 81 | |
|
86 | 82 | if value is not None: |
|
87 | 83 | try: |
|
88 |
value = json.loads(value |
|
|
89 |
except Exception |
|
|
84 | value = ext_json.json.loads(value) | |
|
85 | except Exception: | |
|
90 | 86 | if self.safe and self.default is not None: |
|
91 | 87 | return self.default() |
|
92 | 88 | else: |
|
93 | 89 | raise |
|
94 | 90 | return value |
|
95 | 91 | |
|
96 | 92 | |
|
97 | 93 | class MutationObj(Mutable): |
|
94 | ||
|
98 | 95 | @classmethod |
|
99 | 96 | def coerce(cls, key, value): |
|
100 | 97 | if isinstance(value, dict) and not isinstance(value, MutationDict): |
|
101 | 98 | return MutationDict.coerce(key, value) |
|
102 | 99 | if isinstance(value, list) and not isinstance(value, MutationList): |
|
103 | 100 | return MutationList.coerce(key, value) |
|
104 | 101 | return value |
|
105 | 102 | |
|
106 | 103 | def de_coerce(self): |
|
107 | 104 | return self |
|
108 | 105 | |
|
109 | 106 | @classmethod |
|
110 | 107 | def _listen_on_attribute(cls, attribute, coerce, parent_cls): |
|
111 | 108 | key = attribute.key |
|
112 | 109 | if parent_cls is not attribute.class_: |
|
113 | 110 | return |
|
114 | 111 | |
|
115 | 112 | # rely on "propagate" here |
|
116 | 113 | parent_cls = attribute.class_ |
|
117 | 114 | |
|
118 | 115 | def load(state, *args): |
|
119 | 116 | val = state.dict.get(key, None) |
|
120 | 117 | if coerce: |
|
121 | 118 | val = cls.coerce(key, val) |
|
122 | 119 | state.dict[key] = val |
|
123 | 120 | if isinstance(val, cls): |
|
124 | 121 | val._parents[state.obj()] = key |
|
125 | 122 | |
|
126 | 123 | def set(target, value, oldvalue, initiator): |
|
127 | 124 | if not isinstance(value, cls): |
|
128 | 125 | value = cls.coerce(key, value) |
|
129 | 126 | if isinstance(value, cls): |
|
130 | 127 | value._parents[target.obj()] = key |
|
131 | 128 | if isinstance(oldvalue, cls): |
|
132 | 129 | oldvalue._parents.pop(target.obj(), None) |
|
133 | 130 | return value |
|
134 | 131 | |
|
135 | 132 | def pickle(state, state_dict): |
|
136 | 133 | val = state.dict.get(key, None) |
|
137 | 134 | if isinstance(val, cls): |
|
138 | 135 | if 'ext.mutable.values' not in state_dict: |
|
139 | 136 | state_dict['ext.mutable.values'] = [] |
|
140 | 137 | state_dict['ext.mutable.values'].append(val) |
|
141 | 138 | |
|
142 | 139 | def unpickle(state, state_dict): |
|
143 | 140 | if 'ext.mutable.values' in state_dict: |
|
144 | 141 | for val in state_dict['ext.mutable.values']: |
|
145 | 142 | val._parents[state.obj()] = key |
|
146 | 143 | |
|
147 | 144 | sqlalchemy.event.listen(parent_cls, 'load', load, raw=True, |
|
148 | 145 | propagate=True) |
|
149 | 146 | sqlalchemy.event.listen(parent_cls, 'refresh', load, raw=True, |
|
150 | 147 | propagate=True) |
|
151 | 148 | sqlalchemy.event.listen(parent_cls, 'pickle', pickle, raw=True, |
|
152 | 149 | propagate=True) |
|
153 | 150 | sqlalchemy.event.listen(attribute, 'set', set, raw=True, retval=True, |
|
154 | 151 | propagate=True) |
|
155 | 152 | sqlalchemy.event.listen(parent_cls, 'unpickle', unpickle, raw=True, |
|
156 | 153 | propagate=True) |
|
157 | 154 | |
|
158 | 155 | |
|
159 | class MutationDict(MutationObj, DictClass): | |
|
160 | @classmethod | |
|
161 | def coerce(cls, key, value): | |
|
162 | """Convert plain dictionary to MutationDict""" | |
|
163 | self = MutationDict( | |
|
164 | (k, MutationObj.coerce(key, v)) for (k, v) in value.items()) | |
|
165 | self._key = key | |
|
166 | return self | |
|
167 | ||
|
168 | def de_coerce(self): | |
|
169 | return dict(self) | |
|
170 | ||
|
171 | def __setitem__(self, key, value): | |
|
172 | # Due to the way OrderedDict works, this is called during __init__. | |
|
173 | # At this time we don't have a key set, but what is more, the value | |
|
174 | # being set has already been coerced. So special case this and skip. | |
|
175 | if hasattr(self, '_key'): | |
|
176 | value = MutationObj.coerce(self._key, value) | |
|
177 | DictClass.__setitem__(self, key, value) | |
|
178 | self.changed() | |
|
179 | ||
|
180 | def __delitem__(self, key): | |
|
181 | DictClass.__delitem__(self, key) | |
|
182 | self.changed() | |
|
183 | ||
|
184 | def __setstate__(self, state): | |
|
185 | self.__dict__ = state | |
|
186 | ||
|
187 | def __reduce_ex__(self, proto): | |
|
188 | # support pickling of MutationDicts | |
|
189 | d = dict(self) | |
|
190 | return (self.__class__, (d,)) | |
|
191 | ||
|
192 | ||
|
193 | class MutationList(MutationObj, list): | |
|
194 | @classmethod | |
|
195 | def coerce(cls, key, value): | |
|
196 | """Convert plain list to MutationList""" | |
|
197 | self = MutationList((MutationObj.coerce(key, v) for v in value)) | |
|
198 | self._key = key | |
|
199 | return self | |
|
200 | ||
|
201 | def de_coerce(self): | |
|
202 | return list(self) | |
|
203 | ||
|
204 | def __setitem__(self, idx, value): | |
|
205 | list.__setitem__(self, idx, MutationObj.coerce(self._key, value)) | |
|
206 | self.changed() | |
|
207 | ||
|
208 | def __setslice__(self, start, stop, values): | |
|
209 | list.__setslice__(self, start, stop, | |
|
210 | (MutationObj.coerce(self._key, v) for v in values)) | |
|
211 | self.changed() | |
|
212 | ||
|
213 | def __delitem__(self, idx): | |
|
214 | list.__delitem__(self, idx) | |
|
215 | self.changed() | |
|
216 | ||
|
217 | def __delslice__(self, start, stop): | |
|
218 | list.__delslice__(self, start, stop) | |
|
219 | self.changed() | |
|
220 | ||
|
221 | def append(self, value): | |
|
222 | list.append(self, MutationObj.coerce(self._key, value)) | |
|
223 | self.changed() | |
|
224 | ||
|
225 | def insert(self, idx, value): | |
|
226 | list.insert(self, idx, MutationObj.coerce(self._key, value)) | |
|
227 | self.changed() | |
|
228 | ||
|
229 | def extend(self, values): | |
|
230 | list.extend(self, (MutationObj.coerce(self._key, v) for v in values)) | |
|
231 | self.changed() | |
|
232 | ||
|
233 | def pop(self, *args, **kw): | |
|
234 | value = list.pop(self, *args, **kw) | |
|
235 | self.changed() | |
|
236 | return value | |
|
237 | ||
|
238 | def remove(self, value): | |
|
239 | list.remove(self, value) | |
|
240 | self.changed() | |
|
241 | ||
|
242 | ||
|
243 | 156 | def JsonType(impl=None, **kwargs): |
|
244 | 157 | """ |
|
245 | 158 | Helper for using a mutation obj, it allows to use .with_variant easily. |
|
246 | 159 | example:: |
|
247 | 160 | |
|
248 | 161 | settings = Column('settings_json', |
|
249 | 162 | MutationObj.as_mutable( |
|
250 | 163 | JsonType(dialect_map=dict(mysql=UnicodeText(16384)))) |
|
251 | 164 | """ |
|
252 | 165 | |
|
253 | 166 | if impl == 'list': |
|
254 | 167 | return JSONEncodedObj(default=list, **kwargs) |
|
255 | 168 | elif impl == 'dict': |
|
256 |
return JSONEncodedObj(default= |
|
|
169 | return JSONEncodedObj(default=dict, **kwargs) | |
|
257 | 170 | else: |
|
258 | 171 | return JSONEncodedObj(**kwargs) |
|
259 | 172 | |
|
260 | 173 | |
|
261 | JSON = MutationObj.as_mutable(JsonType()) | |
|
262 | """ | |
|
263 | A type to encode/decode JSON on the fly | |
|
264 | ||
|
265 | sqltype is the string type for the underlying DB column:: | |
|
266 | ||
|
267 | Column(JSON) (defaults to UnicodeText) | |
|
268 | """ | |
|
269 | ||
|
270 | JSONDict = MutationObj.as_mutable(JsonType('dict')) | |
|
271 | """ | |
|
272 | A type to encode/decode JSON dictionaries on the fly | |
|
273 | """ | |
|
274 | ||
|
275 | JSONList = MutationObj.as_mutable(JsonType('list')) | |
|
276 | """ | |
|
277 | A type to encode/decode JSON lists` on the fly | |
|
278 | """ |
General Comments 0
You need to be logged in to leave comments.
Login now