##// END OF EJS Templates
libs: major refactor for python3
super-admin -
r5085:4eab4aa8 default
parent child Browse files
Show More
@@ -1,610 +1,606 b''
1
1
2 # Copyright (C) 2010-2020 RhodeCode GmbH
2 # Copyright (C) 2010-2020 RhodeCode GmbH
3 #
3 #
4 # This program is free software: you can redistribute it and/or modify
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License, version 3
5 # it under the terms of the GNU Affero General Public License, version 3
6 # (only), as published by the Free Software Foundation.
6 # (only), as published by the Free Software Foundation.
7 #
7 #
8 # This program is distributed in the hope that it will be useful,
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # GNU General Public License for more details.
11 # GNU General Public License for more details.
12 #
12 #
13 # You should have received a copy of the GNU Affero General Public License
13 # You should have received a copy of the GNU Affero General Public License
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 #
15 #
16 # This program is dual-licensed. If you wish to learn more about the
16 # This program is dual-licensed. If you wish to learn more about the
17 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 # and proprietary license terms, please see https://rhodecode.com/licenses/
19
19
20 """
20 """
21 The base Controller API
21 The base Controller API
22 Provides the BaseController class for subclassing. And usage in different
22 Provides the BaseController class for subclassing. And usage in different
23 controllers
23 controllers
24 """
24 """
25
25
26 import logging
26 import logging
27 import socket
27 import socket
28 import base64
28
29
29 import markupsafe
30 import markupsafe
30 import ipaddress
31 import ipaddress
31
32
33 import paste.httpheaders
32 from paste.auth.basic import AuthBasicAuthenticator
34 from paste.auth.basic import AuthBasicAuthenticator
33 from paste.httpexceptions import HTTPUnauthorized, HTTPForbidden, get_exception
35 from paste.httpexceptions import HTTPUnauthorized, HTTPForbidden, get_exception
34 from paste.httpheaders import WWW_AUTHENTICATE, AUTHORIZATION
35
36
36 import rhodecode
37 import rhodecode
37 from rhodecode.authentication.base import VCS_TYPE
38 from rhodecode.authentication.base import VCS_TYPE
38 from rhodecode.lib import auth, utils2
39 from rhodecode.lib import auth, utils2
39 from rhodecode.lib import helpers as h
40 from rhodecode.lib import helpers as h
40 from rhodecode.lib.auth import AuthUser, CookieStoreWrapper
41 from rhodecode.lib.auth import AuthUser, CookieStoreWrapper
41 from rhodecode.lib.exceptions import UserCreationError
42 from rhodecode.lib.exceptions import UserCreationError
42 from rhodecode.lib.utils import (password_changed, get_enabled_hook_classes)
43 from rhodecode.lib.utils import (password_changed, get_enabled_hook_classes)
43 from rhodecode.lib.utils2 import (
44 from rhodecode.lib.utils2 import AttributeDict
44 str2bool, safe_unicode, AttributeDict, safe_int, sha1, aslist, safe_str)
45 from rhodecode.lib.str_utils import ascii_bytes, safe_int, safe_str
46 from rhodecode.lib.type_utils import aslist, str2bool
47 from rhodecode.lib.hash_utils import sha1
45 from rhodecode.model.db import Repository, User, ChangesetComment, UserBookmark
48 from rhodecode.model.db import Repository, User, ChangesetComment, UserBookmark
46 from rhodecode.model.notification import NotificationModel
49 from rhodecode.model.notification import NotificationModel
47 from rhodecode.model.settings import VcsSettingsModel, SettingsModel
50 from rhodecode.model.settings import VcsSettingsModel, SettingsModel
48
51
49 log = logging.getLogger(__name__)
52 log = logging.getLogger(__name__)
50
53
51
54
52 def _filter_proxy(ip):
55 def _filter_proxy(ip):
53 """
56 """
54 Passed in IP addresses in HEADERS can be in a special format of multiple
57 Passed in IP addresses in HEADERS can be in a special format of multiple
55 ips. Those comma separated IPs are passed from various proxies in the
58 ips. Those comma separated IPs are passed from various proxies in the
56 chain of request processing. The left-most being the original client.
59 chain of request processing. The left-most being the original client.
57 We only care about the first IP which came from the org. client.
60 We only care about the first IP which came from the org. client.
58
61
59 :param ip: ip string from headers
62 :param ip: ip string from headers
60 """
63 """
61 if ',' in ip:
64 if ',' in ip:
62 _ips = ip.split(',')
65 _ips = ip.split(',')
63 _first_ip = _ips[0].strip()
66 _first_ip = _ips[0].strip()
64 log.debug('Got multiple IPs %s, using %s', ','.join(_ips), _first_ip)
67 log.debug('Got multiple IPs %s, using %s', ','.join(_ips), _first_ip)
65 return _first_ip
68 return _first_ip
66 return ip
69 return ip
67
70
68
71
69 def _filter_port(ip):
72 def _filter_port(ip):
70 """
73 """
71 Removes a port from ip, there are 4 main cases to handle here.
74 Removes a port from ip, there are 4 main cases to handle here.
72 - ipv4 eg. 127.0.0.1
75 - ipv4 eg. 127.0.0.1
73 - ipv6 eg. ::1
76 - ipv6 eg. ::1
74 - ipv4+port eg. 127.0.0.1:8080
77 - ipv4+port eg. 127.0.0.1:8080
75 - ipv6+port eg. [::1]:8080
78 - ipv6+port eg. [::1]:8080
76
79
77 :param ip:
80 :param ip:
78 """
81 """
79 def is_ipv6(ip_addr):
82 def is_ipv6(ip_addr):
80 if hasattr(socket, 'inet_pton'):
83 if hasattr(socket, 'inet_pton'):
81 try:
84 try:
82 socket.inet_pton(socket.AF_INET6, ip_addr)
85 socket.inet_pton(socket.AF_INET6, ip_addr)
83 except socket.error:
86 except socket.error:
84 return False
87 return False
85 else:
88 else:
86 # fallback to ipaddress
89 # fallback to ipaddress
87 try:
90 try:
88 ipaddress.IPv6Address(safe_str(ip_addr))
91 ipaddress.IPv6Address(safe_str(ip_addr))
89 except Exception:
92 except Exception:
90 return False
93 return False
91 return True
94 return True
92
95
93 if ':' not in ip: # must be ipv4 pure ip
96 if ':' not in ip: # must be ipv4 pure ip
94 return ip
97 return ip
95
98
96 if '[' in ip and ']' in ip: # ipv6 with port
99 if '[' in ip and ']' in ip: # ipv6 with port
97 return ip.split(']')[0][1:].lower()
100 return ip.split(']')[0][1:].lower()
98
101
99 # must be ipv6 or ipv4 with port
102 # must be ipv6 or ipv4 with port
100 if is_ipv6(ip):
103 if is_ipv6(ip):
101 return ip
104 return ip
102 else:
105 else:
103 ip, _port = ip.split(':')[:2] # means ipv4+port
106 ip, _port = ip.split(':')[:2] # means ipv4+port
104 return ip
107 return ip
105
108
106
109
107 def get_ip_addr(environ):
110 def get_ip_addr(environ):
108 proxy_key = 'HTTP_X_REAL_IP'
111 proxy_key = 'HTTP_X_REAL_IP'
109 proxy_key2 = 'HTTP_X_FORWARDED_FOR'
112 proxy_key2 = 'HTTP_X_FORWARDED_FOR'
110 def_key = 'REMOTE_ADDR'
113 def_key = 'REMOTE_ADDR'
111 _filters = lambda x: _filter_port(_filter_proxy(x))
114
115 def ip_filters(ip_):
116 return _filter_port(_filter_proxy(ip_))
112
117
113 ip = environ.get(proxy_key)
118 ip = environ.get(proxy_key)
114 if ip:
119 if ip:
115 return _filters(ip)
120 return ip_filters(ip)
116
121
117 ip = environ.get(proxy_key2)
122 ip = environ.get(proxy_key2)
118 if ip:
123 if ip:
119 return _filters(ip)
124 return ip_filters(ip)
120
125
121 ip = environ.get(def_key, '0.0.0.0')
126 ip = environ.get(def_key, '0.0.0.0')
122 return _filters(ip)
127 return ip_filters(ip)
123
128
124
129
125 def get_server_ip_addr(environ, log_errors=True):
130 def get_server_ip_addr(environ, log_errors=True):
126 hostname = environ.get('SERVER_NAME')
131 hostname = environ.get('SERVER_NAME')
127 try:
132 try:
128 return socket.gethostbyname(hostname)
133 return socket.gethostbyname(hostname)
129 except Exception as e:
134 except Exception as e:
130 if log_errors:
135 if log_errors:
131 # in some cases this lookup is not possible, and we don't want to
136 # in some cases this lookup is not possible, and we don't want to
132 # make it an exception in logs
137 # make it an exception in logs
133 log.exception('Could not retrieve server ip address: %s', e)
138 log.exception('Could not retrieve server ip address: %s', e)
134 return hostname
139 return hostname
135
140
136
141
137 def get_server_port(environ):
142 def get_server_port(environ):
138 return environ.get('SERVER_PORT')
143 return environ.get('SERVER_PORT')
139
144
140
145
141 def get_access_path(environ):
142 path = environ.get('PATH_INFO')
143 org_req = environ.get('pylons.original_request')
144 if org_req:
145 path = org_req.environ.get('PATH_INFO')
146 return path
147
148
146
149 def get_user_agent(environ):
147 def get_user_agent(environ):
150 return environ.get('HTTP_USER_AGENT')
148 return environ.get('HTTP_USER_AGENT')
151
149
152
150
153 def vcs_operation_context(
151 def vcs_operation_context(
154 environ, repo_name, username, action, scm, check_locking=True,
152 environ, repo_name, username, action, scm, check_locking=True,
155 is_shadow_repo=False, check_branch_perms=False, detect_force_push=False):
153 is_shadow_repo=False, check_branch_perms=False, detect_force_push=False):
156 """
154 """
157 Generate the context for a vcs operation, e.g. push or pull.
155 Generate the context for a vcs operation, e.g. push or pull.
158
156
159 This context is passed over the layers so that hooks triggered by the
157 This context is passed over the layers so that hooks triggered by the
160 vcs operation know details like the user, the user's IP address etc.
158 vcs operation know details like the user, the user's IP address etc.
161
159
162 :param check_locking: Allows to switch of the computation of the locking
160 :param check_locking: Allows to switch of the computation of the locking
163 data. This serves mainly the need of the simplevcs middleware to be
161 data. This serves mainly the need of the simplevcs middleware to be
164 able to disable this for certain operations.
162 able to disable this for certain operations.
165
163
166 """
164 """
167 # Tri-state value: False: unlock, None: nothing, True: lock
165 # Tri-state value: False: unlock, None: nothing, True: lock
168 make_lock = None
166 make_lock = None
169 locked_by = [None, None, None]
167 locked_by = [None, None, None]
170 is_anonymous = username == User.DEFAULT_USER
168 is_anonymous = username == User.DEFAULT_USER
171 user = User.get_by_username(username)
169 user = User.get_by_username(username)
172 if not is_anonymous and check_locking:
170 if not is_anonymous and check_locking:
173 log.debug('Checking locking on repository "%s"', repo_name)
171 log.debug('Checking locking on repository "%s"', repo_name)
174 repo = Repository.get_by_repo_name(repo_name)
172 repo = Repository.get_by_repo_name(repo_name)
175 make_lock, __, locked_by = repo.get_locking_state(
173 make_lock, __, locked_by = repo.get_locking_state(
176 action, user.user_id)
174 action, user.user_id)
177 user_id = user.user_id
175 user_id = user.user_id
178 settings_model = VcsSettingsModel(repo=repo_name)
176 settings_model = VcsSettingsModel(repo=repo_name)
179 ui_settings = settings_model.get_ui_settings()
177 ui_settings = settings_model.get_ui_settings()
180
178
181 # NOTE(marcink): This should be also in sync with
179 # NOTE(marcink): This should be also in sync with
182 # rhodecode/apps/ssh_support/lib/backends/base.py:update_environment scm_data
180 # rhodecode/apps/ssh_support/lib/backends/base.py:update_environment scm_data
183 store = [x for x in ui_settings if x.key == '/']
181 store = [x for x in ui_settings if x.key == '/']
184 repo_store = ''
182 repo_store = ''
185 if store:
183 if store:
186 repo_store = store[0].value
184 repo_store = store[0].value
187
185
188 scm_data = {
186 scm_data = {
189 'ip': get_ip_addr(environ),
187 'ip': get_ip_addr(environ),
190 'username': username,
188 'username': username,
191 'user_id': user_id,
189 'user_id': user_id,
192 'action': action,
190 'action': action,
193 'repository': repo_name,
191 'repository': repo_name,
194 'scm': scm,
192 'scm': scm,
195 'config': rhodecode.CONFIG['__file__'],
193 'config': rhodecode.CONFIG['__file__'],
196 'repo_store': repo_store,
194 'repo_store': repo_store,
197 'make_lock': make_lock,
195 'make_lock': make_lock,
198 'locked_by': locked_by,
196 'locked_by': locked_by,
199 'server_url': utils2.get_server_url(environ),
197 'server_url': utils2.get_server_url(environ),
200 'user_agent': get_user_agent(environ),
198 'user_agent': get_user_agent(environ),
201 'hooks': get_enabled_hook_classes(ui_settings),
199 'hooks': get_enabled_hook_classes(ui_settings),
202 'is_shadow_repo': is_shadow_repo,
200 'is_shadow_repo': is_shadow_repo,
203 'detect_force_push': detect_force_push,
201 'detect_force_push': detect_force_push,
204 'check_branch_perms': check_branch_perms,
202 'check_branch_perms': check_branch_perms,
205 }
203 }
206 return scm_data
204 return scm_data
207
205
208
206
209 class BasicAuth(AuthBasicAuthenticator):
207 class BasicAuth(AuthBasicAuthenticator):
210
208
211 def __init__(self, realm, authfunc, registry, auth_http_code=None,
209 def __init__(self, realm, authfunc, registry, auth_http_code=None,
212 initial_call_detection=False, acl_repo_name=None, rc_realm=''):
210 initial_call_detection=False, acl_repo_name=None, rc_realm=''):
211 super(BasicAuth, self).__init__(realm=realm, authfunc=authfunc)
213 self.realm = realm
212 self.realm = realm
214 self.rc_realm = rc_realm
213 self.rc_realm = rc_realm
215 self.initial_call = initial_call_detection
214 self.initial_call = initial_call_detection
216 self.authfunc = authfunc
215 self.authfunc = authfunc
217 self.registry = registry
216 self.registry = registry
218 self.acl_repo_name = acl_repo_name
217 self.acl_repo_name = acl_repo_name
219 self._rc_auth_http_code = auth_http_code
218 self._rc_auth_http_code = auth_http_code
220
219
221 def _get_response_from_code(self, http_code):
220 def _get_response_from_code(self, http_code, fallback):
222 try:
221 try:
223 return get_exception(safe_int(http_code))
222 return get_exception(safe_int(http_code))
224 except Exception:
223 except Exception:
225 log.exception('Failed to fetch response for code %s', http_code)
224 log.exception('Failed to fetch response class for code %s, using fallback: %s', http_code, fallback)
226 return HTTPForbidden
225 return fallback
227
226
228 def get_rc_realm(self):
227 def get_rc_realm(self):
229 return safe_str(self.rc_realm)
228 return safe_str(self.rc_realm)
230
229
231 def build_authentication(self):
230 def build_authentication(self):
232 head = WWW_AUTHENTICATE.tuples('Basic realm="%s"' % self.realm)
231 header = [('WWW-Authenticate', f'Basic realm="{self.realm}"')]
232
233 # NOTE: the initial_Call detection seems to be not working/not needed witg latest Mercurial
234 # investigate if we still need it.
233 if self._rc_auth_http_code and not self.initial_call:
235 if self._rc_auth_http_code and not self.initial_call:
234 # return alternative HTTP code if alternative http return code
236 # return alternative HTTP code if alternative http return code
235 # is specified in RhodeCode config, but ONLY if it's not the
237 # is specified in RhodeCode config, but ONLY if it's not the
236 # FIRST call
238 # FIRST call
237 custom_response_klass = self._get_response_from_code(
239 custom_response_klass = self._get_response_from_code(self._rc_auth_http_code, fallback=HTTPUnauthorized)
238 self._rc_auth_http_code)
240 log.debug('Using custom response class: %s', custom_response_klass)
239 return custom_response_klass(headers=head)
241 return custom_response_klass(headers=header)
240 return HTTPUnauthorized(headers=head)
242 return HTTPUnauthorized(headers=header)
241
243
242 def authenticate(self, environ):
244 def authenticate(self, environ):
243 authorization = AUTHORIZATION(environ)
245 authorization = paste.httpheaders.AUTHORIZATION(environ)
244 if not authorization:
246 if not authorization:
245 return self.build_authentication()
247 return self.build_authentication()
246 (authmeth, auth) = authorization.split(' ', 1)
248 (auth_meth, auth_creds_b64) = authorization.split(' ', 1)
247 if 'basic' != authmeth.lower():
249 if 'basic' != auth_meth.lower():
248 return self.build_authentication()
250 return self.build_authentication()
249 auth = auth.strip().decode('base64')
251
250 _parts = auth.split(':', 1)
252 credentials = safe_str(base64.b64decode(auth_creds_b64.strip()))
253 _parts = credentials.split(':', 1)
251 if len(_parts) == 2:
254 if len(_parts) == 2:
252 username, password = _parts
255 username, password = _parts
253 auth_data = self.authfunc(
256 auth_data = self.authfunc(
254 username, password, environ, VCS_TYPE,
257 username, password, environ, VCS_TYPE,
255 registry=self.registry, acl_repo_name=self.acl_repo_name)
258 registry=self.registry, acl_repo_name=self.acl_repo_name)
256 if auth_data:
259 if auth_data:
257 return {'username': username, 'auth_data': auth_data}
260 return {'username': username, 'auth_data': auth_data}
258 if username and password:
261 if username and password:
259 # we mark that we actually executed authentication once, at
262 # we mark that we actually executed authentication once, at
260 # that point we can use the alternative auth code
263 # that point we can use the alternative auth code
261 self.initial_call = False
264 self.initial_call = False
262
265
263 return self.build_authentication()
266 return self.build_authentication()
264
267
265 __call__ = authenticate
268 __call__ = authenticate
266
269
267
270
268 def calculate_version_hash(config):
271 def calculate_version_hash(config):
269 return sha1(
272 return sha1(
270 config.get('beaker.session.secret', '') +
273 config.get(b'beaker.session.secret', b'') + ascii_bytes(rhodecode.__version__)
271 rhodecode.__version__)[:8]
274 )[:8]
272
275
273
276
274 def get_current_lang(request):
277 def get_current_lang(request):
275 # NOTE(marcink): remove after pyramid move
276 try:
277 return translation.get_lang()[0]
278 except:
279 pass
280
281 return getattr(request, '_LOCALE_', request.locale_name)
278 return getattr(request, '_LOCALE_', request.locale_name)
282
279
283
280
284 def attach_context_attributes(context, request, user_id=None, is_api=None):
281 def attach_context_attributes(context, request, user_id=None, is_api=None):
285 """
282 """
286 Attach variables into template context called `c`.
283 Attach variables into template context called `c`.
287 """
284 """
288 config = request.registry.settings
285 config = request.registry.settings
289
286
290 rc_config = SettingsModel().get_all_settings(cache=True, from_request=False)
287 rc_config = SettingsModel().get_all_settings(cache=True, from_request=False)
291 context.rc_config = rc_config
288 context.rc_config = rc_config
292 context.rhodecode_version = rhodecode.__version__
289 context.rhodecode_version = rhodecode.__version__
293 context.rhodecode_edition = config.get('rhodecode.edition')
290 context.rhodecode_edition = config.get('rhodecode.edition')
294 context.rhodecode_edition_id = config.get('rhodecode.edition_id')
291 context.rhodecode_edition_id = config.get('rhodecode.edition_id')
295 # unique secret + version does not leak the version but keep consistency
292 # unique secret + version does not leak the version but keep consistency
296 context.rhodecode_version_hash = calculate_version_hash(config)
293 context.rhodecode_version_hash = calculate_version_hash(config)
297
294
298 # Default language set for the incoming request
295 # Default language set for the incoming request
299 context.language = get_current_lang(request)
296 context.language = get_current_lang(request)
300
297
301 # Visual options
298 # Visual options
302 context.visual = AttributeDict({})
299 context.visual = AttributeDict({})
303
300
304 # DB stored Visual Items
301 # DB stored Visual Items
305 context.visual.show_public_icon = str2bool(
302 context.visual.show_public_icon = str2bool(
306 rc_config.get('rhodecode_show_public_icon'))
303 rc_config.get('rhodecode_show_public_icon'))
307 context.visual.show_private_icon = str2bool(
304 context.visual.show_private_icon = str2bool(
308 rc_config.get('rhodecode_show_private_icon'))
305 rc_config.get('rhodecode_show_private_icon'))
309 context.visual.stylify_metatags = str2bool(
306 context.visual.stylify_metatags = str2bool(
310 rc_config.get('rhodecode_stylify_metatags'))
307 rc_config.get('rhodecode_stylify_metatags'))
311 context.visual.dashboard_items = safe_int(
308 context.visual.dashboard_items = safe_int(
312 rc_config.get('rhodecode_dashboard_items', 100))
309 rc_config.get('rhodecode_dashboard_items', 100))
313 context.visual.admin_grid_items = safe_int(
310 context.visual.admin_grid_items = safe_int(
314 rc_config.get('rhodecode_admin_grid_items', 100))
311 rc_config.get('rhodecode_admin_grid_items', 100))
315 context.visual.show_revision_number = str2bool(
312 context.visual.show_revision_number = str2bool(
316 rc_config.get('rhodecode_show_revision_number', True))
313 rc_config.get('rhodecode_show_revision_number', True))
317 context.visual.show_sha_length = safe_int(
314 context.visual.show_sha_length = safe_int(
318 rc_config.get('rhodecode_show_sha_length', 100))
315 rc_config.get('rhodecode_show_sha_length', 100))
319 context.visual.repository_fields = str2bool(
316 context.visual.repository_fields = str2bool(
320 rc_config.get('rhodecode_repository_fields'))
317 rc_config.get('rhodecode_repository_fields'))
321 context.visual.show_version = str2bool(
318 context.visual.show_version = str2bool(
322 rc_config.get('rhodecode_show_version'))
319 rc_config.get('rhodecode_show_version'))
323 context.visual.use_gravatar = str2bool(
320 context.visual.use_gravatar = str2bool(
324 rc_config.get('rhodecode_use_gravatar'))
321 rc_config.get('rhodecode_use_gravatar'))
325 context.visual.gravatar_url = rc_config.get('rhodecode_gravatar_url')
322 context.visual.gravatar_url = rc_config.get('rhodecode_gravatar_url')
326 context.visual.default_renderer = rc_config.get(
323 context.visual.default_renderer = rc_config.get(
327 'rhodecode_markup_renderer', 'rst')
324 'rhodecode_markup_renderer', 'rst')
328 context.visual.comment_types = ChangesetComment.COMMENT_TYPES
325 context.visual.comment_types = ChangesetComment.COMMENT_TYPES
329 context.visual.rhodecode_support_url = \
326 context.visual.rhodecode_support_url = \
330 rc_config.get('rhodecode_support_url') or h.route_url('rhodecode_support')
327 rc_config.get('rhodecode_support_url') or h.route_url('rhodecode_support')
331
328
332 context.visual.affected_files_cut_off = 60
329 context.visual.affected_files_cut_off = 60
333
330
334 context.pre_code = rc_config.get('rhodecode_pre_code')
331 context.pre_code = rc_config.get('rhodecode_pre_code')
335 context.post_code = rc_config.get('rhodecode_post_code')
332 context.post_code = rc_config.get('rhodecode_post_code')
336 context.rhodecode_name = rc_config.get('rhodecode_title')
333 context.rhodecode_name = rc_config.get('rhodecode_title')
337 context.default_encodings = aslist(config.get('default_encoding'), sep=',')
334 context.default_encodings = aslist(config.get('default_encoding'), sep=',')
338 # if we have specified default_encoding in the request, it has more
335 # if we have specified default_encoding in the request, it has more
339 # priority
336 # priority
340 if request.GET.get('default_encoding'):
337 if request.GET.get('default_encoding'):
341 context.default_encodings.insert(0, request.GET.get('default_encoding'))
338 context.default_encodings.insert(0, request.GET.get('default_encoding'))
342 context.clone_uri_tmpl = rc_config.get('rhodecode_clone_uri_tmpl')
339 context.clone_uri_tmpl = rc_config.get('rhodecode_clone_uri_tmpl')
343 context.clone_uri_id_tmpl = rc_config.get('rhodecode_clone_uri_id_tmpl')
340 context.clone_uri_id_tmpl = rc_config.get('rhodecode_clone_uri_id_tmpl')
344 context.clone_uri_ssh_tmpl = rc_config.get('rhodecode_clone_uri_ssh_tmpl')
341 context.clone_uri_ssh_tmpl = rc_config.get('rhodecode_clone_uri_ssh_tmpl')
345
342
346 # INI stored
343 # INI stored
347 context.labs_active = str2bool(
344 context.labs_active = str2bool(
348 config.get('labs_settings_active', 'false'))
345 config.get('labs_settings_active', 'false'))
349 context.ssh_enabled = str2bool(
346 context.ssh_enabled = str2bool(
350 config.get('ssh.generate_authorized_keyfile', 'false'))
347 config.get('ssh.generate_authorized_keyfile', 'false'))
351 context.ssh_key_generator_enabled = str2bool(
348 context.ssh_key_generator_enabled = str2bool(
352 config.get('ssh.enable_ui_key_generator', 'true'))
349 config.get('ssh.enable_ui_key_generator', 'true'))
353
350
354 context.visual.allow_repo_location_change = str2bool(
351 context.visual.allow_repo_location_change = str2bool(
355 config.get('allow_repo_location_change', True))
352 config.get('allow_repo_location_change', True))
356 context.visual.allow_custom_hooks_settings = str2bool(
353 context.visual.allow_custom_hooks_settings = str2bool(
357 config.get('allow_custom_hooks_settings', True))
354 config.get('allow_custom_hooks_settings', True))
358 context.debug_style = str2bool(config.get('debug_style', False))
355 context.debug_style = str2bool(config.get('debug_style', False))
359
356
360 context.rhodecode_instanceid = config.get('instance_id')
357 context.rhodecode_instanceid = config.get('instance_id')
361
358
362 context.visual.cut_off_limit_diff = safe_int(
359 context.visual.cut_off_limit_diff = safe_int(
363 config.get('cut_off_limit_diff'))
360 config.get('cut_off_limit_diff'), default=0)
364 context.visual.cut_off_limit_file = safe_int(
361 context.visual.cut_off_limit_file = safe_int(
365 config.get('cut_off_limit_file'))
362 config.get('cut_off_limit_file'), default=0)
366
363
367 context.license = AttributeDict({})
364 context.license = AttributeDict({})
368 context.license.hide_license_info = str2bool(
365 context.license.hide_license_info = str2bool(
369 config.get('license.hide_license_info', False))
366 config.get('license.hide_license_info', False))
370
367
371 # AppEnlight
368 # AppEnlight
372 context.appenlight_enabled = config.get('appenlight', False)
369 context.appenlight_enabled = config.get('appenlight', False)
373 context.appenlight_api_public_key = config.get(
370 context.appenlight_api_public_key = config.get(
374 'appenlight.api_public_key', '')
371 'appenlight.api_public_key', '')
375 context.appenlight_server_url = config.get('appenlight.server_url', '')
372 context.appenlight_server_url = config.get('appenlight.server_url', '')
376
373
377 diffmode = {
374 diffmode = {
378 "unified": "unified",
375 "unified": "unified",
379 "sideside": "sideside"
376 "sideside": "sideside"
380 }.get(request.GET.get('diffmode'))
377 }.get(request.GET.get('diffmode'))
381
378
382 if is_api is not None:
379 if is_api is not None:
383 is_api = hasattr(request, 'rpc_user')
380 is_api = hasattr(request, 'rpc_user')
384 session_attrs = {
381 session_attrs = {
385 # defaults
382 # defaults
386 "clone_url_format": "http",
383 "clone_url_format": "http",
387 "diffmode": "sideside",
384 "diffmode": "sideside",
388 "license_fingerprint": request.session.get('license_fingerprint')
385 "license_fingerprint": request.session.get('license_fingerprint')
389 }
386 }
390
387
391 if not is_api:
388 if not is_api:
392 # don't access pyramid session for API calls
389 # don't access pyramid session for API calls
393 if diffmode and diffmode != request.session.get('rc_user_session_attr.diffmode'):
390 if diffmode and diffmode != request.session.get('rc_user_session_attr.diffmode'):
394 request.session['rc_user_session_attr.diffmode'] = diffmode
391 request.session['rc_user_session_attr.diffmode'] = diffmode
395
392
396 # session settings per user
393 # session settings per user
397
394
398 for k, v in request.session.items():
395 for k, v in list(request.session.items()):
399 pref = 'rc_user_session_attr.'
396 pref = 'rc_user_session_attr.'
400 if k and k.startswith(pref):
397 if k and k.startswith(pref):
401 k = k[len(pref):]
398 k = k[len(pref):]
402 session_attrs[k] = v
399 session_attrs[k] = v
403
400
404 context.user_session_attrs = session_attrs
401 context.user_session_attrs = session_attrs
405
402
406 # JS template context
403 # JS template context
407 context.template_context = {
404 context.template_context = {
408 'repo_name': None,
405 'repo_name': None,
409 'repo_type': None,
406 'repo_type': None,
410 'repo_landing_commit': None,
407 'repo_landing_commit': None,
411 'rhodecode_user': {
408 'rhodecode_user': {
412 'username': None,
409 'username': None,
413 'email': None,
410 'email': None,
414 'notification_status': False
411 'notification_status': False
415 },
412 },
416 'session_attrs': session_attrs,
413 'session_attrs': session_attrs,
417 'visual': {
414 'visual': {
418 'default_renderer': None
415 'default_renderer': None
419 },
416 },
420 'commit_data': {
417 'commit_data': {
421 'commit_id': None
418 'commit_id': None
422 },
419 },
423 'pull_request_data': {'pull_request_id': None},
420 'pull_request_data': {'pull_request_id': None},
424 'timeago': {
421 'timeago': {
425 'refresh_time': 120 * 1000,
422 'refresh_time': 120 * 1000,
426 'cutoff_limit': 1000 * 60 * 60 * 24 * 7
423 'cutoff_limit': 1000 * 60 * 60 * 24 * 7
427 },
424 },
428 'pyramid_dispatch': {
425 'pyramid_dispatch': {
429
426
430 },
427 },
431 'extra': {'plugins': {}}
428 'extra': {'plugins': {}}
432 }
429 }
433 # END CONFIG VARS
430 # END CONFIG VARS
434 if is_api:
431 if is_api:
435 csrf_token = None
432 csrf_token = None
436 else:
433 else:
437 csrf_token = auth.get_csrf_token(session=request.session)
434 csrf_token = auth.get_csrf_token(session=request.session)
438
435
439 context.csrf_token = csrf_token
436 context.csrf_token = csrf_token
440 context.backends = rhodecode.BACKENDS.keys()
437 context.backends = list(rhodecode.BACKENDS.keys())
441
438
442 unread_count = 0
439 unread_count = 0
443 user_bookmark_list = []
440 user_bookmark_list = []
444 if user_id:
441 if user_id:
445 unread_count = NotificationModel().get_unread_cnt_for_user(user_id)
442 unread_count = NotificationModel().get_unread_cnt_for_user(user_id)
446 user_bookmark_list = UserBookmark.get_bookmarks_for_user(user_id)
443 user_bookmark_list = UserBookmark.get_bookmarks_for_user(user_id)
447 context.unread_notifications = unread_count
444 context.unread_notifications = unread_count
448 context.bookmark_items = user_bookmark_list
445 context.bookmark_items = user_bookmark_list
449
446
450 # web case
447 # web case
451 if hasattr(request, 'user'):
448 if hasattr(request, 'user'):
452 context.auth_user = request.user
449 context.auth_user = request.user
453 context.rhodecode_user = request.user
450 context.rhodecode_user = request.user
454
451
455 # api case
452 # api case
456 if hasattr(request, 'rpc_user'):
453 if hasattr(request, 'rpc_user'):
457 context.auth_user = request.rpc_user
454 context.auth_user = request.rpc_user
458 context.rhodecode_user = request.rpc_user
455 context.rhodecode_user = request.rpc_user
459
456
460 # attach the whole call context to the request
457 # attach the whole call context to the request
461 request.set_call_context(context)
458 request.set_call_context(context)
462
459
463
460
464 def get_auth_user(request):
461 def get_auth_user(request):
465 environ = request.environ
462 environ = request.environ
466 session = request.session
463 session = request.session
467
464
468 ip_addr = get_ip_addr(environ)
465 ip_addr = get_ip_addr(environ)
469
466
470 # make sure that we update permissions each time we call controller
467 # make sure that we update permissions each time we call controller
471 _auth_token = (
468 _auth_token = (
472 # ?auth_token=XXX
469 # ?auth_token=XXX
473 request.GET.get('auth_token', '')
470 request.GET.get('auth_token', '')
474 # ?api_key=XXX !LEGACY
471 # ?api_key=XXX !LEGACY
475 or request.GET.get('api_key', '')
472 or request.GET.get('api_key', '')
476 # or headers....
473 # or headers....
477 or request.headers.get('X-Rc-Auth-Token', '')
474 or request.headers.get('X-Rc-Auth-Token', '')
478 )
475 )
479 if not _auth_token and request.matchdict:
476 if not _auth_token and request.matchdict:
480 url_auth_token = request.matchdict.get('_auth_token')
477 url_auth_token = request.matchdict.get('_auth_token')
481 _auth_token = url_auth_token
478 _auth_token = url_auth_token
482 if _auth_token:
479 if _auth_token:
483 log.debug('Using URL extracted auth token `...%s`', _auth_token[-4:])
480 log.debug('Using URL extracted auth token `...%s`', _auth_token[-4:])
484
481
485 if _auth_token:
482 if _auth_token:
486 # when using API_KEY we assume user exists, and
483 # when using API_KEY we assume user exists, and
487 # doesn't need auth based on cookies.
484 # doesn't need auth based on cookies.
488 auth_user = AuthUser(api_key=_auth_token, ip_addr=ip_addr)
485 auth_user = AuthUser(api_key=_auth_token, ip_addr=ip_addr)
489 authenticated = False
486 authenticated = False
490 else:
487 else:
491 cookie_store = CookieStoreWrapper(session.get('rhodecode_user'))
488 cookie_store = CookieStoreWrapper(session.get('rhodecode_user'))
492 try:
489 try:
493 auth_user = AuthUser(user_id=cookie_store.get('user_id', None),
490 auth_user = AuthUser(user_id=cookie_store.get('user_id', None),
494 ip_addr=ip_addr)
491 ip_addr=ip_addr)
495 except UserCreationError as e:
492 except UserCreationError as e:
496 h.flash(e, 'error')
493 h.flash(e, 'error')
497 # container auth or other auth functions that create users
494 # container auth or other auth functions that create users
498 # on the fly can throw this exception signaling that there's
495 # on the fly can throw this exception signaling that there's
499 # issue with user creation, explanation should be provided
496 # issue with user creation, explanation should be provided
500 # in Exception itself. We then create a simple blank
497 # in Exception itself. We then create a simple blank
501 # AuthUser
498 # AuthUser
502 auth_user = AuthUser(ip_addr=ip_addr)
499 auth_user = AuthUser(ip_addr=ip_addr)
503
500
504 # in case someone changes a password for user it triggers session
501 # in case someone changes a password for user it triggers session
505 # flush and forces a re-login
502 # flush and forces a re-login
506 if password_changed(auth_user, session):
503 if password_changed(auth_user, session):
507 session.invalidate()
504 session.invalidate()
508 cookie_store = CookieStoreWrapper(session.get('rhodecode_user'))
505 cookie_store = CookieStoreWrapper(session.get('rhodecode_user'))
509 auth_user = AuthUser(ip_addr=ip_addr)
506 auth_user = AuthUser(ip_addr=ip_addr)
510
507
511 authenticated = cookie_store.get('is_authenticated')
508 authenticated = cookie_store.get('is_authenticated')
512
509
513 if not auth_user.is_authenticated and auth_user.is_user_object:
510 if not auth_user.is_authenticated and auth_user.is_user_object:
514 # user is not authenticated and not empty
511 # user is not authenticated and not empty
515 auth_user.set_authenticated(authenticated)
512 auth_user.set_authenticated(authenticated)
516
513
517 return auth_user, _auth_token
514 return auth_user, _auth_token
518
515
519
516
520 def h_filter(s):
517 def h_filter(s):
521 """
518 """
522 Custom filter for Mako templates. Mako by standard uses `markupsafe.escape`
519 Custom filter for Mako templates. Mako by standard uses `markupsafe.escape`
523 we wrap this with additional functionality that converts None to empty
520 we wrap this with additional functionality that converts None to empty
524 strings
521 strings
525 """
522 """
526 if s is None:
523 if s is None:
527 return markupsafe.Markup()
524 return markupsafe.Markup()
528 return markupsafe.escape(s)
525 return markupsafe.escape(s)
529
526
530
527
531 def add_events_routes(config):
528 def add_events_routes(config):
532 """
529 """
533 Adds routing that can be used in events. Because some events are triggered
530 Adds routing that can be used in events. Because some events are triggered
534 outside of pyramid context, we need to bootstrap request with some
531 outside of pyramid context, we need to bootstrap request with some
535 routing registered
532 routing registered
536 """
533 """
537
534
538 from rhodecode.apps._base import ADMIN_PREFIX
535 from rhodecode.apps._base import ADMIN_PREFIX
539
536
540 config.add_route(name='home', pattern='/')
537 config.add_route(name='home', pattern='/')
541 config.add_route(name='main_page_repos_data', pattern='/_home_repos')
538 config.add_route(name='main_page_repos_data', pattern='/_home_repos')
542 config.add_route(name='main_page_repo_groups_data', pattern='/_home_repo_groups')
539 config.add_route(name='main_page_repo_groups_data', pattern='/_home_repo_groups')
543
540
544 config.add_route(name='login', pattern=ADMIN_PREFIX + '/login')
541 config.add_route(name='login', pattern=ADMIN_PREFIX + '/login')
545 config.add_route(name='logout', pattern=ADMIN_PREFIX + '/logout')
542 config.add_route(name='logout', pattern=ADMIN_PREFIX + '/logout')
546 config.add_route(name='repo_summary', pattern='/{repo_name}')
543 config.add_route(name='repo_summary', pattern='/{repo_name}')
547 config.add_route(name='repo_summary_explicit', pattern='/{repo_name}/summary')
544 config.add_route(name='repo_summary_explicit', pattern='/{repo_name}/summary')
548 config.add_route(name='repo_group_home', pattern='/{repo_group_name}')
545 config.add_route(name='repo_group_home', pattern='/{repo_group_name}')
549
546
550 config.add_route(name='pullrequest_show',
547 config.add_route(name='pullrequest_show',
551 pattern='/{repo_name}/pull-request/{pull_request_id}')
548 pattern='/{repo_name}/pull-request/{pull_request_id}')
552 config.add_route(name='pull_requests_global',
549 config.add_route(name='pull_requests_global',
553 pattern='/pull-request/{pull_request_id}')
550 pattern='/pull-request/{pull_request_id}')
554
551
555 config.add_route(name='repo_commit',
552 config.add_route(name='repo_commit',
556 pattern='/{repo_name}/changeset/{commit_id}')
553 pattern='/{repo_name}/changeset/{commit_id}')
557 config.add_route(name='repo_files',
554 config.add_route(name='repo_files',
558 pattern='/{repo_name}/files/{commit_id}/{f_path}')
555 pattern='/{repo_name}/files/{commit_id}/{f_path}')
559
556
560 config.add_route(name='hovercard_user',
557 config.add_route(name='hovercard_user',
561 pattern='/_hovercard/user/{user_id}')
558 pattern='/_hovercard/user/{user_id}')
562
559
563 config.add_route(name='hovercard_user_group',
560 config.add_route(name='hovercard_user_group',
564 pattern='/_hovercard/user_group/{user_group_id}')
561 pattern='/_hovercard/user_group/{user_group_id}')
565
562
566 config.add_route(name='hovercard_pull_request',
563 config.add_route(name='hovercard_pull_request',
567 pattern='/_hovercard/pull_request/{pull_request_id}')
564 pattern='/_hovercard/pull_request/{pull_request_id}')
568
565
569 config.add_route(name='hovercard_repo_commit',
566 config.add_route(name='hovercard_repo_commit',
570 pattern='/_hovercard/commit/{repo_name}/{commit_id}')
567 pattern='/_hovercard/commit/{repo_name}/{commit_id}')
571
568
572
569
573 def bootstrap_config(request, registry_name='RcTestRegistry'):
570 def bootstrap_config(request, registry_name='RcTestRegistry'):
574 import pyramid.testing
571 import pyramid.testing
575 registry = pyramid.testing.Registry(registry_name)
572 registry = pyramid.testing.Registry(registry_name)
576
573
577 config = pyramid.testing.setUp(registry=registry, request=request)
574 config = pyramid.testing.setUp(registry=registry, request=request)
578
575
579 # allow pyramid lookup in testing
576 # allow pyramid lookup in testing
580 config.include('pyramid_mako')
577 config.include('pyramid_mako')
581 config.include('rhodecode.lib.rc_beaker')
578 config.include('rhodecode.lib.rc_beaker')
582 config.include('rhodecode.lib.rc_cache')
579 config.include('rhodecode.lib.rc_cache')
583
580 config.include('rhodecode.lib.rc_cache.archive_cache')
584 add_events_routes(config)
581 add_events_routes(config)
585
582
586 return config
583 return config
587
584
588
585
589 def bootstrap_request(**kwargs):
586 def bootstrap_request(**kwargs):
590 """
587 """
591 Returns a thin version of Request Object that is used in non-web context like testing/celery
588 Returns a thin version of Request Object that is used in non-web context like testing/celery
592 """
589 """
593
590
594 import pyramid.testing
591 import pyramid.testing
595 from rhodecode.lib.request import ThinRequest as _ThinRequest
592 from rhodecode.lib.request import ThinRequest as _ThinRequest
596
593
597 class ThinRequest(_ThinRequest):
594 class ThinRequest(_ThinRequest):
598 application_url = kwargs.pop('application_url', 'http://example.com')
595 application_url = kwargs.pop('application_url', 'http://example.com')
599 host = kwargs.pop('host', 'example.com:80')
596 host = kwargs.pop('host', 'example.com:80')
600 domain = kwargs.pop('domain', 'example.com')
597 domain = kwargs.pop('domain', 'example.com')
601
598
602 class ThinSession(pyramid.testing.DummySession):
599 class ThinSession(pyramid.testing.DummySession):
603 def save(*arg, **kw):
600 def save(*arg, **kw):
604 pass
601 pass
605
602
606 request = ThinRequest(**kwargs)
603 request = ThinRequest(**kwargs)
607 request.session = ThinSession()
604 request.session = ThinSession()
608
605
609 return request
606 return request
610
@@ -1,248 +1,251 b''
1
1
2 # Copyright (C) 2010-2020 RhodeCode GmbH
2 # Copyright (C) 2010-2020 RhodeCode GmbH
3 #
3 #
4 # This program is free software: you can redistribute it and/or modify
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License, version 3
5 # it under the terms of the GNU Affero General Public License, version 3
6 # (only), as published by the Free Software Foundation.
6 # (only), as published by the Free Software Foundation.
7 #
7 #
8 # This program is distributed in the hope that it will be useful,
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # GNU General Public License for more details.
11 # GNU General Public License for more details.
12 #
12 #
13 # You should have received a copy of the GNU Affero General Public License
13 # You should have received a copy of the GNU Affero General Public License
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 #
15 #
16 # This program is dual-licensed. If you wish to learn more about the
16 # This program is dual-licensed. If you wish to learn more about the
17 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 # and proprietary license terms, please see https://rhodecode.com/licenses/
19
19
20 """caching_query.py
20 """caching_query.py
21
21
22 Represent functions and classes
22 Represent functions and classes
23 which allow the usage of Dogpile caching with SQLAlchemy.
23 which allow the usage of Dogpile caching with SQLAlchemy.
24 Introduces a query option called FromCache.
24 Introduces a query option called FromCache.
25
25
26 .. versionchanged:: 1.4 the caching approach has been altered to work
26 .. versionchanged:: 1.4 the caching approach has been altered to work
27 based on a session event.
27 based on a session event.
28
28
29
29
30 The three new concepts introduced here are:
30 The three new concepts introduced here are:
31
31
32 * ORMCache - an extension for an ORM :class:`.Session`
32 * ORMCache - an extension for an ORM :class:`.Session`
33 retrieves results in/from dogpile.cache.
33 retrieves results in/from dogpile.cache.
34 * FromCache - a query option that establishes caching
34 * FromCache - a query option that establishes caching
35 parameters on a Query
35 parameters on a Query
36 * RelationshipCache - a variant of FromCache which is specific
36 * RelationshipCache - a variant of FromCache which is specific
37 to a query invoked during a lazy load.
37 to a query invoked during a lazy load.
38
38
39 The rest of what's here are standard SQLAlchemy and
39 The rest of what's here are standard SQLAlchemy and
40 dogpile.cache constructs.
40 dogpile.cache constructs.
41
41
42 """
42 """
43 from dogpile.cache.api import NO_VALUE
43 from dogpile.cache.api import NO_VALUE
44
44
45 from sqlalchemy import event
45 from sqlalchemy import event
46 from sqlalchemy.orm import loading
46 from sqlalchemy.orm import loading
47 from sqlalchemy.orm.interfaces import UserDefinedOption
47 from sqlalchemy.orm.interfaces import UserDefinedOption
48
48
49
49
50 DEFAULT_REGION = "sql_cache_short"
50 DEFAULT_REGION = "sql_cache_short"
51
51
52
52
53 class ORMCache:
53 class ORMCache:
54
54
55 """An add-on for an ORM :class:`.Session` optionally loads full results
55 """An add-on for an ORM :class:`.Session` optionally loads full results
56 from a dogpile cache region.
56 from a dogpile cache region.
57
57
58 cache = ORMCache(regions={})
58 cache = ORMCache(regions={})
59 cache.listen_on_session(Session)
59 cache.listen_on_session(Session)
60
60
61 """
61 """
62
62
63 def __init__(self, regions):
63 def __init__(self, regions):
64 self.cache_regions = regions or self._get_region()
64 self.cache_regions = regions or self._get_region()
65 self._statement_cache = {}
65 self._statement_cache = {}
66
66
67 @classmethod
67 @classmethod
68 def _get_region(cls):
68 def _get_region(cls):
69 from rhodecode.lib.rc_cache import region_meta
69 from rhodecode.lib.rc_cache import region_meta
70 return region_meta.dogpile_cache_regions
70 return region_meta.dogpile_cache_regions
71
71
72 def listen_on_session(self, session_factory):
72 def listen_on_session(self, session_factory):
73 event.listen(session_factory, "do_orm_execute", self._do_orm_execute)
73 event.listen(session_factory, "do_orm_execute", self._do_orm_execute)
74
74
75 def _do_orm_execute(self, orm_context):
75 def _do_orm_execute(self, orm_context):
76
77 for opt in orm_context.user_defined_options:
76 for opt in orm_context.user_defined_options:
78 if isinstance(opt, RelationshipCache):
77 if isinstance(opt, RelationshipCache):
79 opt = opt._process_orm_context(orm_context)
78 opt = opt._process_orm_context(orm_context)
80 if opt is None:
79 if opt is None:
81 continue
80 continue
82
81
83 if isinstance(opt, FromCache):
82 if isinstance(opt, FromCache):
84 dogpile_region = self.cache_regions[opt.region]
83 dogpile_region = self.cache_regions[opt.region]
85
84
85 if dogpile_region.expiration_time <= 0:
86 # don't cache 0 time expiration cache
87 continue
88
86 if opt.cache_key:
89 if opt.cache_key:
87 our_cache_key = f'SQL_CACHE_{opt.cache_key}'
90 our_cache_key = f'SQL_CACHE_{opt.cache_key}'
88 else:
91 else:
89 our_cache_key = opt._generate_cache_key(
92 our_cache_key = opt._generate_cache_key(
90 orm_context.statement, orm_context.parameters, self
93 orm_context.statement, orm_context.parameters, self
91 )
94 )
92
95
93 if opt.ignore_expiration:
96 if opt.ignore_expiration:
94 cached_value = dogpile_region.get(
97 cached_value = dogpile_region.get(
95 our_cache_key,
98 our_cache_key,
96 expiration_time=opt.expiration_time,
99 expiration_time=opt.expiration_time,
97 ignore_expiration=opt.ignore_expiration,
100 ignore_expiration=opt.ignore_expiration,
98 )
101 )
99 else:
102 else:
100
103
101 def createfunc():
104 def createfunc():
102 return orm_context.invoke_statement().freeze()
105 return orm_context.invoke_statement().freeze()
103
106
104 cached_value = dogpile_region.get_or_create(
107 cached_value = dogpile_region.get_or_create(
105 our_cache_key,
108 our_cache_key,
106 createfunc,
109 createfunc,
107 expiration_time=opt.expiration_time,
110 expiration_time=opt.expiration_time,
108 )
111 )
109
112
110 if cached_value is NO_VALUE:
113 if cached_value is NO_VALUE:
111 # keyerror? this is bigger than a keyerror...
114 # keyerror? this is bigger than a keyerror...
112 raise KeyError()
115 raise KeyError()
113
116
114 orm_result = loading.merge_frozen_result(
117 orm_result = loading.merge_frozen_result(
115 orm_context.session,
118 orm_context.session,
116 orm_context.statement,
119 orm_context.statement,
117 cached_value,
120 cached_value,
118 load=False,
121 load=False,
119 )
122 )
120 return orm_result()
123 return orm_result()
121
124
122 else:
125 else:
123 return None
126 return None
124
127
125 def invalidate(self, statement, parameters, opt):
128 def invalidate(self, statement, parameters, opt):
126 """Invalidate the cache value represented by a statement."""
129 """Invalidate the cache value represented by a statement."""
127
130
128 statement = statement.__clause_element__()
131 statement = statement.__clause_element__()
129
132
130 dogpile_region = self.cache_regions[opt.region]
133 dogpile_region = self.cache_regions[opt.region]
131
134
132 cache_key = opt._generate_cache_key(statement, parameters, self)
135 cache_key = opt._generate_cache_key(statement, parameters, self)
133
136
134 dogpile_region.delete(cache_key)
137 dogpile_region.delete(cache_key)
135
138
136
139
137 class FromCache(UserDefinedOption):
140 class FromCache(UserDefinedOption):
138 """Specifies that a Query should load results from a cache."""
141 """Specifies that a Query should load results from a cache."""
139
142
140 propagate_to_loaders = False
143 propagate_to_loaders = False
141
144
142 def __init__(
145 def __init__(
143 self,
146 self,
144 region=DEFAULT_REGION,
147 region=DEFAULT_REGION,
145 cache_key=None,
148 cache_key=None,
146 expiration_time=None,
149 expiration_time=None,
147 ignore_expiration=False,
150 ignore_expiration=False,
148 ):
151 ):
149 """Construct a new FromCache.
152 """Construct a new FromCache.
150
153
151 :param region: the cache region. Should be a
154 :param region: the cache region. Should be a
152 region configured in the dictionary of dogpile
155 region configured in the dictionary of dogpile
153 regions.
156 regions.
154
157
155 :param cache_key: optional. A string cache key
158 :param cache_key: optional. A string cache key
156 that will serve as the key to the query. Use this
159 that will serve as the key to the query. Use this
157 if your query has a huge amount of parameters (such
160 if your query has a huge amount of parameters (such
158 as when using in_()) which correspond more simply to
161 as when using in_()) which correspond more simply to
159 some other identifier.
162 some other identifier.
160
163
161 """
164 """
162 self.region = region
165 self.region = region
163 self.cache_key = cache_key
166 self.cache_key = cache_key
164 self.expiration_time = expiration_time
167 self.expiration_time = expiration_time
165 self.ignore_expiration = ignore_expiration
168 self.ignore_expiration = ignore_expiration
166
169
167 # this is not needed as of SQLAlchemy 1.4.28;
170 # this is not needed as of SQLAlchemy 1.4.28;
168 # UserDefinedOption classes no longer participate in the SQL
171 # UserDefinedOption classes no longer participate in the SQL
169 # compilation cache key
172 # compilation cache key
170 def _gen_cache_key(self, anon_map, bindparams):
173 def _gen_cache_key(self, anon_map, bindparams):
171 return None
174 return None
172
175
173 def _generate_cache_key(self, statement, parameters, orm_cache):
176 def _generate_cache_key(self, statement, parameters, orm_cache):
174 """generate a cache key with which to key the results of a statement.
177 """generate a cache key with which to key the results of a statement.
175
178
176 This leverages the use of the SQL compilation cache key which is
179 This leverages the use of the SQL compilation cache key which is
177 repurposed as a SQL results key.
180 repurposed as a SQL results key.
178
181
179 """
182 """
180 statement_cache_key = statement._generate_cache_key()
183 statement_cache_key = statement._generate_cache_key()
181
184
182 key = statement_cache_key.to_offline_string(
185 key = statement_cache_key.to_offline_string(
183 orm_cache._statement_cache, statement, parameters
186 orm_cache._statement_cache, statement, parameters
184 ) + repr(self.cache_key)
187 ) + repr(self.cache_key)
185 # print("here's our key...%s" % key)
188 # print("here's our key...%s" % key)
186 return key
189 return key
187
190
188
191
189 class RelationshipCache(FromCache):
192 class RelationshipCache(FromCache):
190 """Specifies that a Query as called within a "lazy load"
193 """Specifies that a Query as called within a "lazy load"
191 should load results from a cache."""
194 should load results from a cache."""
192
195
193 propagate_to_loaders = True
196 propagate_to_loaders = True
194
197
195 def __init__(
198 def __init__(
196 self,
199 self,
197 attribute,
200 attribute,
198 region=DEFAULT_REGION,
201 region=DEFAULT_REGION,
199 cache_key=None,
202 cache_key=None,
200 expiration_time=None,
203 expiration_time=None,
201 ignore_expiration=False,
204 ignore_expiration=False,
202 ):
205 ):
203 """Construct a new RelationshipCache.
206 """Construct a new RelationshipCache.
204
207
205 :param attribute: A Class.attribute which
208 :param attribute: A Class.attribute which
206 indicates a particular class relationship() whose
209 indicates a particular class relationship() whose
207 lazy loader should be pulled from the cache.
210 lazy loader should be pulled from the cache.
208
211
209 :param region: name of the cache region.
212 :param region: name of the cache region.
210
213
211 :param cache_key: optional. A string cache key
214 :param cache_key: optional. A string cache key
212 that will serve as the key to the query, bypassing
215 that will serve as the key to the query, bypassing
213 the usual means of forming a key from the Query itself.
216 the usual means of forming a key from the Query itself.
214
217
215 """
218 """
216 self.region = region
219 self.region = region
217 self.cache_key = cache_key
220 self.cache_key = cache_key
218 self.expiration_time = expiration_time
221 self.expiration_time = expiration_time
219 self.ignore_expiration = ignore_expiration
222 self.ignore_expiration = ignore_expiration
220 self._relationship_options = {
223 self._relationship_options = {
221 (attribute.property.parent.class_, attribute.property.key): self
224 (attribute.property.parent.class_, attribute.property.key): self
222 }
225 }
223
226
224 def _process_orm_context(self, orm_context):
227 def _process_orm_context(self, orm_context):
225 current_path = orm_context.loader_strategy_path
228 current_path = orm_context.loader_strategy_path
226
229
227 if current_path:
230 if current_path:
228 mapper, prop = current_path[-2:]
231 mapper, prop = current_path[-2:]
229 key = prop.key
232 key = prop.key
230
233
231 for cls in mapper.class_.__mro__:
234 for cls in mapper.class_.__mro__:
232 if (cls, key) in self._relationship_options:
235 if (cls, key) in self._relationship_options:
233 relationship_option = self._relationship_options[
236 relationship_option = self._relationship_options[
234 (cls, key)
237 (cls, key)
235 ]
238 ]
236 return relationship_option
239 return relationship_option
237
240
238 def and_(self, option):
241 def and_(self, option):
239 """Chain another RelationshipCache option to this one.
242 """Chain another RelationshipCache option to this one.
240
243
241 While many RelationshipCache objects can be specified on a single
244 While many RelationshipCache objects can be specified on a single
242 Query separately, chaining them together allows for a more efficient
245 Query separately, chaining them together allows for a more efficient
243 lookup during load.
246 lookup during load.
244
247
245 """
248 """
246 self._relationship_options.update(option._relationship_options)
249 self._relationship_options.update(option._relationship_options)
247 return self
250 return self
248
251
@@ -1,372 +1,372 b''
1
1
2
2
3 # Copyright (C) 2016-2020 RhodeCode GmbH
3 # Copyright (C) 2016-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import os
21 import os
22 import itsdangerous
22 import itsdangerous
23 import logging
23 import logging
24 import requests
24 import requests
25 import datetime
25 import datetime
26
26
27 from dogpile.util.readwrite_lock import ReadWriteMutex
27 from dogpile.util.readwrite_lock import ReadWriteMutex
28 from pyramid.threadlocal import get_current_registry
29
28
30 import rhodecode.lib.helpers as h
29 import rhodecode.lib.helpers as h
31 from rhodecode.lib.auth import HasRepoPermissionAny
30 from rhodecode.lib.auth import HasRepoPermissionAny
32 from rhodecode.lib.ext_json import json
31 from rhodecode.lib.ext_json import json
33 from rhodecode.model.db import User
32 from rhodecode.model.db import User
34 from rhodecode.lib.str_utils import ascii_str
33 from rhodecode.lib.str_utils import ascii_str
35 from rhodecode.lib.hash_utils import sha1_safe
34 from rhodecode.lib.hash_utils import sha1_safe
36
35
37 log = logging.getLogger(__name__)
36 log = logging.getLogger(__name__)
38
37
39 LOCK = ReadWriteMutex()
38 LOCK = ReadWriteMutex()
40
39
41 USER_STATE_PUBLIC_KEYS = [
40 USER_STATE_PUBLIC_KEYS = [
42 'id', 'username', 'first_name', 'last_name',
41 'id', 'username', 'first_name', 'last_name',
43 'icon_link', 'display_name', 'display_link']
42 'icon_link', 'display_name', 'display_link']
44
43
45
44
46 class ChannelstreamException(Exception):
45 class ChannelstreamException(Exception):
47 pass
46 pass
48
47
49
48
50 class ChannelstreamConnectionException(ChannelstreamException):
49 class ChannelstreamConnectionException(ChannelstreamException):
51 pass
50 pass
52
51
53
52
54 class ChannelstreamPermissionException(ChannelstreamException):
53 class ChannelstreamPermissionException(ChannelstreamException):
55 pass
54 pass
56
55
57
56
58 def get_channelstream_server_url(config, endpoint):
57 def get_channelstream_server_url(config, endpoint):
59 return 'http://{}{}'.format(config['server'], endpoint)
58 return 'http://{}{}'.format(config['server'], endpoint)
60
59
61
60
62 def channelstream_request(config, payload, endpoint, raise_exc=True):
61 def channelstream_request(config, payload, endpoint, raise_exc=True):
63 signer = itsdangerous.TimestampSigner(config['secret'])
62 signer = itsdangerous.TimestampSigner(config['secret'])
64 sig_for_server = signer.sign(endpoint)
63 sig_for_server = signer.sign(endpoint)
65 secret_headers = {'x-channelstream-secret': sig_for_server,
64 secret_headers = {'x-channelstream-secret': sig_for_server,
66 'x-channelstream-endpoint': endpoint,
65 'x-channelstream-endpoint': endpoint,
67 'Content-Type': 'application/json'}
66 'Content-Type': 'application/json'}
68 req_url = get_channelstream_server_url(config, endpoint)
67 req_url = get_channelstream_server_url(config, endpoint)
69
68
70 log.debug('Sending a channelstream request to endpoint: `%s`', req_url)
69 log.debug('Sending a channelstream request to endpoint: `%s`', req_url)
71 response = None
70 response = None
72 try:
71 try:
73 response = requests.post(req_url, data=json.dumps(payload),
72 response = requests.post(req_url, data=json.dumps(payload),
74 headers=secret_headers).json()
73 headers=secret_headers).json()
75 except requests.ConnectionError:
74 except requests.ConnectionError:
76 log.exception('ConnectionError occurred for endpoint %s', req_url)
75 log.exception('ConnectionError occurred for endpoint %s', req_url)
77 if raise_exc:
76 if raise_exc:
78 raise ChannelstreamConnectionException(req_url)
77 raise ChannelstreamConnectionException(req_url)
79 except Exception:
78 except Exception:
80 log.exception('Exception related to Channelstream happened')
79 log.exception('Exception related to Channelstream happened')
81 if raise_exc:
80 if raise_exc:
82 raise ChannelstreamConnectionException()
81 raise ChannelstreamConnectionException()
83 log.debug('Got channelstream response: %s', response)
82 log.debug('Got channelstream response: %s', response)
84 return response
83 return response
85
84
86
85
87 def get_user_data(user_id):
86 def get_user_data(user_id):
88 user = User.get(user_id)
87 user = User.get(user_id)
89 return {
88 return {
90 'id': user.user_id,
89 'id': user.user_id,
91 'username': user.username,
90 'username': user.username,
92 'first_name': user.first_name,
91 'first_name': user.first_name,
93 'last_name': user.last_name,
92 'last_name': user.last_name,
94 'icon_link': h.gravatar_url(user.email, 60),
93 'icon_link': h.gravatar_url(user.email, 60),
95 'display_name': h.person(user, 'username_or_name_or_email'),
94 'display_name': h.person(user, 'username_or_name_or_email'),
96 'display_link': h.link_to_user(user),
95 'display_link': h.link_to_user(user),
97 'notifications': user.user_data.get('notification_status', True)
96 'notifications': user.user_data.get('notification_status', True)
98 }
97 }
99
98
100
99
101 def broadcast_validator(channel_name):
100 def broadcast_validator(channel_name):
102 """ checks if user can access the broadcast channel """
101 """ checks if user can access the broadcast channel """
103 if channel_name == 'broadcast':
102 if channel_name == 'broadcast':
104 return True
103 return True
105
104
106
105
107 def repo_validator(channel_name):
106 def repo_validator(channel_name):
108 """ checks if user can access the broadcast channel """
107 """ checks if user can access the broadcast channel """
109 channel_prefix = '/repo$'
108 channel_prefix = '/repo$'
110 if channel_name.startswith(channel_prefix):
109 if channel_name.startswith(channel_prefix):
111 elements = channel_name[len(channel_prefix):].split('$')
110 elements = channel_name[len(channel_prefix):].split('$')
112 repo_name = elements[0]
111 repo_name = elements[0]
113 can_access = HasRepoPermissionAny(
112 can_access = HasRepoPermissionAny(
114 'repository.read',
113 'repository.read',
115 'repository.write',
114 'repository.write',
116 'repository.admin')(repo_name)
115 'repository.admin')(repo_name)
117 log.debug(
116 log.debug(
118 'permission check for %s channel resulted in %s',
117 'permission check for %s channel resulted in %s',
119 repo_name, can_access)
118 repo_name, can_access)
120 if can_access:
119 if can_access:
121 return True
120 return True
122 return False
121 return False
123
122
124
123
125 def check_channel_permissions(channels, plugin_validators, should_raise=True):
124 def check_channel_permissions(channels, plugin_validators, should_raise=True):
126 valid_channels = []
125 valid_channels = []
127
126
128 validators = [broadcast_validator, repo_validator]
127 validators = [broadcast_validator, repo_validator]
129 if plugin_validators:
128 if plugin_validators:
130 validators.extend(plugin_validators)
129 validators.extend(plugin_validators)
131 for channel_name in channels:
130 for channel_name in channels:
132 is_valid = False
131 is_valid = False
133 for validator in validators:
132 for validator in validators:
134 if validator(channel_name):
133 if validator(channel_name):
135 is_valid = True
134 is_valid = True
136 break
135 break
137 if is_valid:
136 if is_valid:
138 valid_channels.append(channel_name)
137 valid_channels.append(channel_name)
139 else:
138 else:
140 if should_raise:
139 if should_raise:
141 raise ChannelstreamPermissionException()
140 raise ChannelstreamPermissionException()
142 return valid_channels
141 return valid_channels
143
142
144
143
145 def get_channels_info(self, channels):
144 def get_channels_info(self, channels):
146 payload = {'channels': channels}
145 payload = {'channels': channels}
147 # gather persistence info
146 # gather persistence info
148 return channelstream_request(self._config(), payload, '/info')
147 return channelstream_request(self._config(), payload, '/info')
149
148
150
149
151 def parse_channels_info(info_result, include_channel_info=None):
150 def parse_channels_info(info_result, include_channel_info=None):
152 """
151 """
153 Returns data that contains only secure information that can be
152 Returns data that contains only secure information that can be
154 presented to clients
153 presented to clients
155 """
154 """
156 include_channel_info = include_channel_info or []
155 include_channel_info = include_channel_info or []
157
156
158 user_state_dict = {}
157 user_state_dict = {}
159 for userinfo in info_result['users']:
158 for userinfo in info_result['users']:
160 user_state_dict[userinfo['user']] = {
159 user_state_dict[userinfo['user']] = {
161 k: v for k, v in list(userinfo['state'].items())
160 k: v for k, v in list(userinfo['state'].items())
162 if k in USER_STATE_PUBLIC_KEYS
161 if k in USER_STATE_PUBLIC_KEYS
163 }
162 }
164
163
165 channels_info = {}
164 channels_info = {}
166
165
167 for c_name, c_info in list(info_result['channels'].items()):
166 for c_name, c_info in list(info_result['channels'].items()):
168 if c_name not in include_channel_info:
167 if c_name not in include_channel_info:
169 continue
168 continue
170 connected_list = []
169 connected_list = []
171 for username in c_info['users']:
170 for username in c_info['users']:
172 connected_list.append({
171 connected_list.append({
173 'user': username,
172 'user': username,
174 'state': user_state_dict[username]
173 'state': user_state_dict[username]
175 })
174 })
176 channels_info[c_name] = {'users': connected_list,
175 channels_info[c_name] = {'users': connected_list,
177 'history': c_info['history']}
176 'history': c_info['history']}
178
177
179 return channels_info
178 return channels_info
180
179
181
180
182 def log_filepath(history_location, channel_name):
181 def log_filepath(history_location, channel_name):
183
182
184 channel_hash = ascii_str(sha1_safe(channel_name))
183 channel_hash = ascii_str(sha1_safe(channel_name))
185 filename = f'{channel_hash}.log'
184 filename = f'{channel_hash}.log'
186 filepath = os.path.join(history_location, filename)
185 filepath = os.path.join(history_location, filename)
187 return filepath
186 return filepath
188
187
189
188
190 def read_history(history_location, channel_name):
189 def read_history(history_location, channel_name):
191 filepath = log_filepath(history_location, channel_name)
190 filepath = log_filepath(history_location, channel_name)
192 if not os.path.exists(filepath):
191 if not os.path.exists(filepath):
193 return []
192 return []
194 history_lines_limit = -100
193 history_lines_limit = -100
195 history = []
194 history = []
196 with open(filepath, 'rb') as f:
195 with open(filepath, 'rb') as f:
197 for line in f.readlines()[history_lines_limit:]:
196 for line in f.readlines()[history_lines_limit:]:
198 try:
197 try:
199 history.append(json.loads(line))
198 history.append(json.loads(line))
200 except Exception:
199 except Exception:
201 log.exception('Failed to load history')
200 log.exception('Failed to load history')
202 return history
201 return history
203
202
204
203
205 def update_history_from_logs(config, channels, payload):
204 def update_history_from_logs(config, channels, payload):
206 history_location = config.get('history.location')
205 history_location = config.get('history.location')
207 for channel in channels:
206 for channel in channels:
208 history = read_history(history_location, channel)
207 history = read_history(history_location, channel)
209 payload['channels_info'][channel]['history'] = history
208 payload['channels_info'][channel]['history'] = history
210
209
211
210
212 def write_history(config, message):
211 def write_history(config, message):
213 """ writes a message to a base64encoded filename """
212 """ writes a message to a base64encoded filename """
214 history_location = config.get('history.location')
213 history_location = config.get('history.location')
215 if not os.path.exists(history_location):
214 if not os.path.exists(history_location):
216 return
215 return
217 try:
216 try:
218 LOCK.acquire_write_lock()
217 LOCK.acquire_write_lock()
219 filepath = log_filepath(history_location, message['channel'])
218 filepath = log_filepath(history_location, message['channel'])
220 json_message = json.dumps(message)
219 json_message = json.dumps(message)
221 with open(filepath, 'ab') as f:
220 with open(filepath, 'ab') as f:
222 f.write(json_message)
221 f.write(json_message)
223 f.write('\n')
222 f.write('\n')
224 finally:
223 finally:
225 LOCK.release_write_lock()
224 LOCK.release_write_lock()
226
225
227
226
228 def get_connection_validators(registry):
227 def get_connection_validators(registry):
229 validators = []
228 validators = []
230 for k, config in list(registry.rhodecode_plugins.items()):
229 for k, config in list(registry.rhodecode_plugins.items()):
231 validator = config.get('channelstream', {}).get('connect_validator')
230 validator = config.get('channelstream', {}).get('connect_validator')
232 if validator:
231 if validator:
233 validators.append(validator)
232 validators.append(validator)
234 return validators
233 return validators
235
234
236
235
237 def get_channelstream_config(registry=None):
236 def get_channelstream_config(registry=None):
238 if not registry:
237 if not registry:
238 from pyramid.threadlocal import get_current_registry
239 registry = get_current_registry()
239 registry = get_current_registry()
240
240
241 rhodecode_plugins = getattr(registry, 'rhodecode_plugins', {})
241 rhodecode_plugins = getattr(registry, 'rhodecode_plugins', {})
242 channelstream_config = rhodecode_plugins.get('channelstream', {})
242 channelstream_config = rhodecode_plugins.get('channelstream', {})
243 return channelstream_config
243 return channelstream_config
244
244
245
245
246 def post_message(channel, message, username, registry=None):
246 def post_message(channel, message, username, registry=None):
247 channelstream_config = get_channelstream_config(registry)
247 channelstream_config = get_channelstream_config(registry)
248 if not channelstream_config.get('enabled'):
248 if not channelstream_config.get('enabled'):
249 return
249 return
250
250
251 message_obj = message
251 message_obj = message
252 if isinstance(message, str):
252 if isinstance(message, str):
253 message_obj = {
253 message_obj = {
254 'message': message,
254 'message': message,
255 'level': 'success',
255 'level': 'success',
256 'topic': '/notifications'
256 'topic': '/notifications'
257 }
257 }
258
258
259 log.debug('Channelstream: sending notification to channel %s', channel)
259 log.debug('Channelstream: sending notification to channel %s', channel)
260 payload = {
260 payload = {
261 'type': 'message',
261 'type': 'message',
262 'timestamp': datetime.datetime.utcnow(),
262 'timestamp': datetime.datetime.utcnow(),
263 'user': 'system',
263 'user': 'system',
264 'exclude_users': [username],
264 'exclude_users': [username],
265 'channel': channel,
265 'channel': channel,
266 'message': message_obj
266 'message': message_obj
267 }
267 }
268
268
269 try:
269 try:
270 return channelstream_request(
270 return channelstream_request(
271 channelstream_config, [payload], '/message',
271 channelstream_config, [payload], '/message',
272 raise_exc=False)
272 raise_exc=False)
273 except ChannelstreamException:
273 except ChannelstreamException:
274 log.exception('Failed to send channelstream data')
274 log.exception('Failed to send channelstream data')
275 raise
275 raise
276
276
277
277
278 def _reload_link(label):
278 def _reload_link(label):
279 return (
279 return (
280 '<a onclick="window.location.reload()">'
280 '<a onclick="window.location.reload()">'
281 '<strong>{}</strong>'
281 '<strong>{}</strong>'
282 '</a>'.format(label)
282 '</a>'.format(label)
283 )
283 )
284
284
285
285
286 def pr_channel(pull_request):
286 def pr_channel(pull_request):
287 repo_name = pull_request.target_repo.repo_name
287 repo_name = pull_request.target_repo.repo_name
288 pull_request_id = pull_request.pull_request_id
288 pull_request_id = pull_request.pull_request_id
289 channel = '/repo${}$/pr/{}'.format(repo_name, pull_request_id)
289 channel = '/repo${}$/pr/{}'.format(repo_name, pull_request_id)
290 log.debug('Getting pull-request channelstream broadcast channel: %s', channel)
290 log.debug('Getting pull-request channelstream broadcast channel: %s', channel)
291 return channel
291 return channel
292
292
293
293
294 def comment_channel(repo_name, commit_obj=None, pull_request_obj=None):
294 def comment_channel(repo_name, commit_obj=None, pull_request_obj=None):
295 channel = None
295 channel = None
296 if commit_obj:
296 if commit_obj:
297 channel = '/repo${}$/commit/{}'.format(
297 channel = '/repo${}$/commit/{}'.format(
298 repo_name, commit_obj.raw_id
298 repo_name, commit_obj.raw_id
299 )
299 )
300 elif pull_request_obj:
300 elif pull_request_obj:
301 channel = '/repo${}$/pr/{}'.format(
301 channel = '/repo${}$/pr/{}'.format(
302 repo_name, pull_request_obj.pull_request_id
302 repo_name, pull_request_obj.pull_request_id
303 )
303 )
304 log.debug('Getting comment channelstream broadcast channel: %s', channel)
304 log.debug('Getting comment channelstream broadcast channel: %s', channel)
305
305
306 return channel
306 return channel
307
307
308
308
309 def pr_update_channelstream_push(request, pr_broadcast_channel, user, msg, **kwargs):
309 def pr_update_channelstream_push(request, pr_broadcast_channel, user, msg, **kwargs):
310 """
310 """
311 Channel push on pull request update
311 Channel push on pull request update
312 """
312 """
313 if not pr_broadcast_channel:
313 if not pr_broadcast_channel:
314 return
314 return
315
315
316 _ = request.translate
316 _ = request.translate
317
317
318 message = '{} {}'.format(
318 message = '{} {}'.format(
319 msg,
319 msg,
320 _reload_link(_(' Reload page to load changes')))
320 _reload_link(_(' Reload page to load changes')))
321
321
322 message_obj = {
322 message_obj = {
323 'message': message,
323 'message': message,
324 'level': 'success',
324 'level': 'success',
325 'topic': '/notifications'
325 'topic': '/notifications'
326 }
326 }
327
327
328 post_message(
328 post_message(
329 pr_broadcast_channel, message_obj, user.username,
329 pr_broadcast_channel, message_obj, user.username,
330 registry=request.registry)
330 registry=request.registry)
331
331
332
332
333 def comment_channelstream_push(request, comment_broadcast_channel, user, msg, **kwargs):
333 def comment_channelstream_push(request, comment_broadcast_channel, user, msg, **kwargs):
334 """
334 """
335 Channelstream push on comment action, on commit, or pull-request
335 Channelstream push on comment action, on commit, or pull-request
336 """
336 """
337 if not comment_broadcast_channel:
337 if not comment_broadcast_channel:
338 return
338 return
339
339
340 _ = request.translate
340 _ = request.translate
341
341
342 comment_data = kwargs.pop('comment_data', {})
342 comment_data = kwargs.pop('comment_data', {})
343 user_data = kwargs.pop('user_data', {})
343 user_data = kwargs.pop('user_data', {})
344 comment_id = list(comment_data.keys())[0] if comment_data else ''
344 comment_id = list(comment_data.keys())[0] if comment_data else ''
345
345
346 message = '<strong>{}</strong> {} #{}'.format(
346 message = '<strong>{}</strong> {} #{}'.format(
347 user.username,
347 user.username,
348 msg,
348 msg,
349 comment_id,
349 comment_id,
350 )
350 )
351
351
352 message_obj = {
352 message_obj = {
353 'message': message,
353 'message': message,
354 'level': 'success',
354 'level': 'success',
355 'topic': '/notifications'
355 'topic': '/notifications'
356 }
356 }
357
357
358 post_message(
358 post_message(
359 comment_broadcast_channel, message_obj, user.username,
359 comment_broadcast_channel, message_obj, user.username,
360 registry=request.registry)
360 registry=request.registry)
361
361
362 message_obj = {
362 message_obj = {
363 'message': None,
363 'message': None,
364 'user': user.username,
364 'user': user.username,
365 'comment_id': comment_id,
365 'comment_id': comment_id,
366 'comment_data': comment_data,
366 'comment_data': comment_data,
367 'user_data': user_data,
367 'user_data': user_data,
368 'topic': '/comment'
368 'topic': '/comment'
369 }
369 }
370 post_message(
370 post_message(
371 comment_broadcast_channel, message_obj, user.username,
371 comment_broadcast_channel, message_obj, user.username,
372 registry=request.registry)
372 registry=request.registry)
@@ -1,797 +1,819 b''
1
1
2
2
3 # Copyright (C) 2011-2020 RhodeCode GmbH
3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import logging
21 import logging
22 import difflib
22 import difflib
23 from itertools import groupby
23 import itertools
24
24
25 from pygments import lex
25 from pygments import lex
26 from pygments.formatters.html import _get_ttype_class as pygment_token_class
26 from pygments.formatters.html import _get_ttype_class as pygment_token_class
27 from pygments.lexers.special import TextLexer, Token
27 from pygments.lexers.special import TextLexer, Token
28 from pygments.lexers import get_lexer_by_name
28 from pygments.lexers import get_lexer_by_name
29
29
30 from rhodecode.lib.helpers import (
30 from rhodecode.lib.helpers import (
31 get_lexer_for_filenode, html_escape, get_custom_lexer)
31 get_lexer_for_filenode, html_escape, get_custom_lexer)
32 from rhodecode.lib.utils2 import AttributeDict, StrictAttributeDict, safe_unicode
32 from rhodecode.lib.str_utils import safe_str
33 from rhodecode.lib.utils2 import AttributeDict, StrictAttributeDict
33 from rhodecode.lib.vcs.nodes import FileNode
34 from rhodecode.lib.vcs.nodes import FileNode
34 from rhodecode.lib.vcs.exceptions import VCSError, NodeDoesNotExistError
35 from rhodecode.lib.vcs.exceptions import NodeDoesNotExistError
35 from rhodecode.lib.diff_match_patch import diff_match_patch
36 from rhodecode.lib.diff_match_patch import diff_match_patch
36 from rhodecode.lib.diffs import LimitedDiffContainer, DEL_FILENODE, BIN_FILENODE
37 from rhodecode.lib.diffs import LimitedDiffContainer, DEL_FILENODE, BIN_FILENODE
37
38
38
39
39 plain_text_lexer = get_lexer_by_name(
40 plain_text_lexer = get_lexer_by_name(
40 'text', stripall=False, stripnl=False, ensurenl=False)
41 'text', stripall=False, stripnl=False, ensurenl=False)
41
42
42
43
43 log = logging.getLogger(__name__)
44 log = logging.getLogger(__name__)
44
45
45
46
46 def filenode_as_lines_tokens(filenode, lexer=None):
47 def filenode_as_lines_tokens(filenode, lexer=None):
47 org_lexer = lexer
48 org_lexer = lexer
48 lexer = lexer or get_lexer_for_filenode(filenode)
49 lexer = lexer or get_lexer_for_filenode(filenode)
49 log.debug('Generating file node pygment tokens for %s, %s, org_lexer:%s',
50 log.debug('Generating file node pygment tokens for %s, file=`%s`, org_lexer:%s',
50 lexer, filenode, org_lexer)
51 lexer, filenode, org_lexer)
51 content = filenode.content
52 content = filenode.str_content
52 tokens = tokenize_string(content, lexer)
53 tokens = tokenize_string(content, lexer)
53 lines = split_token_stream(tokens, content)
54 lines = split_token_stream(tokens, content)
54 rv = list(lines)
55 rv = list(lines)
55 return rv
56 return rv
56
57
57
58
58 def tokenize_string(content, lexer):
59 def tokenize_string(content, lexer):
59 """
60 """
60 Use pygments to tokenize some content based on a lexer
61 Use pygments to tokenize some content based on a lexer
61 ensuring all original new lines and whitespace is preserved
62 ensuring all original new lines and whitespace is preserved
62 """
63 """
63
64
64 lexer.stripall = False
65 lexer.stripall = False
65 lexer.stripnl = False
66 lexer.stripnl = False
66 lexer.ensurenl = False
67 lexer.ensurenl = False
67
68
69 # pygments needs to operate on str
70 str_content = safe_str(content)
71
68 if isinstance(lexer, TextLexer):
72 if isinstance(lexer, TextLexer):
69 lexed = [(Token.Text, content)]
73 # we convert content here to STR because pygments does that while tokenizing
74 # if we DON'T get a lexer for unknown file type
75 lexed = [(Token.Text, str_content)]
70 else:
76 else:
71 lexed = lex(content, lexer)
77 lexed = lex(str_content, lexer)
72
78
73 for token_type, token_text in lexed:
79 for token_type, token_text in lexed:
74 yield pygment_token_class(token_type), token_text
80 yield pygment_token_class(token_type), token_text
75
81
76
82
77 def split_token_stream(tokens, content):
83 def split_token_stream(tokens, content):
78 """
84 """
79 Take a list of (TokenType, text) tuples and split them by a string
85 Take a list of (TokenType, text) tuples and split them by a string
80
86
81 split_token_stream([(TEXT, 'some\ntext'), (TEXT, 'more\n')])
87 split_token_stream([(TEXT, 'some\ntext'), (TEXT, 'more\n')])
82 [(TEXT, 'some'), (TEXT, 'text'),
88 [(TEXT, 'some'), (TEXT, 'text'),
83 (TEXT, 'more'), (TEXT, 'text')]
89 (TEXT, 'more'), (TEXT, 'text')]
84 """
90 """
85
91
86 token_buffer = []
92 token_buffer = []
93
87 for token_class, token_text in tokens:
94 for token_class, token_text in tokens:
95
96 # token_text, should be str
88 parts = token_text.split('\n')
97 parts = token_text.split('\n')
89 for part in parts[:-1]:
98 for part in parts[:-1]:
90 token_buffer.append((token_class, part))
99 token_buffer.append((token_class, part))
91 yield token_buffer
100 yield token_buffer
92 token_buffer = []
101 token_buffer = []
93
102
94 token_buffer.append((token_class, parts[-1]))
103 token_buffer.append((token_class, parts[-1]))
95
104
96 if token_buffer:
105 if token_buffer:
97 yield token_buffer
106 yield token_buffer
98 elif content:
107 elif content:
99 # this is a special case, we have the content, but tokenization didn't produce
108 # this is a special case, we have the content, but tokenization didn't produce
100 # any results. THis can happen if know file extensions like .css have some bogus
109 # any results. This can happen if know file extensions like .css have some bogus
101 # unicode content without any newline characters
110 # unicode content without any newline characters
102 yield [(pygment_token_class(Token.Text), content)]
111 yield [(pygment_token_class(Token.Text), content)]
103
112
104
113
105 def filenode_as_annotated_lines_tokens(filenode):
114 def filenode_as_annotated_lines_tokens(filenode):
106 """
115 """
107 Take a file node and return a list of annotations => lines, if no annotation
116 Take a file node and return a list of annotations => lines, if no annotation
108 is found, it will be None.
117 is found, it will be None.
109
118
110 eg:
119 eg:
111
120
112 [
121 [
113 (annotation1, [
122 (annotation1, [
114 (1, line1_tokens_list),
123 (1, line1_tokens_list),
115 (2, line2_tokens_list),
124 (2, line2_tokens_list),
116 ]),
125 ]),
117 (annotation2, [
126 (annotation2, [
118 (3, line1_tokens_list),
127 (3, line1_tokens_list),
119 ]),
128 ]),
120 (None, [
129 (None, [
121 (4, line1_tokens_list),
130 (4, line1_tokens_list),
122 ]),
131 ]),
123 (annotation1, [
132 (annotation1, [
124 (5, line1_tokens_list),
133 (5, line1_tokens_list),
125 (6, line2_tokens_list),
134 (6, line2_tokens_list),
126 ])
135 ])
127 ]
136 ]
128 """
137 """
129
138
130 commit_cache = {} # cache commit_getter lookups
139 commit_cache = {} # cache commit_getter lookups
131
140
132 def _get_annotation(commit_id, commit_getter):
141 def _get_annotation(commit_id, commit_getter):
133 if commit_id not in commit_cache:
142 if commit_id not in commit_cache:
134 commit_cache[commit_id] = commit_getter()
143 commit_cache[commit_id] = commit_getter()
135 return commit_cache[commit_id]
144 return commit_cache[commit_id]
136
145
137 annotation_lookup = {
146 annotation_lookup = {
138 line_no: _get_annotation(commit_id, commit_getter)
147 line_no: _get_annotation(commit_id, commit_getter)
139 for line_no, commit_id, commit_getter, line_content
148 for line_no, commit_id, commit_getter, line_content
140 in filenode.annotate
149 in filenode.annotate
141 }
150 }
142
151
143 annotations_lines = ((annotation_lookup.get(line_no), line_no, tokens)
152 annotations_lines = ((annotation_lookup.get(line_no), line_no, tokens)
144 for line_no, tokens
153 for line_no, tokens
145 in enumerate(filenode_as_lines_tokens(filenode), 1))
154 in enumerate(filenode_as_lines_tokens(filenode), 1))
146
155
147 grouped_annotations_lines = groupby(annotations_lines, lambda x: x[0])
156 grouped_annotations_lines = itertools.groupby(annotations_lines, lambda x: x[0])
148
157
149 for annotation, group in grouped_annotations_lines:
158 for annotation, group in grouped_annotations_lines:
150 yield (
159 yield (
151 annotation, [(line_no, tokens)
160 annotation, [(line_no, tokens)
152 for (_, line_no, tokens) in group]
161 for (_, line_no, tokens) in group]
153 )
162 )
154
163
155
164
156 def render_tokenstream(tokenstream):
165 def render_tokenstream(tokenstream):
157 result = []
166 result = []
158 for token_class, token_ops_texts in rollup_tokenstream(tokenstream):
167 for token_class, token_ops_texts in rollup_tokenstream(tokenstream):
159
168
160 if token_class:
169 if token_class:
161 result.append('<span class="%s">' % token_class)
170 result.append(f'<span class="{token_class}">')
162 else:
171 else:
163 result.append('<span>')
172 result.append('<span>')
164
173
165 for op_tag, token_text in token_ops_texts:
174 for op_tag, token_text in token_ops_texts:
166
175
167 if op_tag:
176 if op_tag:
168 result.append('<%s>' % op_tag)
177 result.append(f'<{op_tag}>')
169
178
170 # NOTE(marcink): in some cases of mixed encodings, we might run into
179 # NOTE(marcink): in some cases of mixed encodings, we might run into
171 # troubles in the html_escape, in this case we say unicode force on token_text
180 # troubles in the html_escape, in this case we say unicode force on token_text
172 # that would ensure "correct" data even with the cost of rendered
181 # that would ensure "correct" data even with the cost of rendered
173 try:
182 try:
174 escaped_text = html_escape(token_text)
183 escaped_text = html_escape(token_text)
175 except TypeError:
184 except TypeError:
176 escaped_text = html_escape(safe_unicode(token_text))
185 escaped_text = html_escape(safe_str(token_text))
177
186
178 # TODO: dan: investigate showing hidden characters like space/nl/tab
187 # TODO: dan: investigate showing hidden characters like space/nl/tab
179 # escaped_text = escaped_text.replace(' ', '<sp> </sp>')
188 # escaped_text = escaped_text.replace(' ', '<sp> </sp>')
180 # escaped_text = escaped_text.replace('\n', '<nl>\n</nl>')
189 # escaped_text = escaped_text.replace('\n', '<nl>\n</nl>')
181 # escaped_text = escaped_text.replace('\t', '<tab>\t</tab>')
190 # escaped_text = escaped_text.replace('\t', '<tab>\t</tab>')
182
191
183 result.append(escaped_text)
192 result.append(escaped_text)
184
193
185 if op_tag:
194 if op_tag:
186 result.append('</%s>' % op_tag)
195 result.append(f'</{op_tag}>')
187
196
188 result.append('</span>')
197 result.append('</span>')
189
198
190 html = ''.join(result)
199 html = ''.join(result)
191 return html
200 return html
192
201
193
202
194 def rollup_tokenstream(tokenstream):
203 def rollup_tokenstream(tokenstream):
195 """
204 """
196 Group a token stream of the format:
205 Group a token stream of the format:
197
206
198 ('class', 'op', 'text')
207 ('class', 'op', 'text')
199 or
208 or
200 ('class', 'text')
209 ('class', 'text')
201
210
202 into
211 into
203
212
204 [('class1',
213 [('class1',
205 [('op1', 'text'),
214 [('op1', 'text'),
206 ('op2', 'text')]),
215 ('op2', 'text')]),
207 ('class2',
216 ('class2',
208 [('op3', 'text')])]
217 [('op3', 'text')])]
209
218
210 This is used to get the minimal tags necessary when
219 This is used to get the minimal tags necessary when
211 rendering to html eg for a token stream ie.
220 rendering to html eg for a token stream ie.
212
221
213 <span class="A"><ins>he</ins>llo</span>
222 <span class="A"><ins>he</ins>llo</span>
214 vs
223 vs
215 <span class="A"><ins>he</ins></span><span class="A">llo</span>
224 <span class="A"><ins>he</ins></span><span class="A">llo</span>
216
225
217 If a 2 tuple is passed in, the output op will be an empty string.
226 If a 2 tuple is passed in, the output op will be an empty string.
218
227
219 eg:
228 eg:
220
229
221 >>> rollup_tokenstream([('classA', '', 'h'),
230 >>> rollup_tokenstream([('classA', '', 'h'),
222 ('classA', 'del', 'ell'),
231 ('classA', 'del', 'ell'),
223 ('classA', '', 'o'),
232 ('classA', '', 'o'),
224 ('classB', '', ' '),
233 ('classB', '', ' '),
225 ('classA', '', 'the'),
234 ('classA', '', 'the'),
226 ('classA', '', 're'),
235 ('classA', '', 're'),
227 ])
236 ])
228
237
229 [('classA', [('', 'h'), ('del', 'ell'), ('', 'o')],
238 [('classA', [('', 'h'), ('del', 'ell'), ('', 'o')],
230 ('classB', [('', ' ')],
239 ('classB', [('', ' ')],
231 ('classA', [('', 'there')]]
240 ('classA', [('', 'there')]]
232
241
233 """
242 """
234 if tokenstream and len(tokenstream[0]) == 2:
243 if tokenstream and len(tokenstream[0]) == 2:
235 tokenstream = ((t[0], '', t[1]) for t in tokenstream)
244 tokenstream = ((t[0], '', t[1]) for t in tokenstream)
236
245
237 result = []
246 result = []
238 for token_class, op_list in groupby(tokenstream, lambda t: t[0]):
247 for token_class, op_list in itertools.groupby(tokenstream, lambda t: t[0]):
239 ops = []
248 ops = []
240 for token_op, token_text_list in groupby(op_list, lambda o: o[1]):
249 for token_op, token_text_list in itertools.groupby(op_list, lambda o: o[1]):
241 text_buffer = []
250 text_buffer = []
242 for t_class, t_op, t_text in token_text_list:
251 for t_class, t_op, t_text in token_text_list:
243 text_buffer.append(t_text)
252 text_buffer.append(t_text)
253
244 ops.append((token_op, ''.join(text_buffer)))
254 ops.append((token_op, ''.join(text_buffer)))
245 result.append((token_class, ops))
255 result.append((token_class, ops))
246 return result
256 return result
247
257
248
258
249 def tokens_diff(old_tokens, new_tokens, use_diff_match_patch=True):
259 def tokens_diff(old_tokens, new_tokens, use_diff_match_patch=True):
250 """
260 """
251 Converts a list of (token_class, token_text) tuples to a list of
261 Converts a list of (token_class, token_text) tuples to a list of
252 (token_class, token_op, token_text) tuples where token_op is one of
262 (token_class, token_op, token_text) tuples where token_op is one of
253 ('ins', 'del', '')
263 ('ins', 'del', '')
254
264
255 :param old_tokens: list of (token_class, token_text) tuples of old line
265 :param old_tokens: list of (token_class, token_text) tuples of old line
256 :param new_tokens: list of (token_class, token_text) tuples of new line
266 :param new_tokens: list of (token_class, token_text) tuples of new line
257 :param use_diff_match_patch: boolean, will use google's diff match patch
267 :param use_diff_match_patch: boolean, will use google's diff match patch
258 library which has options to 'smooth' out the character by character
268 library which has options to 'smooth' out the character by character
259 differences making nicer ins/del blocks
269 differences making nicer ins/del blocks
260 """
270 """
261
271
262 old_tokens_result = []
272 old_tokens_result = []
263 new_tokens_result = []
273 new_tokens_result = []
264
274
265 similarity = difflib.SequenceMatcher(None,
275 def int_convert(val):
276 if isinstance(val, int):
277 return str(val)
278 return val
279
280 similarity = difflib.SequenceMatcher(
281 None,
266 ''.join(token_text for token_class, token_text in old_tokens),
282 ''.join(token_text for token_class, token_text in old_tokens),
267 ''.join(token_text for token_class, token_text in new_tokens)
283 ''.join(token_text for token_class, token_text in new_tokens)
268 ).ratio()
284 ).ratio()
269
285
270 if similarity < 0.6: # return, the blocks are too different
286 if similarity < 0.6: # return, the blocks are too different
271 for token_class, token_text in old_tokens:
287 for token_class, token_text in old_tokens:
272 old_tokens_result.append((token_class, '', token_text))
288 old_tokens_result.append((token_class, '', token_text))
273 for token_class, token_text in new_tokens:
289 for token_class, token_text in new_tokens:
274 new_tokens_result.append((token_class, '', token_text))
290 new_tokens_result.append((token_class, '', token_text))
275 return old_tokens_result, new_tokens_result, similarity
291 return old_tokens_result, new_tokens_result, similarity
276
292
277 token_sequence_matcher = difflib.SequenceMatcher(None,
293 token_sequence_matcher = difflib.SequenceMatcher(
294 None,
278 [x[1] for x in old_tokens],
295 [x[1] for x in old_tokens],
279 [x[1] for x in new_tokens])
296 [x[1] for x in new_tokens])
280
297
281 for tag, o1, o2, n1, n2 in token_sequence_matcher.get_opcodes():
298 for tag, o1, o2, n1, n2 in token_sequence_matcher.get_opcodes():
282 # check the differences by token block types first to give a more
299 # check the differences by token block types first to give a
283 # nicer "block" level replacement vs character diffs
300 # nicer "block" level replacement vs character diffs
284
301
285 if tag == 'equal':
302 if tag == 'equal':
286 for token_class, token_text in old_tokens[o1:o2]:
303 for token_class, token_text in old_tokens[o1:o2]:
287 old_tokens_result.append((token_class, '', token_text))
304 old_tokens_result.append((token_class, '', token_text))
288 for token_class, token_text in new_tokens[n1:n2]:
305 for token_class, token_text in new_tokens[n1:n2]:
289 new_tokens_result.append((token_class, '', token_text))
306 new_tokens_result.append((token_class, '', token_text))
290 elif tag == 'delete':
307 elif tag == 'delete':
291 for token_class, token_text in old_tokens[o1:o2]:
308 for token_class, token_text in old_tokens[o1:o2]:
292 old_tokens_result.append((token_class, 'del', token_text))
309 old_tokens_result.append((token_class, 'del', int_convert(token_text)))
293 elif tag == 'insert':
310 elif tag == 'insert':
294 for token_class, token_text in new_tokens[n1:n2]:
311 for token_class, token_text in new_tokens[n1:n2]:
295 new_tokens_result.append((token_class, 'ins', token_text))
312 new_tokens_result.append((token_class, 'ins', int_convert(token_text)))
296 elif tag == 'replace':
313 elif tag == 'replace':
297 # if same type token blocks must be replaced, do a diff on the
314 # if same type token blocks must be replaced, do a diff on the
298 # characters in the token blocks to show individual changes
315 # characters in the token blocks to show individual changes
299
316
300 old_char_tokens = []
317 old_char_tokens = []
301 new_char_tokens = []
318 new_char_tokens = []
302 for token_class, token_text in old_tokens[o1:o2]:
319 for token_class, token_text in old_tokens[o1:o2]:
303 for char in token_text:
320 for char in map(lambda i: i, token_text):
304 old_char_tokens.append((token_class, char))
321 old_char_tokens.append((token_class, char))
305
322
306 for token_class, token_text in new_tokens[n1:n2]:
323 for token_class, token_text in new_tokens[n1:n2]:
307 for char in token_text:
324 for char in map(lambda i: i, token_text):
308 new_char_tokens.append((token_class, char))
325 new_char_tokens.append((token_class, char))
309
326
310 old_string = ''.join([token_text for
327 old_string = ''.join([token_text for
311 token_class, token_text in old_char_tokens])
328 token_class, token_text in old_char_tokens])
312 new_string = ''.join([token_text for
329 new_string = ''.join([token_text for
313 token_class, token_text in new_char_tokens])
330 token_class, token_text in new_char_tokens])
314
331
315 char_sequence = difflib.SequenceMatcher(
332 char_sequence = difflib.SequenceMatcher(
316 None, old_string, new_string)
333 None, old_string, new_string)
317 copcodes = char_sequence.get_opcodes()
334 copcodes = char_sequence.get_opcodes()
318 obuffer, nbuffer = [], []
335 obuffer, nbuffer = [], []
319
336
320 if use_diff_match_patch:
337 if use_diff_match_patch:
321 dmp = diff_match_patch()
338 dmp = diff_match_patch()
322 dmp.Diff_EditCost = 11 # TODO: dan: extract this to a setting
339 dmp.Diff_EditCost = 11 # TODO: dan: extract this to a setting
323 reps = dmp.diff_main(old_string, new_string)
340 reps = dmp.diff_main(old_string, new_string)
324 dmp.diff_cleanupEfficiency(reps)
341 dmp.diff_cleanupEfficiency(reps)
325
342
326 a, b = 0, 0
343 a, b = 0, 0
327 for op, rep in reps:
344 for op, rep in reps:
328 l = len(rep)
345 l = len(rep)
329 if op == 0:
346 if op == 0:
330 for i, c in enumerate(rep):
347 for i, c in enumerate(rep):
331 obuffer.append((old_char_tokens[a+i][0], '', c))
348 obuffer.append((old_char_tokens[a+i][0], '', c))
332 nbuffer.append((new_char_tokens[b+i][0], '', c))
349 nbuffer.append((new_char_tokens[b+i][0], '', c))
333 a += l
350 a += l
334 b += l
351 b += l
335 elif op == -1:
352 elif op == -1:
336 for i, c in enumerate(rep):
353 for i, c in enumerate(rep):
337 obuffer.append((old_char_tokens[a+i][0], 'del', c))
354 obuffer.append((old_char_tokens[a+i][0], 'del', int_convert(c)))
338 a += l
355 a += l
339 elif op == 1:
356 elif op == 1:
340 for i, c in enumerate(rep):
357 for i, c in enumerate(rep):
341 nbuffer.append((new_char_tokens[b+i][0], 'ins', c))
358 nbuffer.append((new_char_tokens[b+i][0], 'ins', int_convert(c)))
342 b += l
359 b += l
343 else:
360 else:
344 for ctag, co1, co2, cn1, cn2 in copcodes:
361 for ctag, co1, co2, cn1, cn2 in copcodes:
345 if ctag == 'equal':
362 if ctag == 'equal':
346 for token_class, token_text in old_char_tokens[co1:co2]:
363 for token_class, token_text in old_char_tokens[co1:co2]:
347 obuffer.append((token_class, '', token_text))
364 obuffer.append((token_class, '', token_text))
348 for token_class, token_text in new_char_tokens[cn1:cn2]:
365 for token_class, token_text in new_char_tokens[cn1:cn2]:
349 nbuffer.append((token_class, '', token_text))
366 nbuffer.append((token_class, '', token_text))
350 elif ctag == 'delete':
367 elif ctag == 'delete':
351 for token_class, token_text in old_char_tokens[co1:co2]:
368 for token_class, token_text in old_char_tokens[co1:co2]:
352 obuffer.append((token_class, 'del', token_text))
369 obuffer.append((token_class, 'del', int_convert(token_text)))
353 elif ctag == 'insert':
370 elif ctag == 'insert':
354 for token_class, token_text in new_char_tokens[cn1:cn2]:
371 for token_class, token_text in new_char_tokens[cn1:cn2]:
355 nbuffer.append((token_class, 'ins', token_text))
372 nbuffer.append((token_class, 'ins', int_convert(token_text)))
356 elif ctag == 'replace':
373 elif ctag == 'replace':
357 for token_class, token_text in old_char_tokens[co1:co2]:
374 for token_class, token_text in old_char_tokens[co1:co2]:
358 obuffer.append((token_class, 'del', token_text))
375 obuffer.append((token_class, 'del', int_convert(token_text)))
359 for token_class, token_text in new_char_tokens[cn1:cn2]:
376 for token_class, token_text in new_char_tokens[cn1:cn2]:
360 nbuffer.append((token_class, 'ins', token_text))
377 nbuffer.append((token_class, 'ins', int_convert(token_text)))
361
378
362 old_tokens_result.extend(obuffer)
379 old_tokens_result.extend(obuffer)
363 new_tokens_result.extend(nbuffer)
380 new_tokens_result.extend(nbuffer)
364
381
365 return old_tokens_result, new_tokens_result, similarity
382 return old_tokens_result, new_tokens_result, similarity
366
383
367
384
368 def diffset_node_getter(commit):
385 def diffset_node_getter(commit):
369 def get_node(fname):
386 def get_diff_node(file_name):
387
370 try:
388 try:
371 return commit.get_node(fname)
389 return commit.get_node(file_name, pre_load=['size', 'flags', 'data'])
372 except NodeDoesNotExistError:
390 except NodeDoesNotExistError:
373 return None
391 return None
374
392
375 return get_node
393 return get_diff_node
376
394
377
395
378 class DiffSet(object):
396 class DiffSet(object):
379 """
397 """
380 An object for parsing the diff result from diffs.DiffProcessor and
398 An object for parsing the diff result from diffs.DiffProcessor and
381 adding highlighting, side by side/unified renderings and line diffs
399 adding highlighting, side by side/unified renderings and line diffs
382 """
400 """
383
401
384 HL_REAL = 'REAL' # highlights using original file, slow
402 HL_REAL = 'REAL' # highlights using original file, slow
385 HL_FAST = 'FAST' # highlights using just the line, fast but not correct
403 HL_FAST = 'FAST' # highlights using just the line, fast but not correct
386 # in the case of multiline code
404 # in the case of multiline code
387 HL_NONE = 'NONE' # no highlighting, fastest
405 HL_NONE = 'NONE' # no highlighting, fastest
388
406
389 def __init__(self, highlight_mode=HL_REAL, repo_name=None,
407 def __init__(self, highlight_mode=HL_REAL, repo_name=None,
390 source_repo_name=None,
408 source_repo_name=None,
391 source_node_getter=lambda filename: None,
409 source_node_getter=lambda filename: None,
392 target_repo_name=None,
410 target_repo_name=None,
393 target_node_getter=lambda filename: None,
411 target_node_getter=lambda filename: None,
394 source_nodes=None, target_nodes=None,
412 source_nodes=None, target_nodes=None,
395 # files over this size will use fast highlighting
413 # files over this size will use fast highlighting
396 max_file_size_limit=150 * 1024,
414 max_file_size_limit=150 * 1024,
397 ):
415 ):
398
416
399 self.highlight_mode = highlight_mode
417 self.highlight_mode = highlight_mode
400 self.highlighted_filenodes = {
418 self.highlighted_filenodes = {
401 'before': {},
419 'before': {},
402 'after': {}
420 'after': {}
403 }
421 }
404 self.source_node_getter = source_node_getter
422 self.source_node_getter = source_node_getter
405 self.target_node_getter = target_node_getter
423 self.target_node_getter = target_node_getter
406 self.source_nodes = source_nodes or {}
424 self.source_nodes = source_nodes or {}
407 self.target_nodes = target_nodes or {}
425 self.target_nodes = target_nodes or {}
408 self.repo_name = repo_name
426 self.repo_name = repo_name
409 self.target_repo_name = target_repo_name or repo_name
427 self.target_repo_name = target_repo_name or repo_name
410 self.source_repo_name = source_repo_name or repo_name
428 self.source_repo_name = source_repo_name or repo_name
411 self.max_file_size_limit = max_file_size_limit
429 self.max_file_size_limit = max_file_size_limit
412
430
413 def render_patchset(self, patchset, source_ref=None, target_ref=None):
431 def render_patchset(self, patchset, source_ref=None, target_ref=None):
414 diffset = AttributeDict(dict(
432 diffset = AttributeDict(dict(
415 lines_added=0,
433 lines_added=0,
416 lines_deleted=0,
434 lines_deleted=0,
417 changed_files=0,
435 changed_files=0,
418 files=[],
436 files=[],
419 file_stats={},
437 file_stats={},
420 limited_diff=isinstance(patchset, LimitedDiffContainer),
438 limited_diff=isinstance(patchset, LimitedDiffContainer),
421 repo_name=self.repo_name,
439 repo_name=self.repo_name,
422 target_repo_name=self.target_repo_name,
440 target_repo_name=self.target_repo_name,
423 source_repo_name=self.source_repo_name,
441 source_repo_name=self.source_repo_name,
424 source_ref=source_ref,
442 source_ref=source_ref,
425 target_ref=target_ref,
443 target_ref=target_ref,
426 ))
444 ))
427 for patch in patchset:
445 for patch in patchset:
428 diffset.file_stats[patch['filename']] = patch['stats']
446 diffset.file_stats[patch['filename']] = patch['stats']
429 filediff = self.render_patch(patch)
447 filediff = self.render_patch(patch)
430 filediff.diffset = StrictAttributeDict(dict(
448 filediff.diffset = StrictAttributeDict(dict(
431 source_ref=diffset.source_ref,
449 source_ref=diffset.source_ref,
432 target_ref=diffset.target_ref,
450 target_ref=diffset.target_ref,
433 repo_name=diffset.repo_name,
451 repo_name=diffset.repo_name,
434 source_repo_name=diffset.source_repo_name,
452 source_repo_name=diffset.source_repo_name,
435 target_repo_name=diffset.target_repo_name,
453 target_repo_name=diffset.target_repo_name,
436 ))
454 ))
437 diffset.files.append(filediff)
455 diffset.files.append(filediff)
438 diffset.changed_files += 1
456 diffset.changed_files += 1
439 if not patch['stats']['binary']:
457 if not patch['stats']['binary']:
440 diffset.lines_added += patch['stats']['added']
458 diffset.lines_added += patch['stats']['added']
441 diffset.lines_deleted += patch['stats']['deleted']
459 diffset.lines_deleted += patch['stats']['deleted']
442
460
443 return diffset
461 return diffset
444
462
445 _lexer_cache = {}
463 _lexer_cache = {}
446
464
447 def _get_lexer_for_filename(self, filename, filenode=None):
465 def _get_lexer_for_filename(self, filename, filenode=None):
448 # cached because we might need to call it twice for source/target
466 # cached because we might need to call it twice for source/target
449 if filename not in self._lexer_cache:
467 if filename not in self._lexer_cache:
450 if filenode:
468 if filenode:
451 lexer = filenode.lexer
469 lexer = filenode.lexer
452 extension = filenode.extension
470 extension = filenode.extension
453 else:
471 else:
454 lexer = FileNode.get_lexer(filename=filename)
472 lexer = FileNode.get_lexer(filename=filename)
455 extension = filename.split('.')[-1]
473 extension = filename.split('.')[-1]
456
474
457 lexer = get_custom_lexer(extension) or lexer
475 lexer = get_custom_lexer(extension) or lexer
458 self._lexer_cache[filename] = lexer
476 self._lexer_cache[filename] = lexer
459 return self._lexer_cache[filename]
477 return self._lexer_cache[filename]
460
478
461 def render_patch(self, patch):
479 def render_patch(self, patch):
462 log.debug('rendering diff for %r', patch['filename'])
480 log.debug('rendering diff for %r', patch['filename'])
463
481
464 source_filename = patch['original_filename']
482 source_filename = patch['original_filename']
465 target_filename = patch['filename']
483 target_filename = patch['filename']
466
484
467 source_lexer = plain_text_lexer
485 source_lexer = plain_text_lexer
468 target_lexer = plain_text_lexer
486 target_lexer = plain_text_lexer
469
487
470 if not patch['stats']['binary']:
488 if not patch['stats']['binary']:
471 node_hl_mode = self.HL_NONE if patch['chunks'] == [] else None
489 node_hl_mode = self.HL_NONE if patch['chunks'] == [] else None
472 hl_mode = node_hl_mode or self.highlight_mode
490 hl_mode = node_hl_mode or self.highlight_mode
473
491
474 if hl_mode == self.HL_REAL:
492 if hl_mode == self.HL_REAL:
475 if (source_filename and patch['operation'] in ('D', 'M')
493 if (source_filename and patch['operation'] in ('D', 'M')
476 and source_filename not in self.source_nodes):
494 and source_filename not in self.source_nodes):
477 self.source_nodes[source_filename] = (
495 self.source_nodes[source_filename] = (
478 self.source_node_getter(source_filename))
496 self.source_node_getter(source_filename))
479
497
480 if (target_filename and patch['operation'] in ('A', 'M')
498 if (target_filename and patch['operation'] in ('A', 'M')
481 and target_filename not in self.target_nodes):
499 and target_filename not in self.target_nodes):
482 self.target_nodes[target_filename] = (
500 self.target_nodes[target_filename] = (
483 self.target_node_getter(target_filename))
501 self.target_node_getter(target_filename))
484
502
485 elif hl_mode == self.HL_FAST:
503 elif hl_mode == self.HL_FAST:
486 source_lexer = self._get_lexer_for_filename(source_filename)
504 source_lexer = self._get_lexer_for_filename(source_filename)
487 target_lexer = self._get_lexer_for_filename(target_filename)
505 target_lexer = self._get_lexer_for_filename(target_filename)
488
506
489 source_file = self.source_nodes.get(source_filename, source_filename)
507 source_file = self.source_nodes.get(source_filename, source_filename)
490 target_file = self.target_nodes.get(target_filename, target_filename)
508 target_file = self.target_nodes.get(target_filename, target_filename)
491 raw_id_uid = ''
509 raw_id_uid = ''
492 if self.source_nodes.get(source_filename):
510 if self.source_nodes.get(source_filename):
493 raw_id_uid = self.source_nodes[source_filename].commit.raw_id
511 raw_id_uid = self.source_nodes[source_filename].commit.raw_id
494
512
495 if not raw_id_uid and self.target_nodes.get(target_filename):
513 if not raw_id_uid and self.target_nodes.get(target_filename):
496 # in case this is a new file we only have it in target
514 # in case this is a new file we only have it in target
497 raw_id_uid = self.target_nodes[target_filename].commit.raw_id
515 raw_id_uid = self.target_nodes[target_filename].commit.raw_id
498
516
499 source_filenode, target_filenode = None, None
517 source_filenode, target_filenode = None, None
500
518
501 # TODO: dan: FileNode.lexer works on the content of the file - which
519 # TODO: dan: FileNode.lexer works on the content of the file - which
502 # can be slow - issue #4289 explains a lexer clean up - which once
520 # can be slow - issue #4289 explains a lexer clean up - which once
503 # done can allow caching a lexer for a filenode to avoid the file lookup
521 # done can allow caching a lexer for a filenode to avoid the file lookup
504 if isinstance(source_file, FileNode):
522 if isinstance(source_file, FileNode):
505 source_filenode = source_file
523 source_filenode = source_file
506 #source_lexer = source_file.lexer
524 #source_lexer = source_file.lexer
507 source_lexer = self._get_lexer_for_filename(source_filename)
525 source_lexer = self._get_lexer_for_filename(source_filename)
508 source_file.lexer = source_lexer
526 source_file.lexer = source_lexer
509
527
510 if isinstance(target_file, FileNode):
528 if isinstance(target_file, FileNode):
511 target_filenode = target_file
529 target_filenode = target_file
512 #target_lexer = target_file.lexer
530 #target_lexer = target_file.lexer
513 target_lexer = self._get_lexer_for_filename(target_filename)
531 target_lexer = self._get_lexer_for_filename(target_filename)
514 target_file.lexer = target_lexer
532 target_file.lexer = target_lexer
515
533
516 source_file_path, target_file_path = None, None
534 source_file_path, target_file_path = None, None
517
535
518 if source_filename != '/dev/null':
536 if source_filename != '/dev/null':
519 source_file_path = source_filename
537 source_file_path = source_filename
520 if target_filename != '/dev/null':
538 if target_filename != '/dev/null':
521 target_file_path = target_filename
539 target_file_path = target_filename
522
540
523 source_file_type = source_lexer.name
541 source_file_type = source_lexer.name
524 target_file_type = target_lexer.name
542 target_file_type = target_lexer.name
525
543
526 filediff = AttributeDict({
544 filediff = AttributeDict({
527 'source_file_path': source_file_path,
545 'source_file_path': source_file_path,
528 'target_file_path': target_file_path,
546 'target_file_path': target_file_path,
529 'source_filenode': source_filenode,
547 'source_filenode': source_filenode,
530 'target_filenode': target_filenode,
548 'target_filenode': target_filenode,
531 'source_file_type': target_file_type,
549 'source_file_type': target_file_type,
532 'target_file_type': source_file_type,
550 'target_file_type': source_file_type,
533 'patch': {'filename': patch['filename'], 'stats': patch['stats']},
551 'patch': {'filename': patch['filename'], 'stats': patch['stats']},
534 'operation': patch['operation'],
552 'operation': patch['operation'],
535 'source_mode': patch['stats']['old_mode'],
553 'source_mode': patch['stats']['old_mode'],
536 'target_mode': patch['stats']['new_mode'],
554 'target_mode': patch['stats']['new_mode'],
537 'limited_diff': patch['is_limited_diff'],
555 'limited_diff': patch['is_limited_diff'],
538 'hunks': [],
556 'hunks': [],
539 'hunk_ops': None,
557 'hunk_ops': None,
540 'diffset': self,
558 'diffset': self,
541 'raw_id': raw_id_uid,
559 'raw_id': raw_id_uid,
542 })
560 })
543
561
544 file_chunks = patch['chunks'][1:]
562 file_chunks = patch['chunks'][1:]
545 for i, hunk in enumerate(file_chunks, 1):
563 for i, hunk in enumerate(file_chunks, 1):
546 hunkbit = self.parse_hunk(hunk, source_file, target_file)
564 hunkbit = self.parse_hunk(hunk, source_file, target_file)
547 hunkbit.source_file_path = source_file_path
565 hunkbit.source_file_path = source_file_path
548 hunkbit.target_file_path = target_file_path
566 hunkbit.target_file_path = target_file_path
549 hunkbit.index = i
567 hunkbit.index = i
550 filediff.hunks.append(hunkbit)
568 filediff.hunks.append(hunkbit)
551
569
552 # Simulate hunk on OPS type line which doesn't really contain any diff
570 # Simulate hunk on OPS type line which doesn't really contain any diff
553 # this allows commenting on those
571 # this allows commenting on those
554 if not file_chunks:
572 if not file_chunks:
555 actions = []
573 actions = []
556 for op_id, op_text in filediff.patch['stats']['ops'].items():
574 for op_id, op_text in list(filediff.patch['stats']['ops'].items()):
557 if op_id == DEL_FILENODE:
575 if op_id == DEL_FILENODE:
558 actions.append('file was removed')
576 actions.append('file was removed')
559 elif op_id == BIN_FILENODE:
577 elif op_id == BIN_FILENODE:
560 actions.append('binary diff hidden')
578 actions.append('binary diff hidden')
561 else:
579 else:
562 actions.append(safe_unicode(op_text))
580 actions.append(safe_str(op_text))
563 action_line = 'NO CONTENT: ' + \
581 action_line = 'NO CONTENT: ' + \
564 ', '.join(actions) or 'UNDEFINED_ACTION'
582 ', '.join(actions) or 'UNDEFINED_ACTION'
565
583
566 hunk_ops = {'source_length': 0, 'source_start': 0,
584 hunk_ops = {'source_length': 0, 'source_start': 0,
567 'lines': [
585 'lines': [
568 {'new_lineno': 0, 'old_lineno': 1,
586 {'new_lineno': 0, 'old_lineno': 1,
569 'action': 'unmod-no-hl', 'line': action_line}
587 'action': 'unmod-no-hl', 'line': action_line}
570 ],
588 ],
571 'section_header': '', 'target_start': 1, 'target_length': 1}
589 'section_header': '', 'target_start': 1, 'target_length': 1}
572
590
573 hunkbit = self.parse_hunk(hunk_ops, source_file, target_file)
591 hunkbit = self.parse_hunk(hunk_ops, source_file, target_file)
574 hunkbit.source_file_path = source_file_path
592 hunkbit.source_file_path = source_file_path
575 hunkbit.target_file_path = target_file_path
593 hunkbit.target_file_path = target_file_path
576 filediff.hunk_ops = hunkbit
594 filediff.hunk_ops = hunkbit
577 return filediff
595 return filediff
578
596
579 def parse_hunk(self, hunk, source_file, target_file):
597 def parse_hunk(self, hunk, source_file, target_file):
580 result = AttributeDict(dict(
598 result = AttributeDict(dict(
581 source_start=hunk['source_start'],
599 source_start=hunk['source_start'],
582 source_length=hunk['source_length'],
600 source_length=hunk['source_length'],
583 target_start=hunk['target_start'],
601 target_start=hunk['target_start'],
584 target_length=hunk['target_length'],
602 target_length=hunk['target_length'],
585 section_header=hunk['section_header'],
603 section_header=hunk['section_header'],
586 lines=[],
604 lines=[],
587 ))
605 ))
588 before, after = [], []
606 before, after = [], []
589
607
590 for line in hunk['lines']:
608 for line in hunk['lines']:
609
591 if line['action'] in ['unmod', 'unmod-no-hl']:
610 if line['action'] in ['unmod', 'unmod-no-hl']:
592 no_hl = line['action'] == 'unmod-no-hl'
611 no_hl = line['action'] == 'unmod-no-hl'
593 result.lines.extend(
612 parsed_lines = self.parse_lines(before, after, source_file, target_file, no_hl=no_hl)
594 self.parse_lines(before, after, source_file, target_file, no_hl=no_hl))
613 result.lines.extend(parsed_lines)
595 after.append(line)
614 after.append(line)
596 before.append(line)
615 before.append(line)
597 elif line['action'] == 'add':
616 elif line['action'] == 'add':
598 after.append(line)
617 after.append(line)
599 elif line['action'] == 'del':
618 elif line['action'] == 'del':
600 before.append(line)
619 before.append(line)
601 elif line['action'] == 'old-no-nl':
620 elif line['action'] == 'old-no-nl':
602 before.append(line)
621 before.append(line)
622 #line['line'] = safe_str(line['line'])
603 elif line['action'] == 'new-no-nl':
623 elif line['action'] == 'new-no-nl':
624 #line['line'] = safe_str(line['line'])
604 after.append(line)
625 after.append(line)
605
626
606 all_actions = [x['action'] for x in after] + [x['action'] for x in before]
627 all_actions = [x['action'] for x in after] + [x['action'] for x in before]
607 no_hl = {x for x in all_actions} == {'unmod-no-hl'}
628 no_hl = {x for x in all_actions} == {'unmod-no-hl'}
608 result.lines.extend(
629 parsed_no_hl_lines = self.parse_lines(before, after, source_file, target_file, no_hl=no_hl)
609 self.parse_lines(before, after, source_file, target_file, no_hl=no_hl))
630 result.lines.extend(parsed_no_hl_lines)
610 # NOTE(marcink): we must keep list() call here so we can cache the result...
631
632 # NOTE(marcink): we must keep list() call here, so we can cache the result...
611 result.unified = list(self.as_unified(result.lines))
633 result.unified = list(self.as_unified(result.lines))
612 result.sideside = result.lines
634 result.sideside = result.lines
613
635
614 return result
636 return result
615
637
616 def parse_lines(self, before_lines, after_lines, source_file, target_file,
638 def parse_lines(self, before_lines, after_lines, source_file, target_file,
617 no_hl=False):
639 no_hl=False):
618 # TODO: dan: investigate doing the diff comparison and fast highlighting
640 # TODO: dan: investigate doing the diff comparison and fast highlighting
619 # on the entire before and after buffered block lines rather than by
641 # on the entire before and after buffered block lines rather than by
620 # line, this means we can get better 'fast' highlighting if the context
642 # line, this means we can get better 'fast' highlighting if the context
621 # allows it - eg.
643 # allows it - eg.
622 # line 4: """
644 # line 4: """
623 # line 5: this gets highlighted as a string
645 # line 5: this gets highlighted as a string
624 # line 6: """
646 # line 6: """
625
647
626 lines = []
648 lines = []
627
649
628 before_newline = AttributeDict()
650 before_newline = AttributeDict()
629 after_newline = AttributeDict()
651 after_newline = AttributeDict()
630 if before_lines and before_lines[-1]['action'] == 'old-no-nl':
652 if before_lines and before_lines[-1]['action'] == 'old-no-nl':
631 before_newline_line = before_lines.pop(-1)
653 before_newline_line = before_lines.pop(-1)
632 before_newline.content = '\n {}'.format(
654 before_newline.content = '\n {}'.format(
633 render_tokenstream(
655 render_tokenstream(
634 [(x[0], '', x[1])
656 [(x[0], '', safe_str(x[1]))
635 for x in [('nonl', before_newline_line['line'])]]))
657 for x in [('nonl', before_newline_line['line'])]]))
636
658
637 if after_lines and after_lines[-1]['action'] == 'new-no-nl':
659 if after_lines and after_lines[-1]['action'] == 'new-no-nl':
638 after_newline_line = after_lines.pop(-1)
660 after_newline_line = after_lines.pop(-1)
639 after_newline.content = '\n {}'.format(
661 after_newline.content = '\n {}'.format(
640 render_tokenstream(
662 render_tokenstream(
641 [(x[0], '', x[1])
663 [(x[0], '', safe_str(x[1]))
642 for x in [('nonl', after_newline_line['line'])]]))
664 for x in [('nonl', after_newline_line['line'])]]))
643
665
644 while before_lines or after_lines:
666 while before_lines or after_lines:
645 before, after = None, None
667 before, after = None, None
646 before_tokens, after_tokens = None, None
668 before_tokens, after_tokens = None, None
647
669
648 if before_lines:
670 if before_lines:
649 before = before_lines.pop(0)
671 before = before_lines.pop(0)
650 if after_lines:
672 if after_lines:
651 after = after_lines.pop(0)
673 after = after_lines.pop(0)
652
674
653 original = AttributeDict()
675 original = AttributeDict()
654 modified = AttributeDict()
676 modified = AttributeDict()
655
677
656 if before:
678 if before:
657 if before['action'] == 'old-no-nl':
679 if before['action'] == 'old-no-nl':
658 before_tokens = [('nonl', before['line'])]
680 before_tokens = [('nonl', safe_str(before['line']))]
659 else:
681 else:
660 before_tokens = self.get_line_tokens(
682 before_tokens = self.get_line_tokens(
661 line_text=before['line'], line_number=before['old_lineno'],
683 line_text=before['line'], line_number=before['old_lineno'],
662 input_file=source_file, no_hl=no_hl, source='before')
684 input_file=source_file, no_hl=no_hl, source='before')
663 original.lineno = before['old_lineno']
685 original.lineno = before['old_lineno']
664 original.content = before['line']
686 original.content = before['line']
665 original.action = self.action_to_op(before['action'])
687 original.action = self.action_to_op(before['action'])
666
688
667 original.get_comment_args = (
689 original.get_comment_args = (
668 source_file, 'o', before['old_lineno'])
690 source_file, 'o', before['old_lineno'])
669
691
670 if after:
692 if after:
671 if after['action'] == 'new-no-nl':
693 if after['action'] == 'new-no-nl':
672 after_tokens = [('nonl', after['line'])]
694 after_tokens = [('nonl', safe_str(after['line']))]
673 else:
695 else:
674 after_tokens = self.get_line_tokens(
696 after_tokens = self.get_line_tokens(
675 line_text=after['line'], line_number=after['new_lineno'],
697 line_text=after['line'], line_number=after['new_lineno'],
676 input_file=target_file, no_hl=no_hl, source='after')
698 input_file=target_file, no_hl=no_hl, source='after')
677 modified.lineno = after['new_lineno']
699 modified.lineno = after['new_lineno']
678 modified.content = after['line']
700 modified.content = after['line']
679 modified.action = self.action_to_op(after['action'])
701 modified.action = self.action_to_op(after['action'])
680
702
681 modified.get_comment_args = (target_file, 'n', after['new_lineno'])
703 modified.get_comment_args = (target_file, 'n', after['new_lineno'])
682
704
683 # diff the lines
705 # diff the lines
684 if before_tokens and after_tokens:
706 if before_tokens and after_tokens:
685 o_tokens, m_tokens, similarity = tokens_diff(
707 o_tokens, m_tokens, similarity = tokens_diff(
686 before_tokens, after_tokens)
708 before_tokens, after_tokens)
687 original.content = render_tokenstream(o_tokens)
709 original.content = render_tokenstream(o_tokens)
688 modified.content = render_tokenstream(m_tokens)
710 modified.content = render_tokenstream(m_tokens)
689 elif before_tokens:
711 elif before_tokens:
690 original.content = render_tokenstream(
712 original.content = render_tokenstream(
691 [(x[0], '', x[1]) for x in before_tokens])
713 [(x[0], '', x[1]) for x in before_tokens])
692 elif after_tokens:
714 elif after_tokens:
693 modified.content = render_tokenstream(
715 modified.content = render_tokenstream(
694 [(x[0], '', x[1]) for x in after_tokens])
716 [(x[0], '', x[1]) for x in after_tokens])
695
717
696 if not before_lines and before_newline:
718 if not before_lines and before_newline:
697 original.content += before_newline.content
719 original.content += before_newline.content
698 before_newline = None
720 before_newline = None
699 if not after_lines and after_newline:
721 if not after_lines and after_newline:
700 modified.content += after_newline.content
722 modified.content += after_newline.content
701 after_newline = None
723 after_newline = None
702
724
703 lines.append(AttributeDict({
725 lines.append(AttributeDict({
704 'original': original,
726 'original': original,
705 'modified': modified,
727 'modified': modified,
706 }))
728 }))
707
729
708 return lines
730 return lines
709
731
710 def get_line_tokens(self, line_text, line_number, input_file=None, no_hl=False, source=''):
732 def get_line_tokens(self, line_text, line_number, input_file=None, no_hl=False, source=''):
711 filenode = None
733 filenode = None
712 filename = None
734 filename = None
713
735
714 if isinstance(input_file, str):
736 if isinstance(input_file, str):
715 filename = input_file
737 filename = input_file
716 elif isinstance(input_file, FileNode):
738 elif isinstance(input_file, FileNode):
717 filenode = input_file
739 filenode = input_file
718 filename = input_file.unicode_path
740 filename = input_file.str_path
719
741
720 hl_mode = self.HL_NONE if no_hl else self.highlight_mode
742 hl_mode = self.HL_NONE if no_hl else self.highlight_mode
721 if hl_mode == self.HL_REAL and filenode:
743 if hl_mode == self.HL_REAL and filenode:
722 lexer = self._get_lexer_for_filename(filename)
744 lexer = self._get_lexer_for_filename(filename)
723 file_size_allowed = input_file.size < self.max_file_size_limit
745 file_size_allowed = filenode.size < self.max_file_size_limit
724 if line_number and file_size_allowed:
746 if line_number and file_size_allowed:
725 return self.get_tokenized_filenode_line(input_file, line_number, lexer, source)
747 return self.get_tokenized_filenode_line(filenode, line_number, lexer, source)
726
748
727 if hl_mode in (self.HL_REAL, self.HL_FAST) and filename:
749 if hl_mode in (self.HL_REAL, self.HL_FAST) and filename:
728 lexer = self._get_lexer_for_filename(filename)
750 lexer = self._get_lexer_for_filename(filename)
729 return list(tokenize_string(line_text, lexer))
751 return list(tokenize_string(line_text, lexer))
730
752
731 return list(tokenize_string(line_text, plain_text_lexer))
753 return list(tokenize_string(line_text, plain_text_lexer))
732
754
733 def get_tokenized_filenode_line(self, filenode, line_number, lexer=None, source=''):
755 def get_tokenized_filenode_line(self, filenode, line_number, lexer=None, source=''):
756 name_hash = hash(filenode)
734
757
735 def tokenize(_filenode):
758 hl_node_code = self.highlighted_filenodes[source]
736 self.highlighted_filenodes[source][filenode] = filenode_as_lines_tokens(filenode, lexer)
737
759
738 if filenode not in self.highlighted_filenodes[source]:
760 if name_hash not in hl_node_code:
739 tokenize(filenode)
761 hl_node_code[name_hash] = filenode_as_lines_tokens(filenode, lexer)
740
762
741 try:
763 try:
742 return self.highlighted_filenodes[source][filenode][line_number - 1]
764 return hl_node_code[name_hash][line_number - 1]
743 except Exception:
765 except Exception:
744 log.exception('diff rendering error')
766 log.exception('diff rendering error on L:%s and file=%s', line_number - 1, filenode.name)
745 return [('', 'L{}: rhodecode diff rendering error'.format(line_number))]
767 return [('', 'L{}: rhodecode diff rendering error'.format(line_number))]
746
768
747 def action_to_op(self, action):
769 def action_to_op(self, action):
748 return {
770 return {
749 'add': '+',
771 'add': '+',
750 'del': '-',
772 'del': '-',
751 'unmod': ' ',
773 'unmod': ' ',
752 'unmod-no-hl': ' ',
774 'unmod-no-hl': ' ',
753 'old-no-nl': ' ',
775 'old-no-nl': ' ',
754 'new-no-nl': ' ',
776 'new-no-nl': ' ',
755 }.get(action, action)
777 }.get(action, action)
756
778
757 def as_unified(self, lines):
779 def as_unified(self, lines):
758 """
780 """
759 Return a generator that yields the lines of a diff in unified order
781 Return a generator that yields the lines of a diff in unified order
760 """
782 """
761 def generator():
783 def generator():
762 buf = []
784 buf = []
763 for line in lines:
785 for line in lines:
764
786
765 if buf and not line.original or line.original.action == ' ':
787 if buf and not line.original or line.original.action == ' ':
766 for b in buf:
788 for b in buf:
767 yield b
789 yield b
768 buf = []
790 buf = []
769
791
770 if line.original:
792 if line.original:
771 if line.original.action == ' ':
793 if line.original.action == ' ':
772 yield (line.original.lineno, line.modified.lineno,
794 yield (line.original.lineno, line.modified.lineno,
773 line.original.action, line.original.content,
795 line.original.action, line.original.content,
774 line.original.get_comment_args)
796 line.original.get_comment_args)
775 continue
797 continue
776
798
777 if line.original.action == '-':
799 if line.original.action == '-':
778 yield (line.original.lineno, None,
800 yield (line.original.lineno, None,
779 line.original.action, line.original.content,
801 line.original.action, line.original.content,
780 line.original.get_comment_args)
802 line.original.get_comment_args)
781
803
782 if line.modified.action == '+':
804 if line.modified.action == '+':
783 buf.append((
805 buf.append((
784 None, line.modified.lineno,
806 None, line.modified.lineno,
785 line.modified.action, line.modified.content,
807 line.modified.action, line.modified.content,
786 line.modified.get_comment_args))
808 line.modified.get_comment_args))
787 continue
809 continue
788
810
789 if line.modified:
811 if line.modified:
790 yield (None, line.modified.lineno,
812 yield (None, line.modified.lineno,
791 line.modified.action, line.modified.content,
813 line.modified.action, line.modified.content,
792 line.modified.get_comment_args)
814 line.modified.get_comment_args)
793
815
794 for b in buf:
816 for b in buf:
795 yield b
817 yield b
796
818
797 return generator()
819 return generator()
@@ -1,679 +1,688 b''
1
1
2 # Copyright (C) 2010-2020 RhodeCode GmbH
2 # Copyright (C) 2010-2020 RhodeCode GmbH
3 #
3 #
4 # This program is free software: you can redistribute it and/or modify
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License, version 3
5 # it under the terms of the GNU Affero General Public License, version 3
6 # (only), as published by the Free Software Foundation.
6 # (only), as published by the Free Software Foundation.
7 #
7 #
8 # This program is distributed in the hope that it will be useful,
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # GNU General Public License for more details.
11 # GNU General Public License for more details.
12 #
12 #
13 # You should have received a copy of the GNU Affero General Public License
13 # You should have received a copy of the GNU Affero General Public License
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 #
15 #
16 # This program is dual-licensed. If you wish to learn more about the
16 # This program is dual-licensed. If you wish to learn more about the
17 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 # and proprietary license terms, please see https://rhodecode.com/licenses/
19
19
20 """
20 """
21 Database creation, and setup module for RhodeCode Enterprise. Used for creation
21 Database creation, and setup module for RhodeCode Enterprise. Used for creation
22 of database as well as for migration operations
22 of database as well as for migration operations
23 """
23 """
24
24
25 import os
25 import os
26 import sys
26 import sys
27 import time
27 import time
28 import uuid
28 import uuid
29 import logging
29 import logging
30 import getpass
30 import getpass
31 from os.path import dirname as dn, join as jn
31 from os.path import dirname as dn, join as jn
32
32
33 from sqlalchemy.engine import create_engine
33 from sqlalchemy.engine import create_engine
34
34
35 from rhodecode import __dbversion__
35 from rhodecode import __dbversion__
36 from rhodecode.model import init_model
36 from rhodecode.model import init_model
37 from rhodecode.model.user import UserModel
37 from rhodecode.model.user import UserModel
38 from rhodecode.model.db import (
38 from rhodecode.model.db import (
39 User, Permission, RhodeCodeUi, RhodeCodeSetting, UserToPerm,
39 User, Permission, RhodeCodeUi, RhodeCodeSetting, UserToPerm,
40 DbMigrateVersion, RepoGroup, UserRepoGroupToPerm, CacheKey, Repository)
40 DbMigrateVersion, RepoGroup, UserRepoGroupToPerm, CacheKey, Repository)
41 from rhodecode.model.meta import Session, Base
41 from rhodecode.model.meta import Session, Base
42 from rhodecode.model.permission import PermissionModel
42 from rhodecode.model.permission import PermissionModel
43 from rhodecode.model.repo import RepoModel
43 from rhodecode.model.repo import RepoModel
44 from rhodecode.model.repo_group import RepoGroupModel
44 from rhodecode.model.repo_group import RepoGroupModel
45 from rhodecode.model.settings import SettingsModel
45 from rhodecode.model.settings import SettingsModel
46
46
47
47
48 log = logging.getLogger(__name__)
48 log = logging.getLogger(__name__)
49
49
50
50
51 def notify(msg):
51 def notify(msg):
52 """
52 """
53 Notification for migrations messages
53 Notification for migrations messages
54 """
54 """
55 ml = len(msg) + (4 * 2)
55 ml = len(msg) + (4 * 2)
56 print(('\n%s\n*** %s ***\n%s' % ('*' * ml, msg, '*' * ml)).upper())
56 print((('\n%s\n*** %s ***\n%s' % ('*' * ml, msg, '*' * ml)).upper()))
57
57
58
58
59 class DbManage(object):
59 class DbManage(object):
60
60
61 def __init__(self, log_sql, dbconf, root, tests=False,
61 def __init__(self, log_sql, dbconf, root, tests=False,
62 SESSION=None, cli_args=None):
62 SESSION=None, cli_args=None, enc_key=b''):
63
63 self.dbname = dbconf.split('/')[-1]
64 self.dbname = dbconf.split('/')[-1]
64 self.tests = tests
65 self.tests = tests
65 self.root = root
66 self.root = root
66 self.dburi = dbconf
67 self.dburi = dbconf
67 self.log_sql = log_sql
68 self.log_sql = log_sql
68 self.cli_args = cli_args or {}
69 self.cli_args = cli_args or {}
70 self.sa = None
71 self.engine = None
72 self.enc_key = enc_key
73 # sets .sa .engine
69 self.init_db(SESSION=SESSION)
74 self.init_db(SESSION=SESSION)
75
70 self.ask_ok = self.get_ask_ok_func(self.cli_args.get('force_ask'))
76 self.ask_ok = self.get_ask_ok_func(self.cli_args.get('force_ask'))
71
77
72 def db_exists(self):
78 def db_exists(self):
73 if not self.sa:
79 if not self.sa:
74 self.init_db()
80 self.init_db()
75 try:
81 try:
76 self.sa.query(RhodeCodeUi)\
82 self.sa.query(RhodeCodeUi)\
77 .filter(RhodeCodeUi.ui_key == '/')\
83 .filter(RhodeCodeUi.ui_key == '/')\
78 .scalar()
84 .scalar()
79 return True
85 return True
80 except Exception:
86 except Exception:
81 return False
87 return False
82 finally:
88 finally:
83 self.sa.rollback()
89 self.sa.rollback()
84
90
85 def get_ask_ok_func(self, param):
91 def get_ask_ok_func(self, param):
86 if param not in [None]:
92 if param not in [None]:
87 # return a function lambda that has a default set to param
93 # return a function lambda that has a default set to param
88 return lambda *args, **kwargs: param
94 return lambda *args, **kwargs: param
89 else:
95 else:
90 from rhodecode.lib.utils import ask_ok
96 from rhodecode.lib.utils import ask_ok
91 return ask_ok
97 return ask_ok
92
98
93 def init_db(self, SESSION=None):
99 def init_db(self, SESSION=None):
100
94 if SESSION:
101 if SESSION:
95 self.sa = SESSION
102 self.sa = SESSION
103 self.engine = SESSION.bind
96 else:
104 else:
97 # init new sessions
105 # init new sessions
98 engine = create_engine(self.dburi, echo=self.log_sql)
106 engine = create_engine(self.dburi, echo=self.log_sql)
99 init_model(engine)
107 init_model(engine, encryption_key=self.enc_key)
100 self.sa = Session()
108 self.sa = Session()
109 self.engine = engine
101
110
102 def create_tables(self, override=False):
111 def create_tables(self, override=False):
103 """
112 """
104 Create a auth database
113 Create a auth database
105 """
114 """
106
115
107 log.info("Existing database with the same name is going to be destroyed.")
116 log.info("Existing database with the same name is going to be destroyed.")
108 log.info("Setup command will run DROP ALL command on that database.")
117 log.info("Setup command will run DROP ALL command on that database.")
118 engine = self.engine
119
109 if self.tests:
120 if self.tests:
110 destroy = True
121 destroy = True
111 else:
122 else:
112 destroy = self.ask_ok('Are you sure that you want to destroy the old database? [y/n]')
123 destroy = self.ask_ok('Are you sure that you want to destroy the old database? [y/n]')
113 if not destroy:
124 if not destroy:
114 log.info('db tables bootstrap: Nothing done.')
125 log.info('db tables bootstrap: Nothing done.')
115 sys.exit(0)
126 sys.exit(0)
116 if destroy:
127 if destroy:
117 Base.metadata.drop_all()
128 Base.metadata.drop_all(bind=engine)
118
129
119 checkfirst = not override
130 checkfirst = not override
120 Base.metadata.create_all(checkfirst=checkfirst)
131 Base.metadata.create_all(bind=engine, checkfirst=checkfirst)
121 log.info('Created tables for %s', self.dbname)
132 log.info('Created tables for %s', self.dbname)
122
133
123 def set_db_version(self):
134 def set_db_version(self):
124 ver = DbMigrateVersion()
135 ver = DbMigrateVersion()
125 ver.version = __dbversion__
136 ver.version = __dbversion__
126 ver.repository_id = 'rhodecode_db_migrations'
137 ver.repository_id = 'rhodecode_db_migrations'
127 ver.repository_path = 'versions'
138 ver.repository_path = 'versions'
128 self.sa.add(ver)
139 self.sa.add(ver)
129 log.info('db version set to: %s', __dbversion__)
140 log.info('db version set to: %s', __dbversion__)
130
141
131 def run_post_migration_tasks(self):
142 def run_post_migration_tasks(self):
132 """
143 """
133 Run various tasks before actually doing migrations
144 Run various tasks before actually doing migrations
134 """
145 """
135 # delete cache keys on each upgrade
146 # delete cache keys on each upgrade
136 total = CacheKey.query().count()
147 total = CacheKey.query().count()
137 log.info("Deleting (%s) cache keys now...", total)
148 log.info("Deleting (%s) cache keys now...", total)
138 CacheKey.delete_all_cache()
149 CacheKey.delete_all_cache()
139
150
140 def upgrade(self, version=None):
151 def upgrade(self, version=None):
141 """
152 """
142 Upgrades given database schema to given revision following
153 Upgrades given database schema to given revision following
143 all needed steps, to perform the upgrade
154 all needed steps, to perform the upgrade
144
155
145 """
156 """
146
157
147 from rhodecode.lib.dbmigrate.migrate.versioning import api
158 from rhodecode.lib.dbmigrate.migrate.versioning import api
148 from rhodecode.lib.dbmigrate.migrate.exceptions import \
159 from rhodecode.lib.dbmigrate.migrate.exceptions import DatabaseNotControlledError
149 DatabaseNotControlledError
150
160
151 if 'sqlite' in self.dburi:
161 if 'sqlite' in self.dburi:
152 print(
162 print(
153 '********************** WARNING **********************\n'
163 '********************** WARNING **********************\n'
154 'Make sure your version of sqlite is at least 3.7.X. \n'
164 'Make sure your version of sqlite is at least 3.7.X. \n'
155 'Earlier versions are known to fail on some migrations\n'
165 'Earlier versions are known to fail on some migrations\n'
156 '*****************************************************\n')
166 '*****************************************************\n')
157
167
158 upgrade = self.ask_ok(
168 upgrade = self.ask_ok(
159 'You are about to perform a database upgrade. Make '
169 'You are about to perform a database upgrade. Make '
160 'sure you have backed up your database. '
170 'sure you have backed up your database. '
161 'Continue ? [y/n]')
171 'Continue ? [y/n]')
162 if not upgrade:
172 if not upgrade:
163 log.info('No upgrade performed')
173 log.info('No upgrade performed')
164 sys.exit(0)
174 sys.exit(0)
165
175
166 repository_path = jn(dn(dn(dn(os.path.realpath(__file__)))),
176 repository_path = jn(dn(dn(dn(os.path.realpath(__file__)))),
167 'rhodecode/lib/dbmigrate')
177 'rhodecode/lib/dbmigrate')
168 db_uri = self.dburi
178 db_uri = self.dburi
169
179
170 if version:
180 if version:
171 DbMigrateVersion.set_version(version)
181 DbMigrateVersion.set_version(version)
172
182
173 try:
183 try:
174 curr_version = api.db_version(db_uri, repository_path)
184 curr_version = api.db_version(db_uri, repository_path)
175 msg = ('Found current database db_uri under version '
185 msg = (f'Found current database db_uri under version '
176 'control with version {}'.format(curr_version))
186 f'control with version {curr_version}')
177
187
178 except (RuntimeError, DatabaseNotControlledError):
188 except (RuntimeError, DatabaseNotControlledError):
179 curr_version = 1
189 curr_version = 1
180 msg = ('Current database is not under version control. Setting '
190 msg = f'Current database is not under version control. ' \
181 'as version %s' % curr_version)
191 f'Setting as version {curr_version}'
182 api.version_control(db_uri, repository_path, curr_version)
192 api.version_control(db_uri, repository_path, curr_version)
183
193
184 notify(msg)
194 notify(msg)
185
195
186
187 if curr_version == __dbversion__:
196 if curr_version == __dbversion__:
188 log.info('This database is already at the newest version')
197 log.info('This database is already at the newest version')
189 sys.exit(0)
198 sys.exit(0)
190
199
191 upgrade_steps = range(curr_version + 1, __dbversion__ + 1)
200 upgrade_steps = list(range(curr_version + 1, __dbversion__ + 1))
192 notify('attempting to upgrade database from '
201 notify(f'attempting to upgrade database from '
193 'version %s to version %s' % (curr_version, __dbversion__))
202 f'version {curr_version} to version {__dbversion__}')
194
203
195 # CALL THE PROPER ORDER OF STEPS TO PERFORM FULL UPGRADE
204 # CALL THE PROPER ORDER OF STEPS TO PERFORM FULL UPGRADE
196 _step = None
205 _step = None
197 for step in upgrade_steps:
206 for step in upgrade_steps:
198 notify('performing upgrade step %s' % step)
207 notify(f'performing upgrade step {step}')
199 time.sleep(0.5)
208 time.sleep(0.5)
200
209
201 api.upgrade(db_uri, repository_path, step)
210 api.upgrade(db_uri, repository_path, step)
202 self.sa.rollback()
211 self.sa.rollback()
203 notify('schema upgrade for step %s completed' % (step,))
212 notify(f'schema upgrade for step {step} completed')
204
213
205 _step = step
214 _step = step
206
215
207 self.run_post_migration_tasks()
216 self.run_post_migration_tasks()
208 notify('upgrade to version %s successful' % _step)
217 notify(f'upgrade to version {step} successful')
209
218
210 def fix_repo_paths(self):
219 def fix_repo_paths(self):
211 """
220 """
212 Fixes an old RhodeCode version path into new one without a '*'
221 Fixes an old RhodeCode version path into new one without a '*'
213 """
222 """
214
223
215 paths = self.sa.query(RhodeCodeUi)\
224 paths = self.sa.query(RhodeCodeUi)\
216 .filter(RhodeCodeUi.ui_key == '/')\
225 .filter(RhodeCodeUi.ui_key == '/')\
217 .scalar()
226 .scalar()
218
227
219 paths.ui_value = paths.ui_value.replace('*', '')
228 paths.ui_value = paths.ui_value.replace('*', '')
220
229
221 try:
230 try:
222 self.sa.add(paths)
231 self.sa.add(paths)
223 self.sa.commit()
232 self.sa.commit()
224 except Exception:
233 except Exception:
225 self.sa.rollback()
234 self.sa.rollback()
226 raise
235 raise
227
236
228 def fix_default_user(self):
237 def fix_default_user(self):
229 """
238 """
230 Fixes an old default user with some 'nicer' default values,
239 Fixes an old default user with some 'nicer' default values,
231 used mostly for anonymous access
240 used mostly for anonymous access
232 """
241 """
233 def_user = self.sa.query(User)\
242 def_user = self.sa.query(User)\
234 .filter(User.username == User.DEFAULT_USER)\
243 .filter(User.username == User.DEFAULT_USER)\
235 .one()
244 .one()
236
245
237 def_user.name = 'Anonymous'
246 def_user.name = 'Anonymous'
238 def_user.lastname = 'User'
247 def_user.lastname = 'User'
239 def_user.email = User.DEFAULT_USER_EMAIL
248 def_user.email = User.DEFAULT_USER_EMAIL
240
249
241 try:
250 try:
242 self.sa.add(def_user)
251 self.sa.add(def_user)
243 self.sa.commit()
252 self.sa.commit()
244 except Exception:
253 except Exception:
245 self.sa.rollback()
254 self.sa.rollback()
246 raise
255 raise
247
256
248 def fix_settings(self):
257 def fix_settings(self):
249 """
258 """
250 Fixes rhodecode settings and adds ga_code key for google analytics
259 Fixes rhodecode settings and adds ga_code key for google analytics
251 """
260 """
252
261
253 hgsettings3 = RhodeCodeSetting('ga_code', '')
262 hgsettings3 = RhodeCodeSetting('ga_code', '')
254
263
255 try:
264 try:
256 self.sa.add(hgsettings3)
265 self.sa.add(hgsettings3)
257 self.sa.commit()
266 self.sa.commit()
258 except Exception:
267 except Exception:
259 self.sa.rollback()
268 self.sa.rollback()
260 raise
269 raise
261
270
262 def create_admin_and_prompt(self):
271 def create_admin_and_prompt(self):
263
272
264 # defaults
273 # defaults
265 defaults = self.cli_args
274 defaults = self.cli_args
266 username = defaults.get('username')
275 username = defaults.get('username')
267 password = defaults.get('password')
276 password = defaults.get('password')
268 email = defaults.get('email')
277 email = defaults.get('email')
269
278
270 if username is None:
279 if username is None:
271 username = eval(input('Specify admin username:'))
280 username = eval(input('Specify admin username:'))
272 if password is None:
281 if password is None:
273 password = self._get_admin_password()
282 password = self._get_admin_password()
274 if not password:
283 if not password:
275 # second try
284 # second try
276 password = self._get_admin_password()
285 password = self._get_admin_password()
277 if not password:
286 if not password:
278 sys.exit()
287 sys.exit()
279 if email is None:
288 if email is None:
280 email = eval(input('Specify admin email:'))
289 email = eval(input('Specify admin email:'))
281 api_key = self.cli_args.get('api_key')
290 api_key = self.cli_args.get('api_key')
282 self.create_user(username, password, email, True,
291 self.create_user(username, password, email, True,
283 strict_creation_check=False,
292 strict_creation_check=False,
284 api_key=api_key)
293 api_key=api_key)
285
294
286 def _get_admin_password(self):
295 def _get_admin_password(self):
287 password = getpass.getpass('Specify admin password '
296 password = getpass.getpass('Specify admin password '
288 '(min 6 chars):')
297 '(min 6 chars):')
289 confirm = getpass.getpass('Confirm password:')
298 confirm = getpass.getpass('Confirm password:')
290
299
291 if password != confirm:
300 if password != confirm:
292 log.error('passwords mismatch')
301 log.error('passwords mismatch')
293 return False
302 return False
294 if len(password) < 6:
303 if len(password) < 6:
295 log.error('password is too short - use at least 6 characters')
304 log.error('password is too short - use at least 6 characters')
296 return False
305 return False
297
306
298 return password
307 return password
299
308
300 def create_test_admin_and_users(self):
309 def create_test_admin_and_users(self):
301 log.info('creating admin and regular test users')
310 log.info('creating admin and regular test users')
302 from rhodecode.tests import TEST_USER_ADMIN_LOGIN, \
311 from rhodecode.tests import TEST_USER_ADMIN_LOGIN, \
303 TEST_USER_ADMIN_PASS, TEST_USER_ADMIN_EMAIL, \
312 TEST_USER_ADMIN_PASS, TEST_USER_ADMIN_EMAIL, \
304 TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS, \
313 TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS, \
305 TEST_USER_REGULAR_EMAIL, TEST_USER_REGULAR2_LOGIN, \
314 TEST_USER_REGULAR_EMAIL, TEST_USER_REGULAR2_LOGIN, \
306 TEST_USER_REGULAR2_PASS, TEST_USER_REGULAR2_EMAIL
315 TEST_USER_REGULAR2_PASS, TEST_USER_REGULAR2_EMAIL
307
316
308 self.create_user(TEST_USER_ADMIN_LOGIN, TEST_USER_ADMIN_PASS,
317 self.create_user(TEST_USER_ADMIN_LOGIN, TEST_USER_ADMIN_PASS,
309 TEST_USER_ADMIN_EMAIL, True, api_key=True)
318 TEST_USER_ADMIN_EMAIL, True, api_key=True)
310
319
311 self.create_user(TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS,
320 self.create_user(TEST_USER_REGULAR_LOGIN, TEST_USER_REGULAR_PASS,
312 TEST_USER_REGULAR_EMAIL, False, api_key=True)
321 TEST_USER_REGULAR_EMAIL, False, api_key=True)
313
322
314 self.create_user(TEST_USER_REGULAR2_LOGIN, TEST_USER_REGULAR2_PASS,
323 self.create_user(TEST_USER_REGULAR2_LOGIN, TEST_USER_REGULAR2_PASS,
315 TEST_USER_REGULAR2_EMAIL, False, api_key=True)
324 TEST_USER_REGULAR2_EMAIL, False, api_key=True)
316
325
317 def create_ui_settings(self, repo_store_path):
326 def create_ui_settings(self, repo_store_path):
318 """
327 """
319 Creates ui settings, fills out hooks
328 Creates ui settings, fills out hooks
320 and disables dotencode
329 and disables dotencode
321 """
330 """
322 settings_model = SettingsModel(sa=self.sa)
331 settings_model = SettingsModel(sa=self.sa)
323 from rhodecode.lib.vcs.backends.hg import largefiles_store
332 from rhodecode.lib.vcs.backends.hg import largefiles_store
324 from rhodecode.lib.vcs.backends.git import lfs_store
333 from rhodecode.lib.vcs.backends.git import lfs_store
325
334
326 # Build HOOKS
335 # Build HOOKS
327 hooks = [
336 hooks = [
328 (RhodeCodeUi.HOOK_REPO_SIZE, 'python:vcsserver.hooks.repo_size'),
337 (RhodeCodeUi.HOOK_REPO_SIZE, 'python:vcsserver.hooks.repo_size'),
329
338
330 # HG
339 # HG
331 (RhodeCodeUi.HOOK_PRE_PULL, 'python:vcsserver.hooks.pre_pull'),
340 (RhodeCodeUi.HOOK_PRE_PULL, 'python:vcsserver.hooks.pre_pull'),
332 (RhodeCodeUi.HOOK_PULL, 'python:vcsserver.hooks.log_pull_action'),
341 (RhodeCodeUi.HOOK_PULL, 'python:vcsserver.hooks.log_pull_action'),
333 (RhodeCodeUi.HOOK_PRE_PUSH, 'python:vcsserver.hooks.pre_push'),
342 (RhodeCodeUi.HOOK_PRE_PUSH, 'python:vcsserver.hooks.pre_push'),
334 (RhodeCodeUi.HOOK_PRETX_PUSH, 'python:vcsserver.hooks.pre_push'),
343 (RhodeCodeUi.HOOK_PRETX_PUSH, 'python:vcsserver.hooks.pre_push'),
335 (RhodeCodeUi.HOOK_PUSH, 'python:vcsserver.hooks.log_push_action'),
344 (RhodeCodeUi.HOOK_PUSH, 'python:vcsserver.hooks.log_push_action'),
336 (RhodeCodeUi.HOOK_PUSH_KEY, 'python:vcsserver.hooks.key_push'),
345 (RhodeCodeUi.HOOK_PUSH_KEY, 'python:vcsserver.hooks.key_push'),
337
346
338 ]
347 ]
339
348
340 for key, value in hooks:
349 for key, value in hooks:
341 hook_obj = settings_model.get_ui_by_key(key)
350 hook_obj = settings_model.get_ui_by_key(key)
342 hooks2 = hook_obj if hook_obj else RhodeCodeUi()
351 hooks2 = hook_obj if hook_obj else RhodeCodeUi()
343 hooks2.ui_section = 'hooks'
352 hooks2.ui_section = 'hooks'
344 hooks2.ui_key = key
353 hooks2.ui_key = key
345 hooks2.ui_value = value
354 hooks2.ui_value = value
346 self.sa.add(hooks2)
355 self.sa.add(hooks2)
347
356
348 # enable largefiles
357 # enable largefiles
349 largefiles = RhodeCodeUi()
358 largefiles = RhodeCodeUi()
350 largefiles.ui_section = 'extensions'
359 largefiles.ui_section = 'extensions'
351 largefiles.ui_key = 'largefiles'
360 largefiles.ui_key = 'largefiles'
352 largefiles.ui_value = ''
361 largefiles.ui_value = ''
353 self.sa.add(largefiles)
362 self.sa.add(largefiles)
354
363
355 # set default largefiles cache dir, defaults to
364 # set default largefiles cache dir, defaults to
356 # /repo_store_location/.cache/largefiles
365 # /repo_store_location/.cache/largefiles
357 largefiles = RhodeCodeUi()
366 largefiles = RhodeCodeUi()
358 largefiles.ui_section = 'largefiles'
367 largefiles.ui_section = 'largefiles'
359 largefiles.ui_key = 'usercache'
368 largefiles.ui_key = 'usercache'
360 largefiles.ui_value = largefiles_store(repo_store_path)
369 largefiles.ui_value = largefiles_store(repo_store_path)
361
370
362 self.sa.add(largefiles)
371 self.sa.add(largefiles)
363
372
364 # set default lfs cache dir, defaults to
373 # set default lfs cache dir, defaults to
365 # /repo_store_location/.cache/lfs_store
374 # /repo_store_location/.cache/lfs_store
366 lfsstore = RhodeCodeUi()
375 lfsstore = RhodeCodeUi()
367 lfsstore.ui_section = 'vcs_git_lfs'
376 lfsstore.ui_section = 'vcs_git_lfs'
368 lfsstore.ui_key = 'store_location'
377 lfsstore.ui_key = 'store_location'
369 lfsstore.ui_value = lfs_store(repo_store_path)
378 lfsstore.ui_value = lfs_store(repo_store_path)
370
379
371 self.sa.add(lfsstore)
380 self.sa.add(lfsstore)
372
381
373 # enable hgsubversion disabled by default
382 # enable hgsubversion disabled by default
374 hgsubversion = RhodeCodeUi()
383 hgsubversion = RhodeCodeUi()
375 hgsubversion.ui_section = 'extensions'
384 hgsubversion.ui_section = 'extensions'
376 hgsubversion.ui_key = 'hgsubversion'
385 hgsubversion.ui_key = 'hgsubversion'
377 hgsubversion.ui_value = ''
386 hgsubversion.ui_value = ''
378 hgsubversion.ui_active = False
387 hgsubversion.ui_active = False
379 self.sa.add(hgsubversion)
388 self.sa.add(hgsubversion)
380
389
381 # enable hgevolve disabled by default
390 # enable hgevolve disabled by default
382 hgevolve = RhodeCodeUi()
391 hgevolve = RhodeCodeUi()
383 hgevolve.ui_section = 'extensions'
392 hgevolve.ui_section = 'extensions'
384 hgevolve.ui_key = 'evolve'
393 hgevolve.ui_key = 'evolve'
385 hgevolve.ui_value = ''
394 hgevolve.ui_value = ''
386 hgevolve.ui_active = False
395 hgevolve.ui_active = False
387 self.sa.add(hgevolve)
396 self.sa.add(hgevolve)
388
397
389 hgevolve = RhodeCodeUi()
398 hgevolve = RhodeCodeUi()
390 hgevolve.ui_section = 'experimental'
399 hgevolve.ui_section = 'experimental'
391 hgevolve.ui_key = 'evolution'
400 hgevolve.ui_key = 'evolution'
392 hgevolve.ui_value = ''
401 hgevolve.ui_value = ''
393 hgevolve.ui_active = False
402 hgevolve.ui_active = False
394 self.sa.add(hgevolve)
403 self.sa.add(hgevolve)
395
404
396 hgevolve = RhodeCodeUi()
405 hgevolve = RhodeCodeUi()
397 hgevolve.ui_section = 'experimental'
406 hgevolve.ui_section = 'experimental'
398 hgevolve.ui_key = 'evolution.exchange'
407 hgevolve.ui_key = 'evolution.exchange'
399 hgevolve.ui_value = ''
408 hgevolve.ui_value = ''
400 hgevolve.ui_active = False
409 hgevolve.ui_active = False
401 self.sa.add(hgevolve)
410 self.sa.add(hgevolve)
402
411
403 hgevolve = RhodeCodeUi()
412 hgevolve = RhodeCodeUi()
404 hgevolve.ui_section = 'extensions'
413 hgevolve.ui_section = 'extensions'
405 hgevolve.ui_key = 'topic'
414 hgevolve.ui_key = 'topic'
406 hgevolve.ui_value = ''
415 hgevolve.ui_value = ''
407 hgevolve.ui_active = False
416 hgevolve.ui_active = False
408 self.sa.add(hgevolve)
417 self.sa.add(hgevolve)
409
418
410 # enable hggit disabled by default
419 # enable hggit disabled by default
411 hggit = RhodeCodeUi()
420 hggit = RhodeCodeUi()
412 hggit.ui_section = 'extensions'
421 hggit.ui_section = 'extensions'
413 hggit.ui_key = 'hggit'
422 hggit.ui_key = 'hggit'
414 hggit.ui_value = ''
423 hggit.ui_value = ''
415 hggit.ui_active = False
424 hggit.ui_active = False
416 self.sa.add(hggit)
425 self.sa.add(hggit)
417
426
418 # set svn branch defaults
427 # set svn branch defaults
419 branches = ["/branches/*", "/trunk"]
428 branches = ["/branches/*", "/trunk"]
420 tags = ["/tags/*"]
429 tags = ["/tags/*"]
421
430
422 for branch in branches:
431 for branch in branches:
423 settings_model.create_ui_section_value(
432 settings_model.create_ui_section_value(
424 RhodeCodeUi.SVN_BRANCH_ID, branch)
433 RhodeCodeUi.SVN_BRANCH_ID, branch)
425
434
426 for tag in tags:
435 for tag in tags:
427 settings_model.create_ui_section_value(RhodeCodeUi.SVN_TAG_ID, tag)
436 settings_model.create_ui_section_value(RhodeCodeUi.SVN_TAG_ID, tag)
428
437
429 def create_auth_plugin_options(self, skip_existing=False):
438 def create_auth_plugin_options(self, skip_existing=False):
430 """
439 """
431 Create default auth plugin settings, and make it active
440 Create default auth plugin settings, and make it active
432
441
433 :param skip_existing:
442 :param skip_existing:
434 """
443 """
435 defaults = [
444 defaults = [
436 ('auth_plugins',
445 ('auth_plugins',
437 'egg:rhodecode-enterprise-ce#token,egg:rhodecode-enterprise-ce#rhodecode',
446 'egg:rhodecode-enterprise-ce#token,egg:rhodecode-enterprise-ce#rhodecode',
438 'list'),
447 'list'),
439
448
440 ('auth_authtoken_enabled',
449 ('auth_authtoken_enabled',
441 'True',
450 'True',
442 'bool'),
451 'bool'),
443
452
444 ('auth_rhodecode_enabled',
453 ('auth_rhodecode_enabled',
445 'True',
454 'True',
446 'bool'),
455 'bool'),
447 ]
456 ]
448 for k, v, t in defaults:
457 for k, v, t in defaults:
449 if (skip_existing and
458 if (skip_existing and
450 SettingsModel().get_setting_by_name(k) is not None):
459 SettingsModel().get_setting_by_name(k) is not None):
451 log.debug('Skipping option %s', k)
460 log.debug('Skipping option %s', k)
452 continue
461 continue
453 setting = RhodeCodeSetting(k, v, t)
462 setting = RhodeCodeSetting(k, v, t)
454 self.sa.add(setting)
463 self.sa.add(setting)
455
464
456 def create_default_options(self, skip_existing=False):
465 def create_default_options(self, skip_existing=False):
457 """Creates default settings"""
466 """Creates default settings"""
458
467
459 for k, v, t in [
468 for k, v, t in [
460 ('default_repo_enable_locking', False, 'bool'),
469 ('default_repo_enable_locking', False, 'bool'),
461 ('default_repo_enable_downloads', False, 'bool'),
470 ('default_repo_enable_downloads', False, 'bool'),
462 ('default_repo_enable_statistics', False, 'bool'),
471 ('default_repo_enable_statistics', False, 'bool'),
463 ('default_repo_private', False, 'bool'),
472 ('default_repo_private', False, 'bool'),
464 ('default_repo_type', 'hg', 'unicode')]:
473 ('default_repo_type', 'hg', 'unicode')]:
465
474
466 if (skip_existing and
475 if (skip_existing and
467 SettingsModel().get_setting_by_name(k) is not None):
476 SettingsModel().get_setting_by_name(k) is not None):
468 log.debug('Skipping option %s', k)
477 log.debug('Skipping option %s', k)
469 continue
478 continue
470 setting = RhodeCodeSetting(k, v, t)
479 setting = RhodeCodeSetting(k, v, t)
471 self.sa.add(setting)
480 self.sa.add(setting)
472
481
473 def fixup_groups(self):
482 def fixup_groups(self):
474 def_usr = User.get_default_user()
483 def_usr = User.get_default_user()
475 for g in RepoGroup.query().all():
484 for g in RepoGroup.query().all():
476 g.group_name = g.get_new_name(g.name)
485 g.group_name = g.get_new_name(g.name)
477 self.sa.add(g)
486 self.sa.add(g)
478 # get default perm
487 # get default perm
479 default = UserRepoGroupToPerm.query()\
488 default = UserRepoGroupToPerm.query()\
480 .filter(UserRepoGroupToPerm.group == g)\
489 .filter(UserRepoGroupToPerm.group == g)\
481 .filter(UserRepoGroupToPerm.user == def_usr)\
490 .filter(UserRepoGroupToPerm.user == def_usr)\
482 .scalar()
491 .scalar()
483
492
484 if default is None:
493 if default is None:
485 log.debug('missing default permission for group %s adding', g)
494 log.debug('missing default permission for group %s adding', g)
486 perm_obj = RepoGroupModel()._create_default_perms(g)
495 perm_obj = RepoGroupModel()._create_default_perms(g)
487 self.sa.add(perm_obj)
496 self.sa.add(perm_obj)
488
497
489 def reset_permissions(self, username):
498 def reset_permissions(self, username):
490 """
499 """
491 Resets permissions to default state, useful when old systems had
500 Resets permissions to default state, useful when old systems had
492 bad permissions, we must clean them up
501 bad permissions, we must clean them up
493
502
494 :param username:
503 :param username:
495 """
504 """
496 default_user = User.get_by_username(username)
505 default_user = User.get_by_username(username)
497 if not default_user:
506 if not default_user:
498 return
507 return
499
508
500 u2p = UserToPerm.query()\
509 u2p = UserToPerm.query()\
501 .filter(UserToPerm.user == default_user).all()
510 .filter(UserToPerm.user == default_user).all()
502 fixed = False
511 fixed = False
503 if len(u2p) != len(Permission.DEFAULT_USER_PERMISSIONS):
512 if len(u2p) != len(Permission.DEFAULT_USER_PERMISSIONS):
504 for p in u2p:
513 for p in u2p:
505 Session().delete(p)
514 Session().delete(p)
506 fixed = True
515 fixed = True
507 self.populate_default_permissions()
516 self.populate_default_permissions()
508 return fixed
517 return fixed
509
518
510 def config_prompt(self, test_repo_path='', retries=3):
519 def config_prompt(self, test_repo_path='', retries=3):
511 defaults = self.cli_args
520 defaults = self.cli_args
512 _path = defaults.get('repos_location')
521 _path = defaults.get('repos_location')
513 if retries == 3:
522 if retries == 3:
514 log.info('Setting up repositories config')
523 log.info('Setting up repositories config')
515
524
516 if _path is not None:
525 if _path is not None:
517 path = _path
526 path = _path
518 elif not self.tests and not test_repo_path:
527 elif not self.tests and not test_repo_path:
519 path = eval(input(
528 path = eval(input(
520 'Enter a valid absolute path to store repositories. '
529 'Enter a valid absolute path to store repositories. '
521 'All repositories in that path will be added automatically:'
530 'All repositories in that path will be added automatically:'
522 ))
531 ))
523 else:
532 else:
524 path = test_repo_path
533 path = test_repo_path
525 path_ok = True
534 path_ok = True
526
535
527 # check proper dir
536 # check proper dir
528 if not os.path.isdir(path):
537 if not os.path.isdir(path):
529 path_ok = False
538 path_ok = False
530 log.error('Given path %s is not a valid directory', path)
539 log.error('Given path %s is not a valid directory', path)
531
540
532 elif not os.path.isabs(path):
541 elif not os.path.isabs(path):
533 path_ok = False
542 path_ok = False
534 log.error('Given path %s is not an absolute path', path)
543 log.error('Given path %s is not an absolute path', path)
535
544
536 # check if path is at least readable.
545 # check if path is at least readable.
537 if not os.access(path, os.R_OK):
546 if not os.access(path, os.R_OK):
538 path_ok = False
547 path_ok = False
539 log.error('Given path %s is not readable', path)
548 log.error('Given path %s is not readable', path)
540
549
541 # check write access, warn user about non writeable paths
550 # check write access, warn user about non writeable paths
542 elif not os.access(path, os.W_OK) and path_ok:
551 elif not os.access(path, os.W_OK) and path_ok:
543 log.warning('No write permission to given path %s', path)
552 log.warning('No write permission to given path %s', path)
544
553
545 q = ('Given path %s is not writeable, do you want to '
554 q = (f'Given path {path} is not writeable, do you want to '
546 'continue with read only mode ? [y/n]' % (path,))
555 f'continue with read only mode ? [y/n]')
547 if not self.ask_ok(q):
556 if not self.ask_ok(q):
548 log.error('Canceled by user')
557 log.error('Canceled by user')
549 sys.exit(-1)
558 sys.exit(-1)
550
559
551 if retries == 0:
560 if retries == 0:
552 sys.exit('max retries reached')
561 sys.exit('max retries reached')
553 if not path_ok:
562 if not path_ok:
554 retries -= 1
563 retries -= 1
555 return self.config_prompt(test_repo_path, retries)
564 return self.config_prompt(test_repo_path, retries)
556
565
557 real_path = os.path.normpath(os.path.realpath(path))
566 real_path = os.path.normpath(os.path.realpath(path))
558
567
559 if real_path != os.path.normpath(path):
568 if real_path != os.path.normpath(path):
560 q = ('Path looks like a symlink, RhodeCode Enterprise will store '
569 q = (f'Path looks like a symlink, RhodeCode Enterprise will store '
561 'given path as %s ? [y/n]') % (real_path,)
570 f'given path as {real_path} ? [y/n]')
562 if not self.ask_ok(q):
571 if not self.ask_ok(q):
563 log.error('Canceled by user')
572 log.error('Canceled by user')
564 sys.exit(-1)
573 sys.exit(-1)
565
574
566 return real_path
575 return real_path
567
576
568 def create_settings(self, path):
577 def create_settings(self, path):
569
578
570 self.create_ui_settings(path)
579 self.create_ui_settings(path)
571
580
572 ui_config = [
581 ui_config = [
573 ('web', 'push_ssl', 'False'),
582 ('web', 'push_ssl', 'False'),
574 ('web', 'allow_archive', 'gz zip bz2'),
583 ('web', 'allow_archive', 'gz zip bz2'),
575 ('web', 'allow_push', '*'),
584 ('web', 'allow_push', '*'),
576 ('web', 'baseurl', '/'),
585 ('web', 'baseurl', '/'),
577 ('paths', '/', path),
586 ('paths', '/', path),
578 ('phases', 'publish', 'True')
587 ('phases', 'publish', 'True')
579 ]
588 ]
580 for section, key, value in ui_config:
589 for section, key, value in ui_config:
581 ui_conf = RhodeCodeUi()
590 ui_conf = RhodeCodeUi()
582 setattr(ui_conf, 'ui_section', section)
591 setattr(ui_conf, 'ui_section', section)
583 setattr(ui_conf, 'ui_key', key)
592 setattr(ui_conf, 'ui_key', key)
584 setattr(ui_conf, 'ui_value', value)
593 setattr(ui_conf, 'ui_value', value)
585 self.sa.add(ui_conf)
594 self.sa.add(ui_conf)
586
595
587 # rhodecode app settings
596 # rhodecode app settings
588 settings = [
597 settings = [
589 ('realm', 'RhodeCode', 'unicode'),
598 ('realm', 'RhodeCode', 'unicode'),
590 ('title', '', 'unicode'),
599 ('title', '', 'unicode'),
591 ('pre_code', '', 'unicode'),
600 ('pre_code', '', 'unicode'),
592 ('post_code', '', 'unicode'),
601 ('post_code', '', 'unicode'),
593
602
594 # Visual
603 # Visual
595 ('show_public_icon', True, 'bool'),
604 ('show_public_icon', True, 'bool'),
596 ('show_private_icon', True, 'bool'),
605 ('show_private_icon', True, 'bool'),
597 ('stylify_metatags', True, 'bool'),
606 ('stylify_metatags', True, 'bool'),
598 ('dashboard_items', 100, 'int'),
607 ('dashboard_items', 100, 'int'),
599 ('admin_grid_items', 25, 'int'),
608 ('admin_grid_items', 25, 'int'),
600
609
601 ('markup_renderer', 'markdown', 'unicode'),
610 ('markup_renderer', 'markdown', 'unicode'),
602
611
603 ('repository_fields', True, 'bool'),
612 ('repository_fields', True, 'bool'),
604 ('show_version', True, 'bool'),
613 ('show_version', True, 'bool'),
605 ('show_revision_number', True, 'bool'),
614 ('show_revision_number', True, 'bool'),
606 ('show_sha_length', 12, 'int'),
615 ('show_sha_length', 12, 'int'),
607
616
608 ('use_gravatar', False, 'bool'),
617 ('use_gravatar', False, 'bool'),
609 ('gravatar_url', User.DEFAULT_GRAVATAR_URL, 'unicode'),
618 ('gravatar_url', User.DEFAULT_GRAVATAR_URL, 'unicode'),
610
619
611 ('clone_uri_tmpl', Repository.DEFAULT_CLONE_URI, 'unicode'),
620 ('clone_uri_tmpl', Repository.DEFAULT_CLONE_URI, 'unicode'),
612 ('clone_uri_id_tmpl', Repository.DEFAULT_CLONE_URI_ID, 'unicode'),
621 ('clone_uri_id_tmpl', Repository.DEFAULT_CLONE_URI_ID, 'unicode'),
613 ('clone_uri_ssh_tmpl', Repository.DEFAULT_CLONE_URI_SSH, 'unicode'),
622 ('clone_uri_ssh_tmpl', Repository.DEFAULT_CLONE_URI_SSH, 'unicode'),
614 ('support_url', '', 'unicode'),
623 ('support_url', '', 'unicode'),
615 ('update_url', RhodeCodeSetting.DEFAULT_UPDATE_URL, 'unicode'),
624 ('update_url', RhodeCodeSetting.DEFAULT_UPDATE_URL, 'unicode'),
616
625
617 # VCS Settings
626 # VCS Settings
618 ('pr_merge_enabled', True, 'bool'),
627 ('pr_merge_enabled', True, 'bool'),
619 ('use_outdated_comments', True, 'bool'),
628 ('use_outdated_comments', True, 'bool'),
620 ('diff_cache', True, 'bool'),
629 ('diff_cache', True, 'bool'),
621 ]
630 ]
622
631
623 for key, val, type_ in settings:
632 for key, val, type_ in settings:
624 sett = RhodeCodeSetting(key, val, type_)
633 sett = RhodeCodeSetting(key, val, type_)
625 self.sa.add(sett)
634 self.sa.add(sett)
626
635
627 self.create_auth_plugin_options()
636 self.create_auth_plugin_options()
628 self.create_default_options()
637 self.create_default_options()
629
638
630 log.info('created ui config')
639 log.info('created ui config')
631
640
632 def create_user(self, username, password, email='', admin=False,
641 def create_user(self, username, password, email='', admin=False,
633 strict_creation_check=True, api_key=None):
642 strict_creation_check=True, api_key=None):
634 log.info('creating user `%s`', username)
643 log.info('creating user `%s`', username)
635 user = UserModel().create_or_update(
644 user = UserModel().create_or_update(
636 username, password, email, firstname='RhodeCode', lastname='Admin',
645 username, password, email, firstname='RhodeCode', lastname='Admin',
637 active=True, admin=admin, extern_type="rhodecode",
646 active=True, admin=admin, extern_type="rhodecode",
638 strict_creation_check=strict_creation_check)
647 strict_creation_check=strict_creation_check)
639
648
640 if api_key:
649 if api_key:
641 log.info('setting a new default auth token for user `%s`', username)
650 log.info('setting a new default auth token for user `%s`', username)
642 UserModel().add_auth_token(
651 UserModel().add_auth_token(
643 user=user, lifetime_minutes=-1,
652 user=user, lifetime_minutes=-1,
644 role=UserModel.auth_token_role.ROLE_ALL,
653 role=UserModel.auth_token_role.ROLE_ALL,
645 description='BUILTIN TOKEN')
654 description='BUILTIN TOKEN')
646
655
647 def create_default_user(self):
656 def create_default_user(self):
648 log.info('creating default user')
657 log.info('creating default user')
649 # create default user for handling default permissions.
658 # create default user for handling default permissions.
650 user = UserModel().create_or_update(username=User.DEFAULT_USER,
659 user = UserModel().create_or_update(username=User.DEFAULT_USER,
651 password=str(uuid.uuid1())[:20],
660 password=str(uuid.uuid1())[:20],
652 email=User.DEFAULT_USER_EMAIL,
661 email=User.DEFAULT_USER_EMAIL,
653 firstname='Anonymous',
662 firstname='Anonymous',
654 lastname='User',
663 lastname='User',
655 strict_creation_check=False)
664 strict_creation_check=False)
656 # based on configuration options activate/de-activate this user which
665 # based on configuration options activate/de-activate this user which
657 # controls anonymous access
666 # controls anonymous access
658 if self.cli_args.get('public_access') is False:
667 if self.cli_args.get('public_access') is False:
659 log.info('Public access disabled')
668 log.info('Public access disabled')
660 user.active = False
669 user.active = False
661 Session().add(user)
670 Session().add(user)
662 Session().commit()
671 Session().commit()
663
672
664 def create_permissions(self):
673 def create_permissions(self):
665 """
674 """
666 Creates all permissions defined in the system
675 Creates all permissions defined in the system
667 """
676 """
668 # module.(access|create|change|delete)_[name]
677 # module.(access|create|change|delete)_[name]
669 # module.(none|read|write|admin)
678 # module.(none|read|write|admin)
670 log.info('creating permissions')
679 log.info('creating permissions')
671 PermissionModel(self.sa).create_permissions()
680 PermissionModel(self.sa).create_permissions()
672
681
673 def populate_default_permissions(self):
682 def populate_default_permissions(self):
674 """
683 """
675 Populate default permissions. It will create only the default
684 Populate default permissions. It will create only the default
676 permissions that are missing, and not alter already defined ones
685 permissions that are missing, and not alter already defined ones
677 """
686 """
678 log.info('creating default user permissions')
687 log.info('creating default user permissions')
679 PermissionModel(self.sa).create_default_user_permissions(user=User.DEFAULT_USER)
688 PermissionModel(self.sa).create_default_user_permissions(user=User.DEFAULT_USER)
@@ -1,182 +1,202 b''
1
1
2 # Copyright (C) 2010-2020 RhodeCode GmbH
2 # Copyright (C) 2010-2020 RhodeCode GmbH
3 #
3 #
4 # This program is free software: you can redistribute it and/or modify
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License, version 3
5 # it under the terms of the GNU Affero General Public License, version 3
6 # (only), as published by the Free Software Foundation.
6 # (only), as published by the Free Software Foundation.
7 #
7 #
8 # This program is distributed in the hope that it will be useful,
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # GNU General Public License for more details.
11 # GNU General Public License for more details.
12 #
12 #
13 # You should have received a copy of the GNU Affero General Public License
13 # You should have received a copy of the GNU Affero General Public License
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 #
15 #
16 # This program is dual-licensed. If you wish to learn more about the
16 # This program is dual-licensed. If you wish to learn more about the
17 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 # and proprietary license terms, please see https://rhodecode.com/licenses/
19
19
20 """
20 """
21 Set of custom exceptions used in RhodeCode
21 Set of custom exceptions used in RhodeCode
22 """
22 """
23
23
24 from webob.exc import HTTPClientError
24 from webob.exc import HTTPClientError
25 from pyramid.httpexceptions import HTTPBadGateway
25 from pyramid.httpexceptions import HTTPBadGateway
26
26
27
27
28 class LdapUsernameError(Exception):
28 class LdapUsernameError(Exception):
29 pass
29 pass
30
30
31
31
32 class LdapPasswordError(Exception):
32 class LdapPasswordError(Exception):
33 pass
33 pass
34
34
35
35
36 class LdapConnectionError(Exception):
36 class LdapConnectionError(Exception):
37 pass
37 pass
38
38
39
39
40 class LdapImportError(Exception):
40 class LdapImportError(Exception):
41 pass
41 pass
42
42
43
43
44 class DefaultUserException(Exception):
44 class DefaultUserException(Exception):
45 pass
45 pass
46
46
47
47
48 class UserOwnsReposException(Exception):
48 class UserOwnsReposException(Exception):
49 pass
49 pass
50
50
51
51
52 class UserOwnsRepoGroupsException(Exception):
52 class UserOwnsRepoGroupsException(Exception):
53 pass
53 pass
54
54
55
55
56 class UserOwnsUserGroupsException(Exception):
56 class UserOwnsUserGroupsException(Exception):
57 pass
57 pass
58
58
59
59
60 class UserOwnsPullRequestsException(Exception):
60 class UserOwnsPullRequestsException(Exception):
61 pass
61 pass
62
62
63
63
64 class UserOwnsArtifactsException(Exception):
64 class UserOwnsArtifactsException(Exception):
65 pass
65 pass
66
66
67
67
68 class UserGroupAssignedException(Exception):
68 class UserGroupAssignedException(Exception):
69 pass
69 pass
70
70
71
71
72 class StatusChangeOnClosedPullRequestError(Exception):
72 class StatusChangeOnClosedPullRequestError(Exception):
73 pass
73 pass
74
74
75
75
76 class AttachedForksError(Exception):
76 class AttachedForksError(Exception):
77 pass
77 pass
78
78
79
79
80 class AttachedPullRequestsError(Exception):
80 class AttachedPullRequestsError(Exception):
81 pass
81 pass
82
82
83
83
84 class RepoGroupAssignmentError(Exception):
84 class RepoGroupAssignmentError(Exception):
85 pass
85 pass
86
86
87
87
88 class NonRelativePathError(Exception):
88 class NonRelativePathError(Exception):
89 pass
89 pass
90
90
91
91
92 class HTTPRequirementError(HTTPClientError):
92 class HTTPRequirementError(HTTPClientError):
93 title = explanation = 'Repository Requirement Missing'
93 title = explanation = 'Repository Requirement Missing'
94 reason = None
94 reason = None
95
95
96 def __init__(self, message, *args, **kwargs):
96 def __init__(self, message, *args, **kwargs):
97 self.title = self.explanation = message
97 self.title = self.explanation = message
98 super(HTTPRequirementError, self).__init__(*args, **kwargs)
98 super(HTTPRequirementError, self).__init__(*args, **kwargs)
99 self.args = (message, )
99 self.args = (message, )
100
100
101
101
102 class HTTPLockedRC(HTTPClientError):
102 class HTTPLockedRC(HTTPClientError):
103 """
103 """
104 Special Exception For locked Repos in RhodeCode, the return code can
104 Special Exception For locked Repos in RhodeCode, the return code can
105 be overwritten by _code keyword argument passed into constructors
105 be overwritten by _code keyword argument passed into constructors
106 """
106 """
107 code = 423
107 code = 423
108 title = explanation = 'Repository Locked'
108 title = explanation = 'Repository Locked'
109 reason = None
109 reason = None
110
110
111 def __init__(self, message, *args, **kwargs):
111 def __init__(self, message, *args, **kwargs):
112 from rhodecode import CONFIG
112 import rhodecode
113 from rhodecode.lib.utils2 import safe_int
113
114 _code = CONFIG.get('lock_ret_code')
114 self.code = rhodecode.ConfigGet().get_int('lock_ret_code', missing=self.code)
115 self.code = safe_int(_code, self.code)
115
116 self.title = self.explanation = message
116 self.title = self.explanation = message
117 super(HTTPLockedRC, self).__init__(*args, **kwargs)
117 super(HTTPLockedRC, self).__init__(*args, **kwargs)
118 self.args = (message, )
118 self.args = (message, )
119
119
120
120
121 class HTTPBranchProtected(HTTPClientError):
121 class HTTPBranchProtected(HTTPClientError):
122 """
122 """
123 Special Exception For Indicating that branch is protected in RhodeCode, the
123 Special Exception For Indicating that branch is protected in RhodeCode, the
124 return code can be overwritten by _code keyword argument passed into constructors
124 return code can be overwritten by _code keyword argument passed into constructors
125 """
125 """
126 code = 403
126 code = 403
127 title = explanation = 'Branch Protected'
127 title = explanation = 'Branch Protected'
128 reason = None
128 reason = None
129
129
130 def __init__(self, message, *args, **kwargs):
130 def __init__(self, message, *args, **kwargs):
131 self.title = self.explanation = message
131 self.title = self.explanation = message
132 super(HTTPBranchProtected, self).__init__(*args, **kwargs)
132 super(HTTPBranchProtected, self).__init__(*args, **kwargs)
133 self.args = (message, )
133 self.args = (message, )
134
134
135
135
136 class IMCCommitError(Exception):
136 class IMCCommitError(Exception):
137 pass
137 pass
138
138
139
139
140 class UserCreationError(Exception):
140 class UserCreationError(Exception):
141 pass
141 pass
142
142
143
143
144 class NotAllowedToCreateUserError(Exception):
144 class NotAllowedToCreateUserError(Exception):
145 pass
145 pass
146
146
147
147
148 class RepositoryCreationError(Exception):
148 class RepositoryCreationError(Exception):
149 pass
149 pass
150
150
151
151
152 class VCSServerUnavailable(HTTPBadGateway):
152 class VCSServerUnavailable(HTTPBadGateway):
153 """ HTTP Exception class for VCS Server errors """
153 """ HTTP Exception class for VCS Server errors """
154 code = 502
154 code = 502
155 title = 'VCS Server Error'
155 title = 'VCS Server Error'
156 causes = [
156 causes = [
157 'VCS Server is not running',
157 'VCS Server is not running',
158 'Incorrect vcs.server=host:port',
158 'Incorrect vcs.server=host:port',
159 'Incorrect vcs.server.protocol',
159 'Incorrect vcs.server.protocol',
160 ]
160 ]
161
161
162 def __init__(self, message=''):
162 def __init__(self, message=''):
163 self.explanation = 'Could not connect to VCS Server'
163 self.explanation = 'Could not connect to VCS Server'
164 if message:
164 if message:
165 self.explanation += ': ' + message
165 self.explanation += ': ' + message
166 super(VCSServerUnavailable, self).__init__()
166 super(VCSServerUnavailable, self).__init__()
167
167
168
168
169 class ArtifactMetadataDuplicate(ValueError):
169 class ArtifactMetadataDuplicate(ValueError):
170
170
171 def __init__(self, *args, **kwargs):
171 def __init__(self, *args, **kwargs):
172 self.err_section = kwargs.pop('err_section', None)
172 self.err_section = kwargs.pop('err_section', None)
173 self.err_key = kwargs.pop('err_key', None)
173 self.err_key = kwargs.pop('err_key', None)
174 super(ArtifactMetadataDuplicate, self).__init__(*args, **kwargs)
174 super(ArtifactMetadataDuplicate, self).__init__(*args, **kwargs)
175
175
176
176
177 class ArtifactMetadataBadValueType(ValueError):
177 class ArtifactMetadataBadValueType(ValueError):
178 pass
178 pass
179
179
180
180
181 class CommentVersionMismatch(ValueError):
181 class CommentVersionMismatch(ValueError):
182 pass
182 pass
183
184
185 class SignatureVerificationError(ValueError):
186 pass
187
188
189 def signature_verification_error(msg):
190 details = """
191 Encryption signature verification failed.
192 Please check your value of secret key, and/or encrypted value stored.
193 Secret key stored inside .ini file:
194 `rhodecode.encrypted_values.secret` or defaults to
195 `beaker.session.secret`
196
197 Probably the stored values were encrypted using a different secret then currently set in .ini file
198 """
199
200 final_msg = f'{msg}\n{details}'
201 return SignatureVerificationError(final_msg)
202
@@ -1,444 +1,443 b''
1 # Copyright (c) Django Software Foundation and individual contributors.
1 # Copyright (c) Django Software Foundation and individual contributors.
2 # All rights reserved.
2 # All rights reserved.
3 #
3 #
4 # Redistribution and use in source and binary forms, with or without modification,
4 # Redistribution and use in source and binary forms, with or without modification,
5 # are permitted provided that the following conditions are met:
5 # are permitted provided that the following conditions are met:
6 #
6 #
7 # 1. Redistributions of source code must retain the above copyright notice,
7 # 1. Redistributions of source code must retain the above copyright notice,
8 # this list of conditions and the following disclaimer.
8 # this list of conditions and the following disclaimer.
9 #
9 #
10 # 2. Redistributions in binary form must reproduce the above copyright
10 # 2. Redistributions in binary form must reproduce the above copyright
11 # notice, this list of conditions and the following disclaimer in the
11 # notice, this list of conditions and the following disclaimer in the
12 # documentation and/or other materials provided with the distribution.
12 # documentation and/or other materials provided with the distribution.
13 #
13 #
14 # 3. Neither the name of Django nor the names of its contributors may be used
14 # 3. Neither the name of Django nor the names of its contributors may be used
15 # to endorse or promote products derived from this software without
15 # to endorse or promote products derived from this software without
16 # specific prior written permission.
16 # specific prior written permission.
17 #
17 #
18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
18 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
19 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
20 # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20 # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
21 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
22 # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
22 # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
23 # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
23 # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
24 # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
24 # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
25 # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25 # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
26 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
27 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28
28
29 """
29 """
30 For definitions of the different versions of RSS, see:
30 For definitions of the different versions of RSS, see:
31 http://web.archive.org/web/20110718035220/http://diveintomark.org/archives/2004/02/04/incompatible-rss
31 http://web.archive.org/web/20110718035220/http://diveintomark.org/archives/2004/02/04/incompatible-rss
32 """
32 """
33
33
34
34
35 import datetime
35 import datetime
36 import io
36 import io
37
37
38 import pytz
39 from six.moves.urllib import parse as urlparse
38 from six.moves.urllib import parse as urlparse
40
39
41 from rhodecode.lib.feedgenerator import datetime_safe
40 from rhodecode.lib.feedgenerator import datetime_safe
42 from rhodecode.lib.feedgenerator.utils import SimplerXMLGenerator, iri_to_uri, force_text
41 from rhodecode.lib.feedgenerator.utils import SimplerXMLGenerator, iri_to_uri, force_text
43
42
44
43
45 #### The following code comes from ``django.utils.feedgenerator`` ####
44 #### The following code comes from ``django.utils.feedgenerator`` ####
46
45
47
46
48 def rfc2822_date(date):
47 def rfc2822_date(date):
49 # We can't use strftime() because it produces locale-dependent results, so
48 # We can't use strftime() because it produces locale-dependent results, so
50 # we have to map english month and day names manually
49 # we have to map english month and day names manually
51 months = ('Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec',)
50 months = ('Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec',)
52 days = ('Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun')
51 days = ('Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun')
53 # Support datetime objects older than 1900
52 # Support datetime objects older than 1900
54 date = datetime_safe.new_datetime(date)
53 date = datetime_safe.new_datetime(date)
55 # We do this ourselves to be timezone aware, email.Utils is not tz aware.
54 # We do this ourselves to be timezone aware, email.Utils is not tz aware.
56 dow = days[date.weekday()]
55 dow = days[date.weekday()]
57 month = months[date.month - 1]
56 month = months[date.month - 1]
58 time_str = date.strftime('%s, %%d %s %%Y %%H:%%M:%%S ' % (dow, month))
57 time_str = date.strftime('%s, %%d %s %%Y %%H:%%M:%%S ' % (dow, month))
59
58
60 offset = date.utcoffset()
59 offset = date.utcoffset()
61 # Historically, this function assumes that naive datetimes are in UTC.
60 # Historically, this function assumes that naive datetimes are in UTC.
62 if offset is None:
61 if offset is None:
63 return time_str + '-0000'
62 return time_str + '-0000'
64 else:
63 else:
65 timezone = (offset.days * 24 * 60) + (offset.seconds // 60)
64 timezone = (offset.days * 24 * 60) + (offset.seconds // 60)
66 hour, minute = divmod(timezone, 60)
65 hour, minute = divmod(timezone, 60)
67 return time_str + '%+03d%02d' % (hour, minute)
66 return time_str + '%+03d%02d' % (hour, minute)
68
67
69
68
70 def rfc3339_date(date):
69 def rfc3339_date(date):
71 # Support datetime objects older than 1900
70 # Support datetime objects older than 1900
72 date = datetime_safe.new_datetime(date)
71 date = datetime_safe.new_datetime(date)
73 time_str = date.strftime('%Y-%m-%dT%H:%M:%S')
72 time_str = date.strftime('%Y-%m-%dT%H:%M:%S')
74
73
75 offset = date.utcoffset()
74 offset = date.utcoffset()
76 # Historically, this function assumes that naive datetimes are in UTC.
75 # Historically, this function assumes that naive datetimes are in UTC.
77 if offset is None:
76 if offset is None:
78 return time_str + 'Z'
77 return time_str + 'Z'
79 else:
78 else:
80 timezone = (offset.days * 24 * 60) + (offset.seconds // 60)
79 timezone = (offset.days * 24 * 60) + (offset.seconds // 60)
81 hour, minute = divmod(timezone, 60)
80 hour, minute = divmod(timezone, 60)
82 return time_str + '%+03d:%02d' % (hour, minute)
81 return time_str + '%+03d:%02d' % (hour, minute)
83
82
84
83
85 def get_tag_uri(url, date):
84 def get_tag_uri(url, date):
86 """
85 """
87 Creates a TagURI.
86 Creates a TagURI.
88
87
89 See http://web.archive.org/web/20110514113830/http://diveintomark.org/archives/2004/05/28/howto-atom-id
88 See http://web.archive.org/web/20110514113830/http://diveintomark.org/archives/2004/05/28/howto-atom-id
90 """
89 """
91 bits = urlparse(url)
90 bits = urlparse(url)
92 d = ''
91 d = ''
93 if date is not None:
92 if date is not None:
94 d = ',%s' % datetime_safe.new_datetime(date).strftime('%Y-%m-%d')
93 d = ',%s' % datetime_safe.new_datetime(date).strftime('%Y-%m-%d')
95 return 'tag:%s%s:%s/%s' % (bits.hostname, d, bits.path, bits.fragment)
94 return 'tag:%s%s:%s/%s' % (bits.hostname, d, bits.path, bits.fragment)
96
95
97
96
98 class SyndicationFeed(object):
97 class SyndicationFeed(object):
99 """Base class for all syndication feeds. Subclasses should provide write()"""
98 """Base class for all syndication feeds. Subclasses should provide write()"""
100
99
101 def __init__(self, title, link, description, language=None, author_email=None,
100 def __init__(self, title, link, description, language=None, author_email=None,
102 author_name=None, author_link=None, subtitle=None, categories=None,
101 author_name=None, author_link=None, subtitle=None, categories=None,
103 feed_url=None, feed_copyright=None, feed_guid=None, ttl=None, **kwargs):
102 feed_url=None, feed_copyright=None, feed_guid=None, ttl=None, **kwargs):
104 def to_unicode(s):
103 def to_unicode(s):
105 return force_text(s, strings_only=True)
104 return force_text(s, strings_only=True)
106 if categories:
105 if categories:
107 categories = [force_text(c) for c in categories]
106 categories = [force_text(c) for c in categories]
108 if ttl is not None:
107 if ttl is not None:
109 # Force ints to unicode
108 # Force ints to unicode
110 ttl = force_text(ttl)
109 ttl = force_text(ttl)
111 self.feed = {
110 self.feed = {
112 'title': to_unicode(title),
111 'title': to_unicode(title),
113 'link': iri_to_uri(link),
112 'link': iri_to_uri(link),
114 'description': to_unicode(description),
113 'description': to_unicode(description),
115 'language': to_unicode(language),
114 'language': to_unicode(language),
116 'author_email': to_unicode(author_email),
115 'author_email': to_unicode(author_email),
117 'author_name': to_unicode(author_name),
116 'author_name': to_unicode(author_name),
118 'author_link': iri_to_uri(author_link),
117 'author_link': iri_to_uri(author_link),
119 'subtitle': to_unicode(subtitle),
118 'subtitle': to_unicode(subtitle),
120 'categories': categories or (),
119 'categories': categories or (),
121 'feed_url': iri_to_uri(feed_url),
120 'feed_url': iri_to_uri(feed_url),
122 'feed_copyright': to_unicode(feed_copyright),
121 'feed_copyright': to_unicode(feed_copyright),
123 'id': feed_guid or link,
122 'id': feed_guid or link,
124 'ttl': ttl,
123 'ttl': ttl,
125 }
124 }
126 self.feed.update(kwargs)
125 self.feed.update(kwargs)
127 self.items = []
126 self.items = []
128
127
129 def add_item(self, title, link, description, author_email=None,
128 def add_item(self, title, link, description, author_email=None,
130 author_name=None, author_link=None, pubdate=None, comments=None,
129 author_name=None, author_link=None, pubdate=None, comments=None,
131 unique_id=None, unique_id_is_permalink=None, enclosure=None,
130 unique_id=None, unique_id_is_permalink=None, enclosure=None,
132 categories=(), item_copyright=None, ttl=None, updateddate=None,
131 categories=(), item_copyright=None, ttl=None, updateddate=None,
133 enclosures=None, **kwargs):
132 enclosures=None, **kwargs):
134 """
133 """
135 Adds an item to the feed. All args are expected to be Python Unicode
134 Adds an item to the feed. All args are expected to be Python Unicode
136 objects except pubdate and updateddate, which are datetime.datetime
135 objects except pubdate and updateddate, which are datetime.datetime
137 objects, and enclosures, which is an iterable of instances of the
136 objects, and enclosures, which is an iterable of instances of the
138 Enclosure class.
137 Enclosure class.
139 """
138 """
140 def to_unicode(s):
139 def to_unicode(s):
141 return force_text(s, strings_only=True)
140 return force_text(s, strings_only=True)
142 if categories:
141 if categories:
143 categories = [to_unicode(c) for c in categories]
142 categories = [to_unicode(c) for c in categories]
144 if ttl is not None:
143 if ttl is not None:
145 # Force ints to unicode
144 # Force ints to unicode
146 ttl = force_text(ttl)
145 ttl = force_text(ttl)
147 if enclosure is None:
146 if enclosure is None:
148 enclosures = [] if enclosures is None else enclosures
147 enclosures = [] if enclosures is None else enclosures
149
148
150 item = {
149 item = {
151 'title': to_unicode(title),
150 'title': to_unicode(title),
152 'link': iri_to_uri(link),
151 'link': iri_to_uri(link),
153 'description': to_unicode(description),
152 'description': to_unicode(description),
154 'author_email': to_unicode(author_email),
153 'author_email': to_unicode(author_email),
155 'author_name': to_unicode(author_name),
154 'author_name': to_unicode(author_name),
156 'author_link': iri_to_uri(author_link),
155 'author_link': iri_to_uri(author_link),
157 'pubdate': pubdate,
156 'pubdate': pubdate,
158 'updateddate': updateddate,
157 'updateddate': updateddate,
159 'comments': to_unicode(comments),
158 'comments': to_unicode(comments),
160 'unique_id': to_unicode(unique_id),
159 'unique_id': to_unicode(unique_id),
161 'unique_id_is_permalink': unique_id_is_permalink,
160 'unique_id_is_permalink': unique_id_is_permalink,
162 'enclosures': enclosures,
161 'enclosures': enclosures,
163 'categories': categories or (),
162 'categories': categories or (),
164 'item_copyright': to_unicode(item_copyright),
163 'item_copyright': to_unicode(item_copyright),
165 'ttl': ttl,
164 'ttl': ttl,
166 }
165 }
167 item.update(kwargs)
166 item.update(kwargs)
168 self.items.append(item)
167 self.items.append(item)
169
168
170 def num_items(self):
169 def num_items(self):
171 return len(self.items)
170 return len(self.items)
172
171
173 def root_attributes(self):
172 def root_attributes(self):
174 """
173 """
175 Return extra attributes to place on the root (i.e. feed/channel) element.
174 Return extra attributes to place on the root (i.e. feed/channel) element.
176 Called from write().
175 Called from write().
177 """
176 """
178 return {}
177 return {}
179
178
180 def add_root_elements(self, handler):
179 def add_root_elements(self, handler):
181 """
180 """
182 Add elements in the root (i.e. feed/channel) element. Called
181 Add elements in the root (i.e. feed/channel) element. Called
183 from write().
182 from write().
184 """
183 """
185 pass
184 pass
186
185
187 def item_attributes(self, item):
186 def item_attributes(self, item):
188 """
187 """
189 Return extra attributes to place on each item (i.e. item/entry) element.
188 Return extra attributes to place on each item (i.e. item/entry) element.
190 """
189 """
191 return {}
190 return {}
192
191
193 def add_item_elements(self, handler, item):
192 def add_item_elements(self, handler, item):
194 """
193 """
195 Add elements on each item (i.e. item/entry) element.
194 Add elements on each item (i.e. item/entry) element.
196 """
195 """
197 pass
196 pass
198
197
199 def write(self, outfile, encoding):
198 def write(self, outfile, encoding):
200 """
199 """
201 Outputs the feed in the given encoding to outfile, which is a file-like
200 Outputs the feed in the given encoding to outfile, which is a file-like
202 object. Subclasses should override this.
201 object. Subclasses should override this.
203 """
202 """
204 raise NotImplementedError('subclasses of SyndicationFeed must provide a write() method')
203 raise NotImplementedError('subclasses of SyndicationFeed must provide a write() method')
205
204
206 def writeString(self, encoding):
205 def writeString(self, encoding):
207 """
206 """
208 Returns the feed in the given encoding as a string.
207 Returns the feed in the given encoding as a string.
209 """
208 """
210 s = io.StringIO()
209 s = io.StringIO()
211 self.write(s, encoding)
210 self.write(s, encoding)
212 return s.getvalue()
211 return s.getvalue()
213
212
214 def latest_post_date(self):
213 def latest_post_date(self):
215 """
214 """
216 Returns the latest item's pubdate or updateddate. If no items
215 Returns the latest item's pubdate or updateddate. If no items
217 have either of these attributes this returns the current UTC date/time.
216 have either of these attributes this returns the current UTC date/time.
218 """
217 """
219 latest_date = None
218 latest_date = None
220 date_keys = ('updateddate', 'pubdate')
219 date_keys = ('updateddate', 'pubdate')
221
220
222 for item in self.items:
221 for item in self.items:
223 for date_key in date_keys:
222 for date_key in date_keys:
224 item_date = item.get(date_key)
223 item_date = item.get(date_key)
225 if item_date:
224 if item_date:
226 if latest_date is None or item_date > latest_date:
225 if latest_date is None or item_date > latest_date:
227 latest_date = item_date
226 latest_date = item_date
228
227
229 # datetime.now(tz=utc) is slower, as documented in django.utils.timezone.now
228 # datetime.now(tz=utc) is slower, as documented in django.utils.timezone.now
230 return latest_date or datetime.datetime.utcnow().replace(tzinfo=pytz.utc)
229 return latest_date or datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc)
231
230
232
231
233 class Enclosure(object):
232 class Enclosure(object):
234 """Represents an RSS enclosure"""
233 """Represents an RSS enclosure"""
235 def __init__(self, url, length, mime_type):
234 def __init__(self, url, length, mime_type):
236 """All args are expected to be Python Unicode objects"""
235 """All args are expected to be Python Unicode objects"""
237 self.length, self.mime_type = length, mime_type
236 self.length, self.mime_type = length, mime_type
238 self.url = iri_to_uri(url)
237 self.url = iri_to_uri(url)
239
238
240
239
241 class RssFeed(SyndicationFeed):
240 class RssFeed(SyndicationFeed):
242 content_type = 'application/rss+xml; charset=utf-8'
241 content_type = 'application/rss+xml; charset=utf-8'
243
242
244 def write(self, outfile, encoding):
243 def write(self, outfile, encoding):
245 handler = SimplerXMLGenerator(outfile, encoding)
244 handler = SimplerXMLGenerator(outfile, encoding)
246 handler.startDocument()
245 handler.startDocument()
247 handler.startElement("rss", self.rss_attributes())
246 handler.startElement("rss", self.rss_attributes())
248 handler.startElement("channel", self.root_attributes())
247 handler.startElement("channel", self.root_attributes())
249 self.add_root_elements(handler)
248 self.add_root_elements(handler)
250 self.write_items(handler)
249 self.write_items(handler)
251 self.endChannelElement(handler)
250 self.endChannelElement(handler)
252 handler.endElement("rss")
251 handler.endElement("rss")
253
252
254 def rss_attributes(self):
253 def rss_attributes(self):
255 return {"version": self._version,
254 return {"version": self._version,
256 "xmlns:atom": "http://www.w3.org/2005/Atom"}
255 "xmlns:atom": "http://www.w3.org/2005/Atom"}
257
256
258 def write_items(self, handler):
257 def write_items(self, handler):
259 for item in self.items:
258 for item in self.items:
260 handler.startElement('item', self.item_attributes(item))
259 handler.startElement('item', self.item_attributes(item))
261 self.add_item_elements(handler, item)
260 self.add_item_elements(handler, item)
262 handler.endElement("item")
261 handler.endElement("item")
263
262
264 def add_root_elements(self, handler):
263 def add_root_elements(self, handler):
265 handler.addQuickElement("title", self.feed['title'])
264 handler.addQuickElement("title", self.feed['title'])
266 handler.addQuickElement("link", self.feed['link'])
265 handler.addQuickElement("link", self.feed['link'])
267 handler.addQuickElement("description", self.feed['description'])
266 handler.addQuickElement("description", self.feed['description'])
268 if self.feed['feed_url'] is not None:
267 if self.feed['feed_url'] is not None:
269 handler.addQuickElement("atom:link", None, {"rel": "self", "href": self.feed['feed_url']})
268 handler.addQuickElement("atom:link", None, {"rel": "self", "href": self.feed['feed_url']})
270 if self.feed['language'] is not None:
269 if self.feed['language'] is not None:
271 handler.addQuickElement("language", self.feed['language'])
270 handler.addQuickElement("language", self.feed['language'])
272 for cat in self.feed['categories']:
271 for cat in self.feed['categories']:
273 handler.addQuickElement("category", cat)
272 handler.addQuickElement("category", cat)
274 if self.feed['feed_copyright'] is not None:
273 if self.feed['feed_copyright'] is not None:
275 handler.addQuickElement("copyright", self.feed['feed_copyright'])
274 handler.addQuickElement("copyright", self.feed['feed_copyright'])
276 handler.addQuickElement("lastBuildDate", rfc2822_date(self.latest_post_date()))
275 handler.addQuickElement("lastBuildDate", rfc2822_date(self.latest_post_date()))
277 if self.feed['ttl'] is not None:
276 if self.feed['ttl'] is not None:
278 handler.addQuickElement("ttl", self.feed['ttl'])
277 handler.addQuickElement("ttl", self.feed['ttl'])
279
278
280 def endChannelElement(self, handler):
279 def endChannelElement(self, handler):
281 handler.endElement("channel")
280 handler.endElement("channel")
282
281
283
282
284 class RssUserland091Feed(RssFeed):
283 class RssUserland091Feed(RssFeed):
285 _version = "0.91"
284 _version = "0.91"
286
285
287 def add_item_elements(self, handler, item):
286 def add_item_elements(self, handler, item):
288 handler.addQuickElement("title", item['title'])
287 handler.addQuickElement("title", item['title'])
289 handler.addQuickElement("link", item['link'])
288 handler.addQuickElement("link", item['link'])
290 if item['description'] is not None:
289 if item['description'] is not None:
291 handler.addQuickElement("description", item['description'])
290 handler.addQuickElement("description", item['description'])
292
291
293
292
294 class Rss201rev2Feed(RssFeed):
293 class Rss201rev2Feed(RssFeed):
295 # Spec: http://blogs.law.harvard.edu/tech/rss
294 # Spec: http://blogs.law.harvard.edu/tech/rss
296 _version = "2.0"
295 _version = "2.0"
297
296
298 def add_item_elements(self, handler, item):
297 def add_item_elements(self, handler, item):
299 handler.addQuickElement("title", item['title'])
298 handler.addQuickElement("title", item['title'])
300 handler.addQuickElement("link", item['link'])
299 handler.addQuickElement("link", item['link'])
301 if item['description'] is not None:
300 if item['description'] is not None:
302 handler.addQuickElement("description", item['description'])
301 handler.addQuickElement("description", item['description'])
303
302
304 # Author information.
303 # Author information.
305 if item["author_name"] and item["author_email"]:
304 if item["author_name"] and item["author_email"]:
306 handler.addQuickElement("author", "%s (%s)" % (item['author_email'], item['author_name']))
305 handler.addQuickElement("author", "%s (%s)" % (item['author_email'], item['author_name']))
307 elif item["author_email"]:
306 elif item["author_email"]:
308 handler.addQuickElement("author", item["author_email"])
307 handler.addQuickElement("author", item["author_email"])
309 elif item["author_name"]:
308 elif item["author_name"]:
310 handler.addQuickElement(
309 handler.addQuickElement(
311 "dc:creator", item["author_name"], {"xmlns:dc": "http://purl.org/dc/elements/1.1/"}
310 "dc:creator", item["author_name"], {"xmlns:dc": "http://purl.org/dc/elements/1.1/"}
312 )
311 )
313
312
314 if item['pubdate'] is not None:
313 if item['pubdate'] is not None:
315 handler.addQuickElement("pubDate", rfc2822_date(item['pubdate']))
314 handler.addQuickElement("pubDate", rfc2822_date(item['pubdate']))
316 if item['comments'] is not None:
315 if item['comments'] is not None:
317 handler.addQuickElement("comments", item['comments'])
316 handler.addQuickElement("comments", item['comments'])
318 if item['unique_id'] is not None:
317 if item['unique_id'] is not None:
319 guid_attrs = {}
318 guid_attrs = {}
320 if isinstance(item.get('unique_id_is_permalink'), bool):
319 if isinstance(item.get('unique_id_is_permalink'), bool):
321 guid_attrs['isPermaLink'] = str(item['unique_id_is_permalink']).lower()
320 guid_attrs['isPermaLink'] = str(item['unique_id_is_permalink']).lower()
322 handler.addQuickElement("guid", item['unique_id'], guid_attrs)
321 handler.addQuickElement("guid", item['unique_id'], guid_attrs)
323 if item['ttl'] is not None:
322 if item['ttl'] is not None:
324 handler.addQuickElement("ttl", item['ttl'])
323 handler.addQuickElement("ttl", item['ttl'])
325
324
326 # Enclosure.
325 # Enclosure.
327 if item['enclosures']:
326 if item['enclosures']:
328 enclosures = list(item['enclosures'])
327 enclosures = list(item['enclosures'])
329 if len(enclosures) > 1:
328 if len(enclosures) > 1:
330 raise ValueError(
329 raise ValueError(
331 "RSS feed items may only have one enclosure, see "
330 "RSS feed items may only have one enclosure, see "
332 "http://www.rssboard.org/rss-profile#element-channel-item-enclosure"
331 "http://www.rssboard.org/rss-profile#element-channel-item-enclosure"
333 )
332 )
334 enclosure = enclosures[0]
333 enclosure = enclosures[0]
335 handler.addQuickElement('enclosure', '', {
334 handler.addQuickElement('enclosure', '', {
336 'url': enclosure.url,
335 'url': enclosure.url,
337 'length': enclosure.length,
336 'length': enclosure.length,
338 'type': enclosure.mime_type,
337 'type': enclosure.mime_type,
339 })
338 })
340
339
341 # Categories.
340 # Categories.
342 for cat in item['categories']:
341 for cat in item['categories']:
343 handler.addQuickElement("category", cat)
342 handler.addQuickElement("category", cat)
344
343
345
344
346 class Atom1Feed(SyndicationFeed):
345 class Atom1Feed(SyndicationFeed):
347 # Spec: https://tools.ietf.org/html/rfc4287
346 # Spec: https://tools.ietf.org/html/rfc4287
348 content_type = 'application/atom+xml; charset=utf-8'
347 content_type = 'application/atom+xml; charset=utf-8'
349 ns = "http://www.w3.org/2005/Atom"
348 ns = "http://www.w3.org/2005/Atom"
350
349
351 def write(self, outfile, encoding):
350 def write(self, outfile, encoding):
352 handler = SimplerXMLGenerator(outfile, encoding)
351 handler = SimplerXMLGenerator(outfile, encoding)
353 handler.startDocument()
352 handler.startDocument()
354 handler.startElement('feed', self.root_attributes())
353 handler.startElement('feed', self.root_attributes())
355 self.add_root_elements(handler)
354 self.add_root_elements(handler)
356 self.write_items(handler)
355 self.write_items(handler)
357 handler.endElement("feed")
356 handler.endElement("feed")
358
357
359 def root_attributes(self):
358 def root_attributes(self):
360 if self.feed['language'] is not None:
359 if self.feed['language'] is not None:
361 return {"xmlns": self.ns, "xml:lang": self.feed['language']}
360 return {"xmlns": self.ns, "xml:lang": self.feed['language']}
362 else:
361 else:
363 return {"xmlns": self.ns}
362 return {"xmlns": self.ns}
364
363
365 def add_root_elements(self, handler):
364 def add_root_elements(self, handler):
366 handler.addQuickElement("title", self.feed['title'])
365 handler.addQuickElement("title", self.feed['title'])
367 handler.addQuickElement("link", "", {"rel": "alternate", "href": self.feed['link']})
366 handler.addQuickElement("link", "", {"rel": "alternate", "href": self.feed['link']})
368 if self.feed['feed_url'] is not None:
367 if self.feed['feed_url'] is not None:
369 handler.addQuickElement("link", "", {"rel": "self", "href": self.feed['feed_url']})
368 handler.addQuickElement("link", "", {"rel": "self", "href": self.feed['feed_url']})
370 handler.addQuickElement("id", self.feed['id'])
369 handler.addQuickElement("id", self.feed['id'])
371 handler.addQuickElement("updated", rfc3339_date(self.latest_post_date()))
370 handler.addQuickElement("updated", rfc3339_date(self.latest_post_date()))
372 if self.feed['author_name'] is not None:
371 if self.feed['author_name'] is not None:
373 handler.startElement("author", {})
372 handler.startElement("author", {})
374 handler.addQuickElement("name", self.feed['author_name'])
373 handler.addQuickElement("name", self.feed['author_name'])
375 if self.feed['author_email'] is not None:
374 if self.feed['author_email'] is not None:
376 handler.addQuickElement("email", self.feed['author_email'])
375 handler.addQuickElement("email", self.feed['author_email'])
377 if self.feed['author_link'] is not None:
376 if self.feed['author_link'] is not None:
378 handler.addQuickElement("uri", self.feed['author_link'])
377 handler.addQuickElement("uri", self.feed['author_link'])
379 handler.endElement("author")
378 handler.endElement("author")
380 if self.feed['subtitle'] is not None:
379 if self.feed['subtitle'] is not None:
381 handler.addQuickElement("subtitle", self.feed['subtitle'])
380 handler.addQuickElement("subtitle", self.feed['subtitle'])
382 for cat in self.feed['categories']:
381 for cat in self.feed['categories']:
383 handler.addQuickElement("category", "", {"term": cat})
382 handler.addQuickElement("category", "", {"term": cat})
384 if self.feed['feed_copyright'] is not None:
383 if self.feed['feed_copyright'] is not None:
385 handler.addQuickElement("rights", self.feed['feed_copyright'])
384 handler.addQuickElement("rights", self.feed['feed_copyright'])
386
385
387 def write_items(self, handler):
386 def write_items(self, handler):
388 for item in self.items:
387 for item in self.items:
389 handler.startElement("entry", self.item_attributes(item))
388 handler.startElement("entry", self.item_attributes(item))
390 self.add_item_elements(handler, item)
389 self.add_item_elements(handler, item)
391 handler.endElement("entry")
390 handler.endElement("entry")
392
391
393 def add_item_elements(self, handler, item):
392 def add_item_elements(self, handler, item):
394 handler.addQuickElement("title", item['title'])
393 handler.addQuickElement("title", item['title'])
395 handler.addQuickElement("link", "", {"href": item['link'], "rel": "alternate"})
394 handler.addQuickElement("link", "", {"href": item['link'], "rel": "alternate"})
396
395
397 if item['pubdate'] is not None:
396 if item['pubdate'] is not None:
398 handler.addQuickElement('published', rfc3339_date(item['pubdate']))
397 handler.addQuickElement('published', rfc3339_date(item['pubdate']))
399
398
400 if item['updateddate'] is not None:
399 if item['updateddate'] is not None:
401 handler.addQuickElement('updated', rfc3339_date(item['updateddate']))
400 handler.addQuickElement('updated', rfc3339_date(item['updateddate']))
402
401
403 # Author information.
402 # Author information.
404 if item['author_name'] is not None:
403 if item['author_name'] is not None:
405 handler.startElement("author", {})
404 handler.startElement("author", {})
406 handler.addQuickElement("name", item['author_name'])
405 handler.addQuickElement("name", item['author_name'])
407 if item['author_email'] is not None:
406 if item['author_email'] is not None:
408 handler.addQuickElement("email", item['author_email'])
407 handler.addQuickElement("email", item['author_email'])
409 if item['author_link'] is not None:
408 if item['author_link'] is not None:
410 handler.addQuickElement("uri", item['author_link'])
409 handler.addQuickElement("uri", item['author_link'])
411 handler.endElement("author")
410 handler.endElement("author")
412
411
413 # Unique ID.
412 # Unique ID.
414 if item['unique_id'] is not None:
413 if item['unique_id'] is not None:
415 unique_id = item['unique_id']
414 unique_id = item['unique_id']
416 else:
415 else:
417 unique_id = get_tag_uri(item['link'], item['pubdate'])
416 unique_id = get_tag_uri(item['link'], item['pubdate'])
418 handler.addQuickElement("id", unique_id)
417 handler.addQuickElement("id", unique_id)
419
418
420 # Summary.
419 # Summary.
421 if item['description'] is not None:
420 if item['description'] is not None:
422 handler.addQuickElement("summary", item['description'], {"type": "html"})
421 handler.addQuickElement("summary", item['description'], {"type": "html"})
423
422
424 # Enclosures.
423 # Enclosures.
425 for enclosure in item['enclosures']:
424 for enclosure in item['enclosures']:
426 handler.addQuickElement('link', '', {
425 handler.addQuickElement('link', '', {
427 'rel': 'enclosure',
426 'rel': 'enclosure',
428 'href': enclosure.url,
427 'href': enclosure.url,
429 'length': enclosure.length,
428 'length': enclosure.length,
430 'type': enclosure.mime_type,
429 'type': enclosure.mime_type,
431 })
430 })
432
431
433 # Categories.
432 # Categories.
434 for cat in item['categories']:
433 for cat in item['categories']:
435 handler.addQuickElement("category", "", {"term": cat})
434 handler.addQuickElement("category", "", {"term": cat})
436
435
437 # Rights.
436 # Rights.
438 if item['item_copyright'] is not None:
437 if item['item_copyright'] is not None:
439 handler.addQuickElement("rights", item['item_copyright'])
438 handler.addQuickElement("rights", item['item_copyright'])
440
439
441
440
442 # This isolates the decision of what the system default is, so calling code can
441 # This isolates the decision of what the system default is, so calling code can
443 # do "feedgenerator.DefaultFeed" instead of "feedgenerator.Rss201rev2Feed".
442 # do "feedgenerator.DefaultFeed" instead of "feedgenerator.Rss201rev2Feed".
444 DefaultFeed = Rss201rev2Feed No newline at end of file
443 DefaultFeed = Rss201rev2Feed
@@ -1,155 +1,155 b''
1
1
2
2
3 # Copyright (C) 2012-2020 RhodeCode GmbH
3 # Copyright (C) 2012-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 """
21 """
22 Index schema for RhodeCode
22 Index schema for RhodeCode
23 """
23 """
24
24
25 import importlib
25 import importlib
26 import logging
26 import logging
27
27
28 from rhodecode.lib.index.search_utils import normalize_text_for_matching
28 from rhodecode.lib.index.search_utils import normalize_text_for_matching
29
29
30 log = logging.getLogger(__name__)
30 log = logging.getLogger(__name__)
31
31
32 # leave defaults for backward compat
32 # leave defaults for backward compat
33 default_searcher = 'rhodecode.lib.index.whoosh'
33 default_searcher = 'rhodecode.lib.index.whoosh'
34 default_location = '%(here)s/data/index'
34 default_location = '%(here)s/data/index'
35
35
36 ES_VERSION_2 = '2'
36 ES_VERSION_2 = '2'
37 ES_VERSION_6 = '6'
37 ES_VERSION_6 = '6'
38 # for legacy reasons we keep 2 compat as default
38 # for legacy reasons we keep 2 compat as default
39 DEFAULT_ES_VERSION = ES_VERSION_2
39 DEFAULT_ES_VERSION = ES_VERSION_2
40
40
41 try:
41 try:
42 from rhodecode_tools.lib.fts_index.elasticsearch_engine_6 import ES_CONFIG # pragma: no cover
42 from rhodecode_tools.lib.fts_index.elasticsearch_engine_6 import ES_CONFIG # pragma: no cover
43 except ImportError:
43 except ImportError:
44 log.warning('rhodecode_tools not available, use of full text search is limited')
44 log.warning('rhodecode_tools not available, use of full text search is limited')
45 pass
45 pass
46
46
47
47
48 class BaseSearcher(object):
48 class BaseSearcher(object):
49 query_lang_doc = ''
49 query_lang_doc = ''
50 es_version = None
50 es_version = None
51 name = None
51 name = None
52 DIRECTION_ASC = 'asc'
52 DIRECTION_ASC = 'asc'
53 DIRECTION_DESC = 'desc'
53 DIRECTION_DESC = 'desc'
54
54
55 def __init__(self):
55 def __init__(self):
56 pass
56 pass
57
57
58 def cleanup(self):
58 def cleanup(self):
59 pass
59 pass
60
60
61 def search(self, query, document_type, search_user,
61 def search(self, query, document_type, search_user,
62 repo_name=None, repo_group_name=None,
62 repo_name=None, repo_group_name=None,
63 raise_on_exc=True):
63 raise_on_exc=True):
64 raise Exception('NotImplemented')
64 raise Exception('NotImplemented')
65
65
66 @staticmethod
66 @staticmethod
67 def query_to_mark(query, default_field=None):
67 def query_to_mark(query, default_field=None):
68 """
68 """
69 Formats the query to mark token for jquery.mark.js highlighting. ES could
69 Formats the query to mark token for jquery.mark.js highlighting. ES could
70 have a different format optionally.
70 have a different format optionally.
71
71
72 :param default_field:
72 :param default_field:
73 :param query:
73 :param query:
74 """
74 """
75 return ' '.join(normalize_text_for_matching(query).split())
75 return ' '.join(normalize_text_for_matching(query).split())
76
76
77 @property
77 @property
78 def is_es_6(self):
78 def is_es_6(self):
79 return self.es_version == ES_VERSION_6
79 return self.es_version == ES_VERSION_6
80
80
81 def get_handlers(self):
81 def get_handlers(self):
82 return {}
82 return {}
83
83
84 @staticmethod
84 @staticmethod
85 def extract_search_tags(query):
85 def extract_search_tags(query):
86 return []
86 return []
87
87
88 @staticmethod
88 @staticmethod
89 def escape_specials(val):
89 def escape_specials(val):
90 """
90 """
91 Handle and escape reserved chars for search
91 Handle and escape reserved chars for search
92 """
92 """
93 return val
93 return val
94
94
95 def sort_def(self, search_type, direction, sort_field):
95 def sort_def(self, search_type, direction, sort_field):
96 """
96 """
97 Defines sorting for search. This function should decide if for given
97 Defines sorting for search. This function should decide if for given
98 search_type, sorting can be done with sort_field.
98 search_type, sorting can be done with sort_field.
99
99
100 It also should translate common sort fields into backend specific. e.g elasticsearch
100 It also should translate common sort fields into backend specific. e.g elasticsearch
101 """
101 """
102 raise NotImplementedError()
102 raise NotImplementedError()
103
103
104 @staticmethod
104 @staticmethod
105 def get_sort(search_type, search_val):
105 def get_sort(search_type, search_val):
106 """
106 """
107 Method used to parse the GET search sort value to a field and direction.
107 Method used to parse the GET search sort value to a field and direction.
108 e.g asc:lines == asc, lines
108 e.g asc:lines == asc, lines
109
109
110 There's also a legacy support for newfirst/oldfirst which defines commit
110 There's also a legacy support for newfirst/oldfirst which defines commit
111 sorting only
111 sorting only
112 """
112 """
113
113
114 direction = BaseSearcher.DIRECTION_ASC
114 direction = BaseSearcher.DIRECTION_ASC
115 sort_field = None
115 sort_field = None
116
116
117 if not search_val:
117 if not search_val:
118 return direction, sort_field
118 return direction, sort_field
119
119
120 if search_val.startswith('asc:'):
120 if search_val.startswith('asc:'):
121 sort_field = search_val[4:]
121 sort_field = search_val[4:]
122 direction = BaseSearcher.DIRECTION_ASC
122 direction = BaseSearcher.DIRECTION_ASC
123 elif search_val.startswith('desc:'):
123 elif search_val.startswith('desc:'):
124 sort_field = search_val[5:]
124 sort_field = search_val[5:]
125 direction = BaseSearcher.DIRECTION_DESC
125 direction = BaseSearcher.DIRECTION_DESC
126 elif search_val == 'newfirst' and search_type == 'commit':
126 elif search_val == 'newfirst' and search_type == 'commit':
127 sort_field = 'date'
127 sort_field = 'date'
128 direction = BaseSearcher.DIRECTION_DESC
128 direction = BaseSearcher.DIRECTION_DESC
129 elif search_val == 'oldfirst' and search_type == 'commit':
129 elif search_val == 'oldfirst' and search_type == 'commit':
130 sort_field = 'date'
130 sort_field = 'date'
131 direction = BaseSearcher.DIRECTION_ASC
131 direction = BaseSearcher.DIRECTION_ASC
132
132
133 return direction, sort_field
133 return direction, sort_field
134
134
135
135
136 def search_config(config, prefix='search.'):
136 def search_config(config, prefix='search.'):
137 _config = {}
137 _config = {}
138 for key in config.keys():
138 for key in config.keys():
139 if key.startswith(prefix):
139 if key.startswith(prefix):
140 _config[key[len(prefix):]] = config[key]
140 _config[key[len(prefix):]] = config[key]
141 return _config
141 return _config
142
142
143
143
144 def searcher_from_config(config, prefix='search.'):
144 def searcher_from_config(config, prefix='search.'):
145 _config = search_config(config, prefix)
145 _config = search_config(config, prefix)
146
146
147 if 'location' not in _config:
147 if 'location' not in _config:
148 _config['location'] = default_location
148 _config['location'] = default_location
149 if 'es_version' not in _config:
149 if 'es_version' not in _config:
150 # use old legacy ES version set to 2
150 # use an old legacy ES version set to 2
151 _config['es_version'] = '2'
151 _config['es_version'] = '2'
152
152
153 imported = importlib.import_module(_config.get('module', default_searcher))
153 imported = importlib.import_module(_config.get('module', default_searcher))
154 searcher = imported.Searcher(config=_config)
154 searcher = imported.Searcher(config=_config)
155 return searcher
155 return searcher
@@ -1,278 +1,173 b''
1
1
2 # Copyright (C) 2010-2020 RhodeCode GmbH
2 # Copyright (C) 2010-2020 RhodeCode GmbH
3 #
3 #
4 # This program is free software: you can redistribute it and/or modify
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License, version 3
5 # it under the terms of the GNU Affero General Public License, version 3
6 # (only), as published by the Free Software Foundation.
6 # (only), as published by the Free Software Foundation.
7 #
7 #
8 # This program is distributed in the hope that it will be useful,
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # GNU General Public License for more details.
11 # GNU General Public License for more details.
12 #
12 #
13 # You should have received a copy of the GNU Affero General Public License
13 # You should have received a copy of the GNU Affero General Public License
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 #
15 #
16 # This program is dual-licensed. If you wish to learn more about the
16 # This program is dual-licensed. If you wish to learn more about the
17 # RhodeCode Enterprise Edition, including its added features, Support services,
17 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 # and proprietary license terms, please see https://rhodecode.com/licenses/
19
19
20 import collections
21
22 import sqlalchemy
20 import sqlalchemy
23 from sqlalchemy import UnicodeText
21 from sqlalchemy import UnicodeText
24 from sqlalchemy.ext.mutable import Mutable
22 from sqlalchemy.ext.mutable import Mutable, \
23 MutableList as MutationList, \
24 MutableDict as MutationDict
25
25
26 from rhodecode.lib.ext_json import json
26 from rhodecode.lib import ext_json
27 from rhodecode.lib.utils2 import safe_unicode
28
27
29
28
30 class JsonRaw(str):
29 class JsonRaw(str):
31 """
30 """
32 Allows interacting with a JSON types field using a raw string.
31 Allows interacting with a JSON types field using a raw string.
33
32
34 For example::
33 For example::
35 db_instance = JsonTable()
34 db_instance = JsonTable()
36 db_instance.enabled = True
35 db_instance.enabled = True
37 db_instance.json_data = JsonRaw('{"a": 4}')
36 db_instance.json_data = JsonRaw('{"a": 4}')
38
37
39 This will bypass serialization/checks, and allow storing
38 This will bypass serialization/checks, and allow storing
40 raw values
39 raw values
41 """
40 """
42 pass
41 pass
43
42
44
43
45 # Set this to the standard dict if Order is not required
46 DictClass = collections.OrderedDict
47
48
49 class JSONEncodedObj(sqlalchemy.types.TypeDecorator):
44 class JSONEncodedObj(sqlalchemy.types.TypeDecorator):
50 """
45 """
51 Represents an immutable structure as a json-encoded string.
46 Represents an immutable structure as a json-encoded string.
52
47
53 If default is, for example, a dict, then a NULL value in the
48 If default is, for example, a dict, then a NULL value in the
54 database will be exposed as an empty dict.
49 database will be exposed as an empty dict.
55 """
50 """
56
51
57 impl = UnicodeText
52 impl = UnicodeText
58 safe = True
53 safe = True
59 enforce_unicode = True
54 enforce_str = True
60
55
61 def __init__(self, *args, **kwargs):
56 def __init__(self, *args, **kwargs):
62 self.default = kwargs.pop('default', None)
57 self.default = kwargs.pop('default', None)
63 self.safe = kwargs.pop('safe_json', self.safe)
58 self.safe = kwargs.pop('safe_json', self.safe)
64 self.enforce_unicode = kwargs.pop('enforce_unicode', self.enforce_unicode)
59 self.enforce_str = kwargs.pop('enforce_str', self.enforce_str)
65 self.dialect_map = kwargs.pop('dialect_map', {})
60 self.dialect_map = kwargs.pop('dialect_map', {})
66 super(JSONEncodedObj, self).__init__(*args, **kwargs)
61 super(JSONEncodedObj, self).__init__(*args, **kwargs)
67
62
68 def load_dialect_impl(self, dialect):
63 def load_dialect_impl(self, dialect):
69 if dialect.name in self.dialect_map:
64 if dialect.name in self.dialect_map:
70 return dialect.type_descriptor(self.dialect_map[dialect.name])
65 return dialect.type_descriptor(self.dialect_map[dialect.name])
71 return dialect.type_descriptor(self.impl)
66 return dialect.type_descriptor(self.impl)
72
67
73 def process_bind_param(self, value, dialect):
68 def process_bind_param(self, value, dialect):
74 if isinstance(value, JsonRaw):
69 if isinstance(value, JsonRaw):
75 value = value
70 value = value
76 elif value is not None:
71 elif value is not None:
77 value = json.dumps(value)
72 if self.enforce_str:
78 if self.enforce_unicode:
73 value = ext_json.str_json(value)
79 value = safe_unicode(value)
74 else:
75 value = ext_json.json.dumps(value)
80 return value
76 return value
81
77
82 def process_result_value(self, value, dialect):
78 def process_result_value(self, value, dialect):
83 if self.default is not None and (not value or value == '""'):
79 if self.default is not None and (not value or value == '""'):
84 return self.default()
80 return self.default()
85
81
86 if value is not None:
82 if value is not None:
87 try:
83 try:
88 value = json.loads(value, object_pairs_hook=DictClass)
84 value = ext_json.json.loads(value)
89 except Exception as e:
85 except Exception:
90 if self.safe and self.default is not None:
86 if self.safe and self.default is not None:
91 return self.default()
87 return self.default()
92 else:
88 else:
93 raise
89 raise
94 return value
90 return value
95
91
96
92
97 class MutationObj(Mutable):
93 class MutationObj(Mutable):
94
98 @classmethod
95 @classmethod
99 def coerce(cls, key, value):
96 def coerce(cls, key, value):
100 if isinstance(value, dict) and not isinstance(value, MutationDict):
97 if isinstance(value, dict) and not isinstance(value, MutationDict):
101 return MutationDict.coerce(key, value)
98 return MutationDict.coerce(key, value)
102 if isinstance(value, list) and not isinstance(value, MutationList):
99 if isinstance(value, list) and not isinstance(value, MutationList):
103 return MutationList.coerce(key, value)
100 return MutationList.coerce(key, value)
104 return value
101 return value
105
102
106 def de_coerce(self):
103 def de_coerce(self):
107 return self
104 return self
108
105
109 @classmethod
106 @classmethod
110 def _listen_on_attribute(cls, attribute, coerce, parent_cls):
107 def _listen_on_attribute(cls, attribute, coerce, parent_cls):
111 key = attribute.key
108 key = attribute.key
112 if parent_cls is not attribute.class_:
109 if parent_cls is not attribute.class_:
113 return
110 return
114
111
115 # rely on "propagate" here
112 # rely on "propagate" here
116 parent_cls = attribute.class_
113 parent_cls = attribute.class_
117
114
118 def load(state, *args):
115 def load(state, *args):
119 val = state.dict.get(key, None)
116 val = state.dict.get(key, None)
120 if coerce:
117 if coerce:
121 val = cls.coerce(key, val)
118 val = cls.coerce(key, val)
122 state.dict[key] = val
119 state.dict[key] = val
123 if isinstance(val, cls):
120 if isinstance(val, cls):
124 val._parents[state.obj()] = key
121 val._parents[state.obj()] = key
125
122
126 def set(target, value, oldvalue, initiator):
123 def set(target, value, oldvalue, initiator):
127 if not isinstance(value, cls):
124 if not isinstance(value, cls):
128 value = cls.coerce(key, value)
125 value = cls.coerce(key, value)
129 if isinstance(value, cls):
126 if isinstance(value, cls):
130 value._parents[target.obj()] = key
127 value._parents[target.obj()] = key
131 if isinstance(oldvalue, cls):
128 if isinstance(oldvalue, cls):
132 oldvalue._parents.pop(target.obj(), None)
129 oldvalue._parents.pop(target.obj(), None)
133 return value
130 return value
134
131
135 def pickle(state, state_dict):
132 def pickle(state, state_dict):
136 val = state.dict.get(key, None)
133 val = state.dict.get(key, None)
137 if isinstance(val, cls):
134 if isinstance(val, cls):
138 if 'ext.mutable.values' not in state_dict:
135 if 'ext.mutable.values' not in state_dict:
139 state_dict['ext.mutable.values'] = []
136 state_dict['ext.mutable.values'] = []
140 state_dict['ext.mutable.values'].append(val)
137 state_dict['ext.mutable.values'].append(val)
141
138
142 def unpickle(state, state_dict):
139 def unpickle(state, state_dict):
143 if 'ext.mutable.values' in state_dict:
140 if 'ext.mutable.values' in state_dict:
144 for val in state_dict['ext.mutable.values']:
141 for val in state_dict['ext.mutable.values']:
145 val._parents[state.obj()] = key
142 val._parents[state.obj()] = key
146
143
147 sqlalchemy.event.listen(parent_cls, 'load', load, raw=True,
144 sqlalchemy.event.listen(parent_cls, 'load', load, raw=True,
148 propagate=True)
145 propagate=True)
149 sqlalchemy.event.listen(parent_cls, 'refresh', load, raw=True,
146 sqlalchemy.event.listen(parent_cls, 'refresh', load, raw=True,
150 propagate=True)
147 propagate=True)
151 sqlalchemy.event.listen(parent_cls, 'pickle', pickle, raw=True,
148 sqlalchemy.event.listen(parent_cls, 'pickle', pickle, raw=True,
152 propagate=True)
149 propagate=True)
153 sqlalchemy.event.listen(attribute, 'set', set, raw=True, retval=True,
150 sqlalchemy.event.listen(attribute, 'set', set, raw=True, retval=True,
154 propagate=True)
151 propagate=True)
155 sqlalchemy.event.listen(parent_cls, 'unpickle', unpickle, raw=True,
152 sqlalchemy.event.listen(parent_cls, 'unpickle', unpickle, raw=True,
156 propagate=True)
153 propagate=True)
157
154
158
155
159 class MutationDict(MutationObj, DictClass):
160 @classmethod
161 def coerce(cls, key, value):
162 """Convert plain dictionary to MutationDict"""
163 self = MutationDict(
164 (k, MutationObj.coerce(key, v)) for (k, v) in value.items())
165 self._key = key
166 return self
167
168 def de_coerce(self):
169 return dict(self)
170
171 def __setitem__(self, key, value):
172 # Due to the way OrderedDict works, this is called during __init__.
173 # At this time we don't have a key set, but what is more, the value
174 # being set has already been coerced. So special case this and skip.
175 if hasattr(self, '_key'):
176 value = MutationObj.coerce(self._key, value)
177 DictClass.__setitem__(self, key, value)
178 self.changed()
179
180 def __delitem__(self, key):
181 DictClass.__delitem__(self, key)
182 self.changed()
183
184 def __setstate__(self, state):
185 self.__dict__ = state
186
187 def __reduce_ex__(self, proto):
188 # support pickling of MutationDicts
189 d = dict(self)
190 return (self.__class__, (d,))
191
192
193 class MutationList(MutationObj, list):
194 @classmethod
195 def coerce(cls, key, value):
196 """Convert plain list to MutationList"""
197 self = MutationList((MutationObj.coerce(key, v) for v in value))
198 self._key = key
199 return self
200
201 def de_coerce(self):
202 return list(self)
203
204 def __setitem__(self, idx, value):
205 list.__setitem__(self, idx, MutationObj.coerce(self._key, value))
206 self.changed()
207
208 def __setslice__(self, start, stop, values):
209 list.__setslice__(self, start, stop,
210 (MutationObj.coerce(self._key, v) for v in values))
211 self.changed()
212
213 def __delitem__(self, idx):
214 list.__delitem__(self, idx)
215 self.changed()
216
217 def __delslice__(self, start, stop):
218 list.__delslice__(self, start, stop)
219 self.changed()
220
221 def append(self, value):
222 list.append(self, MutationObj.coerce(self._key, value))
223 self.changed()
224
225 def insert(self, idx, value):
226 list.insert(self, idx, MutationObj.coerce(self._key, value))
227 self.changed()
228
229 def extend(self, values):
230 list.extend(self, (MutationObj.coerce(self._key, v) for v in values))
231 self.changed()
232
233 def pop(self, *args, **kw):
234 value = list.pop(self, *args, **kw)
235 self.changed()
236 return value
237
238 def remove(self, value):
239 list.remove(self, value)
240 self.changed()
241
242
243 def JsonType(impl=None, **kwargs):
156 def JsonType(impl=None, **kwargs):
244 """
157 """
245 Helper for using a mutation obj, it allows to use .with_variant easily.
158 Helper for using a mutation obj, it allows to use .with_variant easily.
246 example::
159 example::
247
160
248 settings = Column('settings_json',
161 settings = Column('settings_json',
249 MutationObj.as_mutable(
162 MutationObj.as_mutable(
250 JsonType(dialect_map=dict(mysql=UnicodeText(16384))))
163 JsonType(dialect_map=dict(mysql=UnicodeText(16384))))
251 """
164 """
252
165
253 if impl == 'list':
166 if impl == 'list':
254 return JSONEncodedObj(default=list, **kwargs)
167 return JSONEncodedObj(default=list, **kwargs)
255 elif impl == 'dict':
168 elif impl == 'dict':
256 return JSONEncodedObj(default=DictClass, **kwargs)
169 return JSONEncodedObj(default=dict, **kwargs)
257 else:
170 else:
258 return JSONEncodedObj(**kwargs)
171 return JSONEncodedObj(**kwargs)
259
172
260
173
261 JSON = MutationObj.as_mutable(JsonType())
262 """
263 A type to encode/decode JSON on the fly
264
265 sqltype is the string type for the underlying DB column::
266
267 Column(JSON) (defaults to UnicodeText)
268 """
269
270 JSONDict = MutationObj.as_mutable(JsonType('dict'))
271 """
272 A type to encode/decode JSON dictionaries on the fly
273 """
274
275 JSONList = MutationObj.as_mutable(JsonType('list'))
276 """
277 A type to encode/decode JSON lists` on the fly
278 """
General Comments 0
You need to be logged in to leave comments. Login now