##// END OF EJS Templates
caches: don't use deprecated md5 for key calculation....
marcink -
r2834:b65d885f default
parent child Browse files
Show More
@@ -1,295 +1,295 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2015-2018 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20 import functools
21 21
22 22 import beaker
23 23 import logging
24 24 import threading
25 25
26 26 from beaker.cache import _cache_decorate, cache_regions, region_invalidate
27 27 from sqlalchemy.exc import IntegrityError
28 28
29 from rhodecode.lib.utils import safe_str, md5
29 from rhodecode.lib.utils import safe_str, sha1
30 30 from rhodecode.model.db import Session, CacheKey
31 31
32 32 log = logging.getLogger(__name__)
33 33
34 34 FILE_TREE = 'cache_file_tree'
35 35 FILE_TREE_META = 'cache_file_tree_metadata'
36 36 FILE_SEARCH_TREE_META = 'cache_file_search_metadata'
37 37 SUMMARY_STATS = 'cache_summary_stats'
38 38
39 39 # This list of caches gets purged when invalidation happens
40 40 USED_REPO_CACHES = (FILE_TREE, FILE_SEARCH_TREE_META)
41 41
42 42 DEFAULT_CACHE_MANAGER_CONFIG = {
43 43 'type': 'memorylru_base',
44 44 'max_items': 10240,
45 45 'key_length': 256,
46 46 'enabled': True
47 47 }
48 48
49 49
50 50 def get_default_cache_settings(settings):
51 51 cache_settings = {}
52 52 for key in settings.keys():
53 53 for prefix in ['beaker.cache.', 'cache.']:
54 54 if key.startswith(prefix):
55 55 name = key.split(prefix)[1].strip()
56 56 cache_settings[name] = settings[key].strip()
57 57 return cache_settings
58 58
59 59
60 60 # set cache regions for beaker so celery can utilise it
61 61 def configure_caches(settings, default_region_settings=None):
62 62 cache_settings = {'regions': None}
63 63 # main cache settings used as default ...
64 64 cache_settings.update(get_default_cache_settings(settings))
65 65 default_region_settings = default_region_settings or \
66 66 {'type': DEFAULT_CACHE_MANAGER_CONFIG['type']}
67 67 if cache_settings['regions']:
68 68 for region in cache_settings['regions'].split(','):
69 69 region = region.strip()
70 70 region_settings = default_region_settings.copy()
71 71 for key, value in cache_settings.items():
72 72 if key.startswith(region):
73 73 region_settings[key.split(region + '.')[-1]] = value
74 74 log.debug('Configuring cache region `%s` with settings %s',
75 75 region, region_settings)
76 76 configure_cache_region(
77 77 region, region_settings, cache_settings)
78 78
79 79
80 80 def configure_cache_region(
81 81 region_name, region_settings, default_cache_kw, default_expire=60):
82 82 default_type = default_cache_kw.get('type', 'memory')
83 83 default_lock_dir = default_cache_kw.get('lock_dir')
84 84 default_data_dir = default_cache_kw.get('data_dir')
85 85
86 86 region_settings['lock_dir'] = region_settings.get('lock_dir', default_lock_dir)
87 87 region_settings['data_dir'] = region_settings.get('data_dir', default_data_dir)
88 88 region_settings['type'] = region_settings.get('type', default_type)
89 89 region_settings['expire'] = int(region_settings.get('expire', default_expire))
90 90
91 91 beaker.cache.cache_regions[region_name] = region_settings
92 92
93 93
94 94 def get_cache_manager(region_name, cache_name, custom_ttl=None):
95 95 """
96 96 Creates a Beaker cache manager. Such instance can be used like that::
97 97
98 98 _namespace = caches.get_repo_namespace_key(caches.XXX, repo_name)
99 99 cache_manager = caches.get_cache_manager('repo_cache_long', _namespace)
100 100 _cache_key = caches.compute_key_from_params(repo_name, commit.raw_id)
101 101 def heavy_compute():
102 102 ...
103 103 result = cache_manager.get(_cache_key, createfunc=heavy_compute)
104 104
105 105 :param region_name: region from ini file
106 106 :param cache_name: custom cache name, usually prefix+repo_name. eg
107 107 file_switcher_repo1
108 108 :param custom_ttl: override .ini file timeout on this cache
109 109 :return: instance of cache manager
110 110 """
111 111
112 112 cache_config = cache_regions.get(region_name, DEFAULT_CACHE_MANAGER_CONFIG)
113 113 if custom_ttl:
114 114 log.debug('Updating region %s with custom ttl: %s',
115 115 region_name, custom_ttl)
116 116 cache_config.update({'expire': custom_ttl})
117 117
118 118 return beaker.cache.Cache._get_cache(cache_name, cache_config)
119 119
120 120
121 121 def clear_cache_manager(cache_manager):
122 122 """
123 123 namespace = 'foobar'
124 124 cache_manager = get_cache_manager('repo_cache_long', namespace)
125 125 clear_cache_manager(cache_manager)
126 126 """
127 127
128 128 log.debug('Clearing all values for cache manager %s', cache_manager)
129 129 cache_manager.clear()
130 130
131 131
132 132 def clear_repo_caches(repo_name):
133 133 # invalidate cache manager for this repo
134 134 for prefix in USED_REPO_CACHES:
135 135 namespace = get_repo_namespace_key(prefix, repo_name)
136 136 cache_manager = get_cache_manager('repo_cache_long', namespace)
137 137 clear_cache_manager(cache_manager)
138 138
139 139
140 140 def compute_key_from_params(*args):
141 141 """
142 142 Helper to compute key from given params to be used in cache manager
143 143 """
144 return md5("_".join(map(safe_str, args)))
144 return sha1("_".join(map(safe_str, args)))
145 145
146 146
147 147 def get_repo_namespace_key(prefix, repo_name):
148 148 return '{0}_{1}'.format(prefix, compute_key_from_params(repo_name))
149 149
150 150
151 151 def conditional_cache(region, cache_namespace, condition, func):
152 152 """
153 153 Conditional caching function use like::
154 154 def _c(arg):
155 155 # heavy computation function
156 156 return data
157 157
158 158 # depending on the condition the compute is wrapped in cache or not
159 159 compute = conditional_cache('short_term', 'cache_namespace_id',
160 160 condition=True, func=func)
161 161 return compute(arg)
162 162
163 163 :param region: name of cache region
164 164 :param cache_namespace: cache namespace
165 165 :param condition: condition for cache to be triggered, and
166 166 return data cached
167 167 :param func: wrapped heavy function to compute
168 168
169 169 """
170 170 wrapped = func
171 171 if condition:
172 172 log.debug('conditional_cache: True, wrapping call of '
173 173 'func: %s into %s region cache', region, func)
174 174
175 175 def _cache_wrap(region_name, cache_namespace):
176 176 """Return a caching wrapper"""
177 177
178 178 def decorate(func):
179 179 @functools.wraps(func)
180 180 def cached(*args, **kwargs):
181 181 if kwargs:
182 182 raise AttributeError(
183 183 'Usage of kwargs is not allowed. '
184 184 'Use only positional arguments in wrapped function')
185 185 manager = get_cache_manager(region_name, cache_namespace)
186 186 cache_key = compute_key_from_params(*args)
187 187
188 188 def go():
189 189 return func(*args, **kwargs)
190 190
191 191 # save org function name
192 192 go.__name__ = '_cached_%s' % (func.__name__,)
193 193
194 194 return manager.get(cache_key, createfunc=go)
195 195 return cached
196 196
197 197 return decorate
198 198
199 199 cached_region = _cache_wrap(region, cache_namespace)
200 200 wrapped = cached_region(func)
201 201
202 202 return wrapped
203 203
204 204
205 205 class ActiveRegionCache(object):
206 206 def __init__(self, context):
207 207 self.context = context
208 208
209 209 def invalidate(self, *args, **kwargs):
210 210 return False
211 211
212 212 def compute(self):
213 213 log.debug('Context cache: getting obj %s from cache', self.context)
214 214 return self.context.compute_func(self.context.cache_key)
215 215
216 216
217 217 class FreshRegionCache(ActiveRegionCache):
218 218 def invalidate(self):
219 219 log.debug('Context cache: invalidating cache for %s', self.context)
220 220 region_invalidate(
221 221 self.context.compute_func, None, self.context.cache_key)
222 222 return True
223 223
224 224
225 225 class InvalidationContext(object):
226 226 def __repr__(self):
227 227 return '<InvalidationContext:{}[{}]>'.format(
228 228 safe_str(self.repo_name), safe_str(self.cache_type))
229 229
230 230 def __init__(self, compute_func, repo_name, cache_type,
231 231 raise_exception=False, thread_scoped=False):
232 232 self.compute_func = compute_func
233 233 self.repo_name = repo_name
234 234 self.cache_type = cache_type
235 235 self.cache_key = compute_key_from_params(
236 236 repo_name, cache_type)
237 237 self.raise_exception = raise_exception
238 238
239 239 # Append the thread id to the cache key if this invalidation context
240 240 # should be scoped to the current thread.
241 241 if thread_scoped:
242 242 thread_id = threading.current_thread().ident
243 243 self.cache_key = '{cache_key}_{thread_id}'.format(
244 244 cache_key=self.cache_key, thread_id=thread_id)
245 245
246 246 def get_cache_obj(self):
247 247 cache_key = CacheKey.get_cache_key(
248 248 self.repo_name, self.cache_type)
249 249 cache_obj = CacheKey.get_active_cache(cache_key)
250 250 if not cache_obj:
251 251 cache_obj = CacheKey(cache_key, self.repo_name)
252 252 return cache_obj
253 253
254 254 def __enter__(self):
255 255 """
256 256 Test if current object is valid, and return CacheRegion function
257 257 that does invalidation and calculation
258 258 """
259 259
260 260 self.cache_obj = self.get_cache_obj()
261 261 if self.cache_obj.cache_active:
262 262 # means our cache obj is existing and marked as it's
263 263 # cache is not outdated, we return BaseInvalidator
264 264 self.skip_cache_active_change = True
265 265 return ActiveRegionCache(self)
266 266
267 267 # the key is either not existing or set to False, we return
268 268 # the real invalidator which re-computes value. We additionally set
269 269 # the flag to actually update the Database objects
270 270 self.skip_cache_active_change = False
271 271 return FreshRegionCache(self)
272 272
273 273 def __exit__(self, exc_type, exc_val, exc_tb):
274 274
275 275 if self.skip_cache_active_change:
276 276 return
277 277
278 278 try:
279 279 self.cache_obj.cache_active = True
280 280 Session().add(self.cache_obj)
281 281 Session().commit()
282 282 except IntegrityError:
283 283 # if we catch integrity error, it means we inserted this object
284 284 # assumption is that's really an edge race-condition case and
285 285 # it's safe is to skip it
286 286 Session().rollback()
287 287 except Exception:
288 288 log.exception('Failed to commit on cache key update')
289 289 Session().rollback()
290 290 if self.raise_exception:
291 291 raise
292 292
293 293
294 294 def includeme(config):
295 295 configure_caches(config.registry.settings)
@@ -1,779 +1,779 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2018 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 """
22 22 Utilities library for RhodeCode
23 23 """
24 24
25 25 import datetime
26 26 import decorator
27 27 import json
28 28 import logging
29 29 import os
30 30 import re
31 31 import shutil
32 32 import tempfile
33 33 import traceback
34 34 import tarfile
35 35 import warnings
36 36 import hashlib
37 37 from os.path import join as jn
38 38
39 39 import paste
40 40 import pkg_resources
41 41 from webhelpers.text import collapse, remove_formatting, strip_tags
42 42 from mako import exceptions
43 43 from pyramid.threadlocal import get_current_registry
44 44 from rhodecode.lib.request import Request
45 45
46 46 from rhodecode.lib.fakemod import create_module
47 47 from rhodecode.lib.vcs.backends.base import Config
48 48 from rhodecode.lib.vcs.exceptions import VCSError
49 49 from rhodecode.lib.vcs.utils.helpers import get_scm, get_scm_backend
50 50 from rhodecode.lib.utils2 import (
51 safe_str, safe_unicode, get_current_rhodecode_user, md5)
51 safe_str, safe_unicode, get_current_rhodecode_user, md5, sha1)
52 52 from rhodecode.model import meta
53 53 from rhodecode.model.db import (
54 54 Repository, User, RhodeCodeUi, UserLog, RepoGroup, UserGroup)
55 55 from rhodecode.model.meta import Session
56 56
57 57
58 58 log = logging.getLogger(__name__)
59 59
60 60 REMOVED_REPO_PAT = re.compile(r'rm__\d{8}_\d{6}_\d{6}__.*')
61 61
62 62 # String which contains characters that are not allowed in slug names for
63 63 # repositories or repository groups. It is properly escaped to use it in
64 64 # regular expressions.
65 65 SLUG_BAD_CHARS = re.escape('`?=[]\;\'"<>,/~!@#$%^&*()+{}|:')
66 66
67 67 # Regex that matches forbidden characters in repo/group slugs.
68 68 SLUG_BAD_CHAR_RE = re.compile('[{}]'.format(SLUG_BAD_CHARS))
69 69
70 70 # Regex that matches allowed characters in repo/group slugs.
71 71 SLUG_GOOD_CHAR_RE = re.compile('[^{}]'.format(SLUG_BAD_CHARS))
72 72
73 73 # Regex that matches whole repo/group slugs.
74 74 SLUG_RE = re.compile('[^{}]+'.format(SLUG_BAD_CHARS))
75 75
76 76 _license_cache = None
77 77
78 78
79 79 def repo_name_slug(value):
80 80 """
81 81 Return slug of name of repository
82 82 This function is called on each creation/modification
83 83 of repository to prevent bad names in repo
84 84 """
85 85 replacement_char = '-'
86 86
87 87 slug = remove_formatting(value)
88 88 slug = SLUG_BAD_CHAR_RE.sub('', slug)
89 89 slug = re.sub('[\s]+', '-', slug)
90 90 slug = collapse(slug, replacement_char)
91 91 return slug
92 92
93 93
94 94 #==============================================================================
95 95 # PERM DECORATOR HELPERS FOR EXTRACTING NAMES FOR PERM CHECKS
96 96 #==============================================================================
97 97 def get_repo_slug(request):
98 98 _repo = ''
99 99
100 100 if hasattr(request, 'db_repo'):
101 101 # if our requests has set db reference use it for name, this
102 102 # translates the example.com/_<id> into proper repo names
103 103 _repo = request.db_repo.repo_name
104 104 elif getattr(request, 'matchdict', None):
105 105 # pyramid
106 106 _repo = request.matchdict.get('repo_name')
107 107
108 108 if _repo:
109 109 _repo = _repo.rstrip('/')
110 110 return _repo
111 111
112 112
113 113 def get_repo_group_slug(request):
114 114 _group = ''
115 115 if hasattr(request, 'db_repo_group'):
116 116 # if our requests has set db reference use it for name, this
117 117 # translates the example.com/_<id> into proper repo group names
118 118 _group = request.db_repo_group.group_name
119 119 elif getattr(request, 'matchdict', None):
120 120 # pyramid
121 121 _group = request.matchdict.get('repo_group_name')
122 122
123 123
124 124 if _group:
125 125 _group = _group.rstrip('/')
126 126 return _group
127 127
128 128
129 129 def get_user_group_slug(request):
130 130 _user_group = ''
131 131
132 132 if hasattr(request, 'db_user_group'):
133 133 _user_group = request.db_user_group.users_group_name
134 134 elif getattr(request, 'matchdict', None):
135 135 # pyramid
136 136 _user_group = request.matchdict.get('user_group_id')
137 137 _user_group_name = request.matchdict.get('user_group_name')
138 138 try:
139 139 if _user_group:
140 140 _user_group = UserGroup.get(_user_group)
141 141 elif _user_group_name:
142 142 _user_group = UserGroup.get_by_group_name(_user_group_name)
143 143
144 144 if _user_group:
145 145 _user_group = _user_group.users_group_name
146 146 except Exception:
147 147 log.exception('Failed to get user group by id and name')
148 148 # catch all failures here
149 149 return None
150 150
151 151 return _user_group
152 152
153 153
154 154 def get_filesystem_repos(path, recursive=False, skip_removed_repos=True):
155 155 """
156 156 Scans given path for repos and return (name,(type,path)) tuple
157 157
158 158 :param path: path to scan for repositories
159 159 :param recursive: recursive search and return names with subdirs in front
160 160 """
161 161
162 162 # remove ending slash for better results
163 163 path = path.rstrip(os.sep)
164 164 log.debug('now scanning in %s location recursive:%s...', path, recursive)
165 165
166 166 def _get_repos(p):
167 167 dirpaths = _get_dirpaths(p)
168 168 if not _is_dir_writable(p):
169 169 log.warning('repo path without write access: %s', p)
170 170
171 171 for dirpath in dirpaths:
172 172 if os.path.isfile(os.path.join(p, dirpath)):
173 173 continue
174 174 cur_path = os.path.join(p, dirpath)
175 175
176 176 # skip removed repos
177 177 if skip_removed_repos and REMOVED_REPO_PAT.match(dirpath):
178 178 continue
179 179
180 180 #skip .<somethin> dirs
181 181 if dirpath.startswith('.'):
182 182 continue
183 183
184 184 try:
185 185 scm_info = get_scm(cur_path)
186 186 yield scm_info[1].split(path, 1)[-1].lstrip(os.sep), scm_info
187 187 except VCSError:
188 188 if not recursive:
189 189 continue
190 190 #check if this dir containts other repos for recursive scan
191 191 rec_path = os.path.join(p, dirpath)
192 192 if os.path.isdir(rec_path):
193 193 for inner_scm in _get_repos(rec_path):
194 194 yield inner_scm
195 195
196 196 return _get_repos(path)
197 197
198 198
199 199 def _get_dirpaths(p):
200 200 try:
201 201 # OS-independable way of checking if we have at least read-only
202 202 # access or not.
203 203 dirpaths = os.listdir(p)
204 204 except OSError:
205 205 log.warning('ignoring repo path without read access: %s', p)
206 206 return []
207 207
208 208 # os.listpath has a tweak: If a unicode is passed into it, then it tries to
209 209 # decode paths and suddenly returns unicode objects itself. The items it
210 210 # cannot decode are returned as strings and cause issues.
211 211 #
212 212 # Those paths are ignored here until a solid solution for path handling has
213 213 # been built.
214 214 expected_type = type(p)
215 215
216 216 def _has_correct_type(item):
217 217 if type(item) is not expected_type:
218 218 log.error(
219 219 u"Ignoring path %s since it cannot be decoded into unicode.",
220 220 # Using "repr" to make sure that we see the byte value in case
221 221 # of support.
222 222 repr(item))
223 223 return False
224 224 return True
225 225
226 226 dirpaths = [item for item in dirpaths if _has_correct_type(item)]
227 227
228 228 return dirpaths
229 229
230 230
231 231 def _is_dir_writable(path):
232 232 """
233 233 Probe if `path` is writable.
234 234
235 235 Due to trouble on Cygwin / Windows, this is actually probing if it is
236 236 possible to create a file inside of `path`, stat does not produce reliable
237 237 results in this case.
238 238 """
239 239 try:
240 240 with tempfile.TemporaryFile(dir=path):
241 241 pass
242 242 except OSError:
243 243 return False
244 244 return True
245 245
246 246
247 247 def is_valid_repo(repo_name, base_path, expect_scm=None, explicit_scm=None, config=None):
248 248 """
249 249 Returns True if given path is a valid repository False otherwise.
250 250 If expect_scm param is given also, compare if given scm is the same
251 251 as expected from scm parameter. If explicit_scm is given don't try to
252 252 detect the scm, just use the given one to check if repo is valid
253 253
254 254 :param repo_name:
255 255 :param base_path:
256 256 :param expect_scm:
257 257 :param explicit_scm:
258 258 :param config:
259 259
260 260 :return True: if given path is a valid repository
261 261 """
262 262 full_path = os.path.join(safe_str(base_path), safe_str(repo_name))
263 263 log.debug('Checking if `%s` is a valid path for repository. '
264 264 'Explicit type: %s', repo_name, explicit_scm)
265 265
266 266 try:
267 267 if explicit_scm:
268 268 detected_scms = [get_scm_backend(explicit_scm)(
269 269 full_path, config=config).alias]
270 270 else:
271 271 detected_scms = get_scm(full_path)
272 272
273 273 if expect_scm:
274 274 return detected_scms[0] == expect_scm
275 275 log.debug('path: %s is an vcs object:%s', full_path, detected_scms)
276 276 return True
277 277 except VCSError:
278 278 log.debug('path: %s is not a valid repo !', full_path)
279 279 return False
280 280
281 281
282 282 def is_valid_repo_group(repo_group_name, base_path, skip_path_check=False):
283 283 """
284 284 Returns True if given path is a repository group, False otherwise
285 285
286 286 :param repo_name:
287 287 :param base_path:
288 288 """
289 289 full_path = os.path.join(safe_str(base_path), safe_str(repo_group_name))
290 290 log.debug('Checking if `%s` is a valid path for repository group',
291 291 repo_group_name)
292 292
293 293 # check if it's not a repo
294 294 if is_valid_repo(repo_group_name, base_path):
295 295 log.debug('Repo called %s exist, it is not a valid '
296 296 'repo group' % repo_group_name)
297 297 return False
298 298
299 299 try:
300 300 # we need to check bare git repos at higher level
301 301 # since we might match branches/hooks/info/objects or possible
302 302 # other things inside bare git repo
303 303 scm_ = get_scm(os.path.dirname(full_path))
304 304 log.debug('path: %s is a vcs object:%s, not valid '
305 305 'repo group' % (full_path, scm_))
306 306 return False
307 307 except VCSError:
308 308 pass
309 309
310 310 # check if it's a valid path
311 311 if skip_path_check or os.path.isdir(full_path):
312 312 log.debug('path: %s is a valid repo group !', full_path)
313 313 return True
314 314
315 315 log.debug('path: %s is not a valid repo group !', full_path)
316 316 return False
317 317
318 318
319 319 def ask_ok(prompt, retries=4, complaint='[y]es or [n]o please!'):
320 320 while True:
321 321 ok = raw_input(prompt)
322 322 if ok.lower() in ('y', 'ye', 'yes'):
323 323 return True
324 324 if ok.lower() in ('n', 'no', 'nop', 'nope'):
325 325 return False
326 326 retries = retries - 1
327 327 if retries < 0:
328 328 raise IOError
329 329 print(complaint)
330 330
331 331 # propagated from mercurial documentation
332 332 ui_sections = [
333 333 'alias', 'auth',
334 334 'decode/encode', 'defaults',
335 335 'diff', 'email',
336 336 'extensions', 'format',
337 337 'merge-patterns', 'merge-tools',
338 338 'hooks', 'http_proxy',
339 339 'smtp', 'patch',
340 340 'paths', 'profiling',
341 341 'server', 'trusted',
342 342 'ui', 'web', ]
343 343
344 344
345 345 def config_data_from_db(clear_session=True, repo=None):
346 346 """
347 347 Read the configuration data from the database and return configuration
348 348 tuples.
349 349 """
350 350 from rhodecode.model.settings import VcsSettingsModel
351 351
352 352 config = []
353 353
354 354 sa = meta.Session()
355 355 settings_model = VcsSettingsModel(repo=repo, sa=sa)
356 356
357 357 ui_settings = settings_model.get_ui_settings()
358 358
359 359 ui_data = []
360 360 for setting in ui_settings:
361 361 if setting.active:
362 362 ui_data.append((setting.section, setting.key, setting.value))
363 363 config.append((
364 364 safe_str(setting.section), safe_str(setting.key),
365 365 safe_str(setting.value)))
366 366 if setting.key == 'push_ssl':
367 367 # force set push_ssl requirement to False, rhodecode
368 368 # handles that
369 369 config.append((
370 370 safe_str(setting.section), safe_str(setting.key), False))
371 371 log.debug(
372 372 'settings ui from db: %s',
373 373 ','.join(map(lambda s: '[{}] {}={}'.format(*s), ui_data)))
374 374 if clear_session:
375 375 meta.Session.remove()
376 376
377 377 # TODO: mikhail: probably it makes no sense to re-read hooks information.
378 378 # It's already there and activated/deactivated
379 379 skip_entries = []
380 380 enabled_hook_classes = get_enabled_hook_classes(ui_settings)
381 381 if 'pull' not in enabled_hook_classes:
382 382 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRE_PULL))
383 383 if 'push' not in enabled_hook_classes:
384 384 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRE_PUSH))
385 385 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PRETX_PUSH))
386 386 skip_entries.append(('hooks', RhodeCodeUi.HOOK_PUSH_KEY))
387 387
388 388 config = [entry for entry in config if entry[:2] not in skip_entries]
389 389
390 390 return config
391 391
392 392
393 393 def make_db_config(clear_session=True, repo=None):
394 394 """
395 395 Create a :class:`Config` instance based on the values in the database.
396 396 """
397 397 config = Config()
398 398 config_data = config_data_from_db(clear_session=clear_session, repo=repo)
399 399 for section, option, value in config_data:
400 400 config.set(section, option, value)
401 401 return config
402 402
403 403
404 404 def get_enabled_hook_classes(ui_settings):
405 405 """
406 406 Return the enabled hook classes.
407 407
408 408 :param ui_settings: List of ui_settings as returned
409 409 by :meth:`VcsSettingsModel.get_ui_settings`
410 410
411 411 :return: a list with the enabled hook classes. The order is not guaranteed.
412 412 :rtype: list
413 413 """
414 414 enabled_hooks = []
415 415 active_hook_keys = [
416 416 key for section, key, value, active in ui_settings
417 417 if section == 'hooks' and active]
418 418
419 419 hook_names = {
420 420 RhodeCodeUi.HOOK_PUSH: 'push',
421 421 RhodeCodeUi.HOOK_PULL: 'pull',
422 422 RhodeCodeUi.HOOK_REPO_SIZE: 'repo_size'
423 423 }
424 424
425 425 for key in active_hook_keys:
426 426 hook = hook_names.get(key)
427 427 if hook:
428 428 enabled_hooks.append(hook)
429 429
430 430 return enabled_hooks
431 431
432 432
433 433 def set_rhodecode_config(config):
434 434 """
435 435 Updates pyramid config with new settings from database
436 436
437 437 :param config:
438 438 """
439 439 from rhodecode.model.settings import SettingsModel
440 440 app_settings = SettingsModel().get_all_settings()
441 441
442 442 for k, v in app_settings.items():
443 443 config[k] = v
444 444
445 445
446 446 def get_rhodecode_realm():
447 447 """
448 448 Return the rhodecode realm from database.
449 449 """
450 450 from rhodecode.model.settings import SettingsModel
451 451 realm = SettingsModel().get_setting_by_name('realm')
452 452 return safe_str(realm.app_settings_value)
453 453
454 454
455 455 def get_rhodecode_base_path():
456 456 """
457 457 Returns the base path. The base path is the filesystem path which points
458 458 to the repository store.
459 459 """
460 460 from rhodecode.model.settings import SettingsModel
461 461 paths_ui = SettingsModel().get_ui_by_section_and_key('paths', '/')
462 462 return safe_str(paths_ui.ui_value)
463 463
464 464
465 465 def map_groups(path):
466 466 """
467 467 Given a full path to a repository, create all nested groups that this
468 468 repo is inside. This function creates parent-child relationships between
469 469 groups and creates default perms for all new groups.
470 470
471 471 :param paths: full path to repository
472 472 """
473 473 from rhodecode.model.repo_group import RepoGroupModel
474 474 sa = meta.Session()
475 475 groups = path.split(Repository.NAME_SEP)
476 476 parent = None
477 477 group = None
478 478
479 479 # last element is repo in nested groups structure
480 480 groups = groups[:-1]
481 481 rgm = RepoGroupModel(sa)
482 482 owner = User.get_first_super_admin()
483 483 for lvl, group_name in enumerate(groups):
484 484 group_name = '/'.join(groups[:lvl] + [group_name])
485 485 group = RepoGroup.get_by_group_name(group_name)
486 486 desc = '%s group' % group_name
487 487
488 488 # skip folders that are now removed repos
489 489 if REMOVED_REPO_PAT.match(group_name):
490 490 break
491 491
492 492 if group is None:
493 493 log.debug('creating group level: %s group_name: %s',
494 494 lvl, group_name)
495 495 group = RepoGroup(group_name, parent)
496 496 group.group_description = desc
497 497 group.user = owner
498 498 sa.add(group)
499 499 perm_obj = rgm._create_default_perms(group)
500 500 sa.add(perm_obj)
501 501 sa.flush()
502 502
503 503 parent = group
504 504 return group
505 505
506 506
507 507 def repo2db_mapper(initial_repo_list, remove_obsolete=False):
508 508 """
509 509 maps all repos given in initial_repo_list, non existing repositories
510 510 are created, if remove_obsolete is True it also checks for db entries
511 511 that are not in initial_repo_list and removes them.
512 512
513 513 :param initial_repo_list: list of repositories found by scanning methods
514 514 :param remove_obsolete: check for obsolete entries in database
515 515 """
516 516 from rhodecode.model.repo import RepoModel
517 517 from rhodecode.model.repo_group import RepoGroupModel
518 518 from rhodecode.model.settings import SettingsModel
519 519
520 520 sa = meta.Session()
521 521 repo_model = RepoModel()
522 522 user = User.get_first_super_admin()
523 523 added = []
524 524
525 525 # creation defaults
526 526 defs = SettingsModel().get_default_repo_settings(strip_prefix=True)
527 527 enable_statistics = defs.get('repo_enable_statistics')
528 528 enable_locking = defs.get('repo_enable_locking')
529 529 enable_downloads = defs.get('repo_enable_downloads')
530 530 private = defs.get('repo_private')
531 531
532 532 for name, repo in initial_repo_list.items():
533 533 group = map_groups(name)
534 534 unicode_name = safe_unicode(name)
535 535 db_repo = repo_model.get_by_repo_name(unicode_name)
536 536 # found repo that is on filesystem not in RhodeCode database
537 537 if not db_repo:
538 538 log.info('repository %s not found, creating now', name)
539 539 added.append(name)
540 540 desc = (repo.description
541 541 if repo.description != 'unknown'
542 542 else '%s repository' % name)
543 543
544 544 db_repo = repo_model._create_repo(
545 545 repo_name=name,
546 546 repo_type=repo.alias,
547 547 description=desc,
548 548 repo_group=getattr(group, 'group_id', None),
549 549 owner=user,
550 550 enable_locking=enable_locking,
551 551 enable_downloads=enable_downloads,
552 552 enable_statistics=enable_statistics,
553 553 private=private,
554 554 state=Repository.STATE_CREATED
555 555 )
556 556 sa.commit()
557 557 # we added that repo just now, and make sure we updated server info
558 558 if db_repo.repo_type == 'git':
559 559 git_repo = db_repo.scm_instance()
560 560 # update repository server-info
561 561 log.debug('Running update server info')
562 562 git_repo._update_server_info()
563 563
564 564 db_repo.update_commit_cache()
565 565
566 566 config = db_repo._config
567 567 config.set('extensions', 'largefiles', '')
568 568 repo = db_repo.scm_instance(config=config)
569 569 repo.install_hooks()
570 570
571 571 removed = []
572 572 if remove_obsolete:
573 573 # remove from database those repositories that are not in the filesystem
574 574 for repo in sa.query(Repository).all():
575 575 if repo.repo_name not in initial_repo_list.keys():
576 576 log.debug("Removing non-existing repository found in db `%s`",
577 577 repo.repo_name)
578 578 try:
579 579 RepoModel(sa).delete(repo, forks='detach', fs_remove=False)
580 580 sa.commit()
581 581 removed.append(repo.repo_name)
582 582 except Exception:
583 583 # don't hold further removals on error
584 584 log.error(traceback.format_exc())
585 585 sa.rollback()
586 586
587 587 def splitter(full_repo_name):
588 588 _parts = full_repo_name.rsplit(RepoGroup.url_sep(), 1)
589 589 gr_name = None
590 590 if len(_parts) == 2:
591 591 gr_name = _parts[0]
592 592 return gr_name
593 593
594 594 initial_repo_group_list = [splitter(x) for x in
595 595 initial_repo_list.keys() if splitter(x)]
596 596
597 597 # remove from database those repository groups that are not in the
598 598 # filesystem due to parent child relationships we need to delete them
599 599 # in a specific order of most nested first
600 600 all_groups = [x.group_name for x in sa.query(RepoGroup).all()]
601 601 nested_sort = lambda gr: len(gr.split('/'))
602 602 for group_name in sorted(all_groups, key=nested_sort, reverse=True):
603 603 if group_name not in initial_repo_group_list:
604 604 repo_group = RepoGroup.get_by_group_name(group_name)
605 605 if (repo_group.children.all() or
606 606 not RepoGroupModel().check_exist_filesystem(
607 607 group_name=group_name, exc_on_failure=False)):
608 608 continue
609 609
610 610 log.info(
611 611 'Removing non-existing repository group found in db `%s`',
612 612 group_name)
613 613 try:
614 614 RepoGroupModel(sa).delete(group_name, fs_remove=False)
615 615 sa.commit()
616 616 removed.append(group_name)
617 617 except Exception:
618 618 # don't hold further removals on error
619 619 log.exception(
620 620 'Unable to remove repository group `%s`',
621 621 group_name)
622 622 sa.rollback()
623 623 raise
624 624
625 625 return added, removed
626 626
627 627
628 628 def load_rcextensions(root_path):
629 629 import rhodecode
630 630 from rhodecode.config import conf
631 631
632 632 path = os.path.join(root_path, 'rcextensions', '__init__.py')
633 633 if os.path.isfile(path):
634 634 rcext = create_module('rc', path)
635 635 EXT = rhodecode.EXTENSIONS = rcext
636 636 log.debug('Found rcextensions now loading %s...', rcext)
637 637
638 638 # Additional mappings that are not present in the pygments lexers
639 639 conf.LANGUAGES_EXTENSIONS_MAP.update(getattr(EXT, 'EXTRA_MAPPINGS', {}))
640 640
641 641 # auto check if the module is not missing any data, set to default if is
642 642 # this will help autoupdate new feature of rcext module
643 643 #from rhodecode.config import rcextensions
644 644 #for k in dir(rcextensions):
645 645 # if not k.startswith('_') and not hasattr(EXT, k):
646 646 # setattr(EXT, k, getattr(rcextensions, k))
647 647
648 648
649 649 def get_custom_lexer(extension):
650 650 """
651 651 returns a custom lexer if it is defined in rcextensions module, or None
652 652 if there's no custom lexer defined
653 653 """
654 654 import rhodecode
655 655 from pygments import lexers
656 656
657 657 # custom override made by RhodeCode
658 658 if extension in ['mako']:
659 659 return lexers.get_lexer_by_name('html+mako')
660 660
661 661 # check if we didn't define this extension as other lexer
662 662 extensions = rhodecode.EXTENSIONS and getattr(rhodecode.EXTENSIONS, 'EXTRA_LEXERS', None)
663 663 if extensions and extension in rhodecode.EXTENSIONS.EXTRA_LEXERS:
664 664 _lexer_name = rhodecode.EXTENSIONS.EXTRA_LEXERS[extension]
665 665 return lexers.get_lexer_by_name(_lexer_name)
666 666
667 667
668 668 #==============================================================================
669 669 # TEST FUNCTIONS AND CREATORS
670 670 #==============================================================================
671 671 def create_test_index(repo_location, config):
672 672 """
673 673 Makes default test index.
674 674 """
675 675 import rc_testdata
676 676
677 677 rc_testdata.extract_search_index(
678 678 'vcs_search_index', os.path.dirname(config['search.location']))
679 679
680 680
681 681 def create_test_directory(test_path):
682 682 """
683 683 Create test directory if it doesn't exist.
684 684 """
685 685 if not os.path.isdir(test_path):
686 686 log.debug('Creating testdir %s', test_path)
687 687 os.makedirs(test_path)
688 688
689 689
690 690 def create_test_database(test_path, config):
691 691 """
692 692 Makes a fresh database.
693 693 """
694 694 from rhodecode.lib.db_manage import DbManage
695 695
696 696 # PART ONE create db
697 697 dbconf = config['sqlalchemy.db1.url']
698 698 log.debug('making test db %s', dbconf)
699 699
700 700 dbmanage = DbManage(log_sql=False, dbconf=dbconf, root=config['here'],
701 701 tests=True, cli_args={'force_ask': True})
702 702 dbmanage.create_tables(override=True)
703 703 dbmanage.set_db_version()
704 704 # for tests dynamically set new root paths based on generated content
705 705 dbmanage.create_settings(dbmanage.config_prompt(test_path))
706 706 dbmanage.create_default_user()
707 707 dbmanage.create_test_admin_and_users()
708 708 dbmanage.create_permissions()
709 709 dbmanage.populate_default_permissions()
710 710 Session().commit()
711 711
712 712
713 713 def create_test_repositories(test_path, config):
714 714 """
715 715 Creates test repositories in the temporary directory. Repositories are
716 716 extracted from archives within the rc_testdata package.
717 717 """
718 718 import rc_testdata
719 719 from rhodecode.tests import HG_REPO, GIT_REPO, SVN_REPO
720 720
721 721 log.debug('making test vcs repositories')
722 722
723 723 idx_path = config['search.location']
724 724 data_path = config['cache_dir']
725 725
726 726 # clean index and data
727 727 if idx_path and os.path.exists(idx_path):
728 728 log.debug('remove %s', idx_path)
729 729 shutil.rmtree(idx_path)
730 730
731 731 if data_path and os.path.exists(data_path):
732 732 log.debug('remove %s', data_path)
733 733 shutil.rmtree(data_path)
734 734
735 735 rc_testdata.extract_hg_dump('vcs_test_hg', jn(test_path, HG_REPO))
736 736 rc_testdata.extract_git_dump('vcs_test_git', jn(test_path, GIT_REPO))
737 737
738 738 # Note: Subversion is in the process of being integrated with the system,
739 739 # until we have a properly packed version of the test svn repository, this
740 740 # tries to copy over the repo from a package "rc_testdata"
741 741 svn_repo_path = rc_testdata.get_svn_repo_archive()
742 742 with tarfile.open(svn_repo_path) as tar:
743 743 tar.extractall(jn(test_path, SVN_REPO))
744 744
745 745
746 746 def password_changed(auth_user, session):
747 747 # Never report password change in case of default user or anonymous user.
748 748 if auth_user.username == User.DEFAULT_USER or auth_user.user_id is None:
749 749 return False
750 750
751 751 password_hash = md5(auth_user.password) if auth_user.password else None
752 752 rhodecode_user = session.get('rhodecode_user', {})
753 753 session_password_hash = rhodecode_user.get('password', '')
754 754 return password_hash != session_password_hash
755 755
756 756
757 757 def read_opensource_licenses():
758 758 global _license_cache
759 759
760 760 if not _license_cache:
761 761 licenses = pkg_resources.resource_string(
762 762 'rhodecode', 'config/licenses.json')
763 763 _license_cache = json.loads(licenses)
764 764
765 765 return _license_cache
766 766
767 767
768 768 def generate_platform_uuid():
769 769 """
770 770 Generates platform UUID based on it's name
771 771 """
772 772 import platform
773 773
774 774 try:
775 775 uuid_list = [platform.platform()]
776 776 return hashlib.sha256(':'.join(uuid_list)).hexdigest()
777 777 except Exception as e:
778 778 log.error('Failed to generate host uuid: %s' % e)
779 779 return 'UNDEFINED'
@@ -1,996 +1,1004 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2011-2018 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21
22 22 """
23 23 Some simple helper functions
24 24 """
25 25
26 26 import collections
27 27 import datetime
28 28 import dateutil.relativedelta
29 29 import hashlib
30 30 import logging
31 31 import re
32 32 import sys
33 33 import time
34 34 import urllib
35 35 import urlobject
36 36 import uuid
37 37 import getpass
38 38
39 39 import pygments.lexers
40 40 import sqlalchemy
41 41 import sqlalchemy.engine.url
42 42 import sqlalchemy.exc
43 43 import sqlalchemy.sql
44 44 import webob
45 45 import pyramid.threadlocal
46 46
47 47 import rhodecode
48 48 from rhodecode.translation import _, _pluralize
49 49
50 50
51 51 def md5(s):
52 52 return hashlib.md5(s).hexdigest()
53 53
54 54
55 55 def md5_safe(s):
56 56 return md5(safe_str(s))
57 57
58 58
59 def sha1(s):
60 return hashlib.sha1(s).hexdigest()
61
62
63 def sha1_safe(s):
64 return sha1(safe_str(s))
65
66
59 67 def __get_lem(extra_mapping=None):
60 68 """
61 69 Get language extension map based on what's inside pygments lexers
62 70 """
63 71 d = collections.defaultdict(lambda: [])
64 72
65 73 def __clean(s):
66 74 s = s.lstrip('*')
67 75 s = s.lstrip('.')
68 76
69 77 if s.find('[') != -1:
70 78 exts = []
71 79 start, stop = s.find('['), s.find(']')
72 80
73 81 for suffix in s[start + 1:stop]:
74 82 exts.append(s[:s.find('[')] + suffix)
75 83 return [e.lower() for e in exts]
76 84 else:
77 85 return [s.lower()]
78 86
79 87 for lx, t in sorted(pygments.lexers.LEXERS.items()):
80 88 m = map(__clean, t[-2])
81 89 if m:
82 90 m = reduce(lambda x, y: x + y, m)
83 91 for ext in m:
84 92 desc = lx.replace('Lexer', '')
85 93 d[ext].append(desc)
86 94
87 95 data = dict(d)
88 96
89 97 extra_mapping = extra_mapping or {}
90 98 if extra_mapping:
91 99 for k, v in extra_mapping.items():
92 100 if k not in data:
93 101 # register new mapping2lexer
94 102 data[k] = [v]
95 103
96 104 return data
97 105
98 106
99 107 def str2bool(_str):
100 108 """
101 109 returns True/False value from given string, it tries to translate the
102 110 string into boolean
103 111
104 112 :param _str: string value to translate into boolean
105 113 :rtype: boolean
106 114 :returns: boolean from given string
107 115 """
108 116 if _str is None:
109 117 return False
110 118 if _str in (True, False):
111 119 return _str
112 120 _str = str(_str).strip().lower()
113 121 return _str in ('t', 'true', 'y', 'yes', 'on', '1')
114 122
115 123
116 124 def aslist(obj, sep=None, strip=True):
117 125 """
118 126 Returns given string separated by sep as list
119 127
120 128 :param obj:
121 129 :param sep:
122 130 :param strip:
123 131 """
124 132 if isinstance(obj, (basestring,)):
125 133 lst = obj.split(sep)
126 134 if strip:
127 135 lst = [v.strip() for v in lst]
128 136 return lst
129 137 elif isinstance(obj, (list, tuple)):
130 138 return obj
131 139 elif obj is None:
132 140 return []
133 141 else:
134 142 return [obj]
135 143
136 144
137 145 def convert_line_endings(line, mode):
138 146 """
139 147 Converts a given line "line end" accordingly to given mode
140 148
141 149 Available modes are::
142 150 0 - Unix
143 151 1 - Mac
144 152 2 - DOS
145 153
146 154 :param line: given line to convert
147 155 :param mode: mode to convert to
148 156 :rtype: str
149 157 :return: converted line according to mode
150 158 """
151 159 if mode == 0:
152 160 line = line.replace('\r\n', '\n')
153 161 line = line.replace('\r', '\n')
154 162 elif mode == 1:
155 163 line = line.replace('\r\n', '\r')
156 164 line = line.replace('\n', '\r')
157 165 elif mode == 2:
158 166 line = re.sub('\r(?!\n)|(?<!\r)\n', '\r\n', line)
159 167 return line
160 168
161 169
162 170 def detect_mode(line, default):
163 171 """
164 172 Detects line break for given line, if line break couldn't be found
165 173 given default value is returned
166 174
167 175 :param line: str line
168 176 :param default: default
169 177 :rtype: int
170 178 :return: value of line end on of 0 - Unix, 1 - Mac, 2 - DOS
171 179 """
172 180 if line.endswith('\r\n'):
173 181 return 2
174 182 elif line.endswith('\n'):
175 183 return 0
176 184 elif line.endswith('\r'):
177 185 return 1
178 186 else:
179 187 return default
180 188
181 189
182 190 def safe_int(val, default=None):
183 191 """
184 192 Returns int() of val if val is not convertable to int use default
185 193 instead
186 194
187 195 :param val:
188 196 :param default:
189 197 """
190 198
191 199 try:
192 200 val = int(val)
193 201 except (ValueError, TypeError):
194 202 val = default
195 203
196 204 return val
197 205
198 206
199 207 def safe_unicode(str_, from_encoding=None):
200 208 """
201 209 safe unicode function. Does few trick to turn str_ into unicode
202 210
203 211 In case of UnicodeDecode error, we try to return it with encoding detected
204 212 by chardet library if it fails fallback to unicode with errors replaced
205 213
206 214 :param str_: string to decode
207 215 :rtype: unicode
208 216 :returns: unicode object
209 217 """
210 218 if isinstance(str_, unicode):
211 219 return str_
212 220
213 221 if not from_encoding:
214 222 DEFAULT_ENCODINGS = aslist(rhodecode.CONFIG.get('default_encoding',
215 223 'utf8'), sep=',')
216 224 from_encoding = DEFAULT_ENCODINGS
217 225
218 226 if not isinstance(from_encoding, (list, tuple)):
219 227 from_encoding = [from_encoding]
220 228
221 229 try:
222 230 return unicode(str_)
223 231 except UnicodeDecodeError:
224 232 pass
225 233
226 234 for enc in from_encoding:
227 235 try:
228 236 return unicode(str_, enc)
229 237 except UnicodeDecodeError:
230 238 pass
231 239
232 240 try:
233 241 import chardet
234 242 encoding = chardet.detect(str_)['encoding']
235 243 if encoding is None:
236 244 raise Exception()
237 245 return str_.decode(encoding)
238 246 except (ImportError, UnicodeDecodeError, Exception):
239 247 return unicode(str_, from_encoding[0], 'replace')
240 248
241 249
242 250 def safe_str(unicode_, to_encoding=None):
243 251 """
244 252 safe str function. Does few trick to turn unicode_ into string
245 253
246 254 In case of UnicodeEncodeError, we try to return it with encoding detected
247 255 by chardet library if it fails fallback to string with errors replaced
248 256
249 257 :param unicode_: unicode to encode
250 258 :rtype: str
251 259 :returns: str object
252 260 """
253 261
254 262 # if it's not basestr cast to str
255 263 if not isinstance(unicode_, basestring):
256 264 return str(unicode_)
257 265
258 266 if isinstance(unicode_, str):
259 267 return unicode_
260 268
261 269 if not to_encoding:
262 270 DEFAULT_ENCODINGS = aslist(rhodecode.CONFIG.get('default_encoding',
263 271 'utf8'), sep=',')
264 272 to_encoding = DEFAULT_ENCODINGS
265 273
266 274 if not isinstance(to_encoding, (list, tuple)):
267 275 to_encoding = [to_encoding]
268 276
269 277 for enc in to_encoding:
270 278 try:
271 279 return unicode_.encode(enc)
272 280 except UnicodeEncodeError:
273 281 pass
274 282
275 283 try:
276 284 import chardet
277 285 encoding = chardet.detect(unicode_)['encoding']
278 286 if encoding is None:
279 287 raise UnicodeEncodeError()
280 288
281 289 return unicode_.encode(encoding)
282 290 except (ImportError, UnicodeEncodeError):
283 291 return unicode_.encode(to_encoding[0], 'replace')
284 292
285 293
286 294 def remove_suffix(s, suffix):
287 295 if s.endswith(suffix):
288 296 s = s[:-1 * len(suffix)]
289 297 return s
290 298
291 299
292 300 def remove_prefix(s, prefix):
293 301 if s.startswith(prefix):
294 302 s = s[len(prefix):]
295 303 return s
296 304
297 305
298 306 def find_calling_context(ignore_modules=None):
299 307 """
300 308 Look through the calling stack and return the frame which called
301 309 this function and is part of core module ( ie. rhodecode.* )
302 310
303 311 :param ignore_modules: list of modules to ignore eg. ['rhodecode.lib']
304 312 """
305 313
306 314 ignore_modules = ignore_modules or []
307 315
308 316 f = sys._getframe(2)
309 317 while f.f_back is not None:
310 318 name = f.f_globals.get('__name__')
311 319 if name and name.startswith(__name__.split('.')[0]):
312 320 if name not in ignore_modules:
313 321 return f
314 322 f = f.f_back
315 323 return None
316 324
317 325
318 326 def ping_connection(connection, branch):
319 327 if branch:
320 328 # "branch" refers to a sub-connection of a connection,
321 329 # we don't want to bother pinging on these.
322 330 return
323 331
324 332 # turn off "close with result". This flag is only used with
325 333 # "connectionless" execution, otherwise will be False in any case
326 334 save_should_close_with_result = connection.should_close_with_result
327 335 connection.should_close_with_result = False
328 336
329 337 try:
330 338 # run a SELECT 1. use a core select() so that
331 339 # the SELECT of a scalar value without a table is
332 340 # appropriately formatted for the backend
333 341 connection.scalar(sqlalchemy.sql.select([1]))
334 342 except sqlalchemy.exc.DBAPIError as err:
335 343 # catch SQLAlchemy's DBAPIError, which is a wrapper
336 344 # for the DBAPI's exception. It includes a .connection_invalidated
337 345 # attribute which specifies if this connection is a "disconnect"
338 346 # condition, which is based on inspection of the original exception
339 347 # by the dialect in use.
340 348 if err.connection_invalidated:
341 349 # run the same SELECT again - the connection will re-validate
342 350 # itself and establish a new connection. The disconnect detection
343 351 # here also causes the whole connection pool to be invalidated
344 352 # so that all stale connections are discarded.
345 353 connection.scalar(sqlalchemy.sql.select([1]))
346 354 else:
347 355 raise
348 356 finally:
349 357 # restore "close with result"
350 358 connection.should_close_with_result = save_should_close_with_result
351 359
352 360
353 361 def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs):
354 362 """Custom engine_from_config functions."""
355 363 log = logging.getLogger('sqlalchemy.engine')
356 364 _ping_connection = configuration.pop('sqlalchemy.db1.ping_connection', None)
357 365
358 366 engine = sqlalchemy.engine_from_config(configuration, prefix, **kwargs)
359 367
360 368 def color_sql(sql):
361 369 color_seq = '\033[1;33m' # This is yellow: code 33
362 370 normal = '\x1b[0m'
363 371 return ''.join([color_seq, sql, normal])
364 372
365 373 if configuration['debug'] or _ping_connection:
366 374 sqlalchemy.event.listen(engine, "engine_connect", ping_connection)
367 375
368 376 if configuration['debug']:
369 377 # attach events only for debug configuration
370 378
371 379 def before_cursor_execute(conn, cursor, statement,
372 380 parameters, context, executemany):
373 381 setattr(conn, 'query_start_time', time.time())
374 382 log.info(color_sql(">>>>> STARTING QUERY >>>>>"))
375 383 calling_context = find_calling_context(ignore_modules=[
376 384 'rhodecode.lib.caching_query',
377 385 'rhodecode.model.settings',
378 386 ])
379 387 if calling_context:
380 388 log.info(color_sql('call context %s:%s' % (
381 389 calling_context.f_code.co_filename,
382 390 calling_context.f_lineno,
383 391 )))
384 392
385 393 def after_cursor_execute(conn, cursor, statement,
386 394 parameters, context, executemany):
387 395 delattr(conn, 'query_start_time')
388 396
389 397 sqlalchemy.event.listen(engine, "before_cursor_execute",
390 398 before_cursor_execute)
391 399 sqlalchemy.event.listen(engine, "after_cursor_execute",
392 400 after_cursor_execute)
393 401
394 402 return engine
395 403
396 404
397 405 def get_encryption_key(config):
398 406 secret = config.get('rhodecode.encrypted_values.secret')
399 407 default = config['beaker.session.secret']
400 408 return secret or default
401 409
402 410
403 411 def age(prevdate, now=None, show_short_version=False, show_suffix=True,
404 412 short_format=False):
405 413 """
406 414 Turns a datetime into an age string.
407 415 If show_short_version is True, this generates a shorter string with
408 416 an approximate age; ex. '1 day ago', rather than '1 day and 23 hours ago'.
409 417
410 418 * IMPORTANT*
411 419 Code of this function is written in special way so it's easier to
412 420 backport it to javascript. If you mean to update it, please also update
413 421 `jquery.timeago-extension.js` file
414 422
415 423 :param prevdate: datetime object
416 424 :param now: get current time, if not define we use
417 425 `datetime.datetime.now()`
418 426 :param show_short_version: if it should approximate the date and
419 427 return a shorter string
420 428 :param show_suffix:
421 429 :param short_format: show short format, eg 2D instead of 2 days
422 430 :rtype: unicode
423 431 :returns: unicode words describing age
424 432 """
425 433
426 434 def _get_relative_delta(now, prevdate):
427 435 base = dateutil.relativedelta.relativedelta(now, prevdate)
428 436 return {
429 437 'year': base.years,
430 438 'month': base.months,
431 439 'day': base.days,
432 440 'hour': base.hours,
433 441 'minute': base.minutes,
434 442 'second': base.seconds,
435 443 }
436 444
437 445 def _is_leap_year(year):
438 446 return year % 4 == 0 and (year % 100 != 0 or year % 400 == 0)
439 447
440 448 def get_month(prevdate):
441 449 return prevdate.month
442 450
443 451 def get_year(prevdate):
444 452 return prevdate.year
445 453
446 454 now = now or datetime.datetime.now()
447 455 order = ['year', 'month', 'day', 'hour', 'minute', 'second']
448 456 deltas = {}
449 457 future = False
450 458
451 459 if prevdate > now:
452 460 now_old = now
453 461 now = prevdate
454 462 prevdate = now_old
455 463 future = True
456 464 if future:
457 465 prevdate = prevdate.replace(microsecond=0)
458 466 # Get date parts deltas
459 467 for part in order:
460 468 rel_delta = _get_relative_delta(now, prevdate)
461 469 deltas[part] = rel_delta[part]
462 470
463 471 # Fix negative offsets (there is 1 second between 10:59:59 and 11:00:00,
464 472 # not 1 hour, -59 minutes and -59 seconds)
465 473 offsets = [[5, 60], [4, 60], [3, 24]]
466 474 for element in offsets: # seconds, minutes, hours
467 475 num = element[0]
468 476 length = element[1]
469 477
470 478 part = order[num]
471 479 carry_part = order[num - 1]
472 480
473 481 if deltas[part] < 0:
474 482 deltas[part] += length
475 483 deltas[carry_part] -= 1
476 484
477 485 # Same thing for days except that the increment depends on the (variable)
478 486 # number of days in the month
479 487 month_lengths = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
480 488 if deltas['day'] < 0:
481 489 if get_month(prevdate) == 2 and _is_leap_year(get_year(prevdate)):
482 490 deltas['day'] += 29
483 491 else:
484 492 deltas['day'] += month_lengths[get_month(prevdate) - 1]
485 493
486 494 deltas['month'] -= 1
487 495
488 496 if deltas['month'] < 0:
489 497 deltas['month'] += 12
490 498 deltas['year'] -= 1
491 499
492 500 # Format the result
493 501 if short_format:
494 502 fmt_funcs = {
495 503 'year': lambda d: u'%dy' % d,
496 504 'month': lambda d: u'%dm' % d,
497 505 'day': lambda d: u'%dd' % d,
498 506 'hour': lambda d: u'%dh' % d,
499 507 'minute': lambda d: u'%dmin' % d,
500 508 'second': lambda d: u'%dsec' % d,
501 509 }
502 510 else:
503 511 fmt_funcs = {
504 512 'year': lambda d: _pluralize(u'${num} year', u'${num} years', d, mapping={'num': d}).interpolate(),
505 513 'month': lambda d: _pluralize(u'${num} month', u'${num} months', d, mapping={'num': d}).interpolate(),
506 514 'day': lambda d: _pluralize(u'${num} day', u'${num} days', d, mapping={'num': d}).interpolate(),
507 515 'hour': lambda d: _pluralize(u'${num} hour', u'${num} hours', d, mapping={'num': d}).interpolate(),
508 516 'minute': lambda d: _pluralize(u'${num} minute', u'${num} minutes', d, mapping={'num': d}).interpolate(),
509 517 'second': lambda d: _pluralize(u'${num} second', u'${num} seconds', d, mapping={'num': d}).interpolate(),
510 518 }
511 519
512 520 i = 0
513 521 for part in order:
514 522 value = deltas[part]
515 523 if value != 0:
516 524
517 525 if i < 5:
518 526 sub_part = order[i + 1]
519 527 sub_value = deltas[sub_part]
520 528 else:
521 529 sub_value = 0
522 530
523 531 if sub_value == 0 or show_short_version:
524 532 _val = fmt_funcs[part](value)
525 533 if future:
526 534 if show_suffix:
527 535 return _(u'in ${ago}', mapping={'ago': _val})
528 536 else:
529 537 return _(_val)
530 538
531 539 else:
532 540 if show_suffix:
533 541 return _(u'${ago} ago', mapping={'ago': _val})
534 542 else:
535 543 return _(_val)
536 544
537 545 val = fmt_funcs[part](value)
538 546 val_detail = fmt_funcs[sub_part](sub_value)
539 547 mapping = {'val': val, 'detail': val_detail}
540 548
541 549 if short_format:
542 550 datetime_tmpl = _(u'${val}, ${detail}', mapping=mapping)
543 551 if show_suffix:
544 552 datetime_tmpl = _(u'${val}, ${detail} ago', mapping=mapping)
545 553 if future:
546 554 datetime_tmpl = _(u'in ${val}, ${detail}', mapping=mapping)
547 555 else:
548 556 datetime_tmpl = _(u'${val} and ${detail}', mapping=mapping)
549 557 if show_suffix:
550 558 datetime_tmpl = _(u'${val} and ${detail} ago', mapping=mapping)
551 559 if future:
552 560 datetime_tmpl = _(u'in ${val} and ${detail}', mapping=mapping)
553 561
554 562 return datetime_tmpl
555 563 i += 1
556 564 return _(u'just now')
557 565
558 566
559 567 def cleaned_uri(uri):
560 568 """
561 569 Quotes '[' and ']' from uri if there is only one of them.
562 570 according to RFC3986 we cannot use such chars in uri
563 571 :param uri:
564 572 :return: uri without this chars
565 573 """
566 574 return urllib.quote(uri, safe='@$:/')
567 575
568 576
569 577 def uri_filter(uri):
570 578 """
571 579 Removes user:password from given url string
572 580
573 581 :param uri:
574 582 :rtype: unicode
575 583 :returns: filtered list of strings
576 584 """
577 585 if not uri:
578 586 return ''
579 587
580 588 proto = ''
581 589
582 590 for pat in ('https://', 'http://'):
583 591 if uri.startswith(pat):
584 592 uri = uri[len(pat):]
585 593 proto = pat
586 594 break
587 595
588 596 # remove passwords and username
589 597 uri = uri[uri.find('@') + 1:]
590 598
591 599 # get the port
592 600 cred_pos = uri.find(':')
593 601 if cred_pos == -1:
594 602 host, port = uri, None
595 603 else:
596 604 host, port = uri[:cred_pos], uri[cred_pos + 1:]
597 605
598 606 return filter(None, [proto, host, port])
599 607
600 608
601 609 def credentials_filter(uri):
602 610 """
603 611 Returns a url with removed credentials
604 612
605 613 :param uri:
606 614 """
607 615
608 616 uri = uri_filter(uri)
609 617 # check if we have port
610 618 if len(uri) > 2 and uri[2]:
611 619 uri[2] = ':' + uri[2]
612 620
613 621 return ''.join(uri)
614 622
615 623
616 624 def get_clone_url(request, uri_tmpl, repo_name, repo_id, **override):
617 625 qualifed_home_url = request.route_url('home')
618 626 parsed_url = urlobject.URLObject(qualifed_home_url)
619 627 decoded_path = safe_unicode(urllib.unquote(parsed_url.path.rstrip('/')))
620 628
621 629 args = {
622 630 'scheme': parsed_url.scheme,
623 631 'user': '',
624 632 'sys_user': getpass.getuser(),
625 633 # path if we use proxy-prefix
626 634 'netloc': parsed_url.netloc+decoded_path,
627 635 'hostname': parsed_url.hostname,
628 636 'prefix': decoded_path,
629 637 'repo': repo_name,
630 638 'repoid': str(repo_id)
631 639 }
632 640 args.update(override)
633 641 args['user'] = urllib.quote(safe_str(args['user']))
634 642
635 643 for k, v in args.items():
636 644 uri_tmpl = uri_tmpl.replace('{%s}' % k, v)
637 645
638 646 # remove leading @ sign if it's present. Case of empty user
639 647 url_obj = urlobject.URLObject(uri_tmpl)
640 648 url = url_obj.with_netloc(url_obj.netloc.lstrip('@'))
641 649
642 650 return safe_unicode(url)
643 651
644 652
645 653 def get_commit_safe(repo, commit_id=None, commit_idx=None, pre_load=None):
646 654 """
647 655 Safe version of get_commit if this commit doesn't exists for a
648 656 repository it returns a Dummy one instead
649 657
650 658 :param repo: repository instance
651 659 :param commit_id: commit id as str
652 660 :param pre_load: optional list of commit attributes to load
653 661 """
654 662 # TODO(skreft): remove these circular imports
655 663 from rhodecode.lib.vcs.backends.base import BaseRepository, EmptyCommit
656 664 from rhodecode.lib.vcs.exceptions import RepositoryError
657 665 if not isinstance(repo, BaseRepository):
658 666 raise Exception('You must pass an Repository '
659 667 'object as first argument got %s', type(repo))
660 668
661 669 try:
662 670 commit = repo.get_commit(
663 671 commit_id=commit_id, commit_idx=commit_idx, pre_load=pre_load)
664 672 except (RepositoryError, LookupError):
665 673 commit = EmptyCommit()
666 674 return commit
667 675
668 676
669 677 def datetime_to_time(dt):
670 678 if dt:
671 679 return time.mktime(dt.timetuple())
672 680
673 681
674 682 def time_to_datetime(tm):
675 683 if tm:
676 684 if isinstance(tm, basestring):
677 685 try:
678 686 tm = float(tm)
679 687 except ValueError:
680 688 return
681 689 return datetime.datetime.fromtimestamp(tm)
682 690
683 691
684 692 def time_to_utcdatetime(tm):
685 693 if tm:
686 694 if isinstance(tm, basestring):
687 695 try:
688 696 tm = float(tm)
689 697 except ValueError:
690 698 return
691 699 return datetime.datetime.utcfromtimestamp(tm)
692 700
693 701
694 702 MENTIONS_REGEX = re.compile(
695 703 # ^@ or @ without any special chars in front
696 704 r'(?:^@|[^a-zA-Z0-9\-\_\.]@)'
697 705 # main body starts with letter, then can be . - _
698 706 r'([a-zA-Z0-9]{1}[a-zA-Z0-9\-\_\.]+)',
699 707 re.VERBOSE | re.MULTILINE)
700 708
701 709
702 710 def extract_mentioned_users(s):
703 711 """
704 712 Returns unique usernames from given string s that have @mention
705 713
706 714 :param s: string to get mentions
707 715 """
708 716 usrs = set()
709 717 for username in MENTIONS_REGEX.findall(s):
710 718 usrs.add(username)
711 719
712 720 return sorted(list(usrs), key=lambda k: k.lower())
713 721
714 722
715 723 class AttributeDictBase(dict):
716 724 def __getstate__(self):
717 725 odict = self.__dict__ # get attribute dictionary
718 726 return odict
719 727
720 728 def __setstate__(self, dict):
721 729 self.__dict__ = dict
722 730
723 731 __setattr__ = dict.__setitem__
724 732 __delattr__ = dict.__delitem__
725 733
726 734
727 735 class StrictAttributeDict(AttributeDictBase):
728 736 """
729 737 Strict Version of Attribute dict which raises an Attribute error when
730 738 requested attribute is not set
731 739 """
732 740 def __getattr__(self, attr):
733 741 try:
734 742 return self[attr]
735 743 except KeyError:
736 744 raise AttributeError('%s object has no attribute %s' % (
737 745 self.__class__, attr))
738 746
739 747
740 748 class AttributeDict(AttributeDictBase):
741 749 def __getattr__(self, attr):
742 750 return self.get(attr, None)
743 751
744 752
745 753
746 754 def fix_PATH(os_=None):
747 755 """
748 756 Get current active python path, and append it to PATH variable to fix
749 757 issues of subprocess calls and different python versions
750 758 """
751 759 if os_ is None:
752 760 import os
753 761 else:
754 762 os = os_
755 763
756 764 cur_path = os.path.split(sys.executable)[0]
757 765 if not os.environ['PATH'].startswith(cur_path):
758 766 os.environ['PATH'] = '%s:%s' % (cur_path, os.environ['PATH'])
759 767
760 768
761 769 def obfuscate_url_pw(engine):
762 770 _url = engine or ''
763 771 try:
764 772 _url = sqlalchemy.engine.url.make_url(engine)
765 773 if _url.password:
766 774 _url.password = 'XXXXX'
767 775 except Exception:
768 776 pass
769 777 return unicode(_url)
770 778
771 779
772 780 def get_server_url(environ):
773 781 req = webob.Request(environ)
774 782 return req.host_url + req.script_name
775 783
776 784
777 785 def unique_id(hexlen=32):
778 786 alphabet = "23456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghjklmnpqrstuvwxyz"
779 787 return suuid(truncate_to=hexlen, alphabet=alphabet)
780 788
781 789
782 790 def suuid(url=None, truncate_to=22, alphabet=None):
783 791 """
784 792 Generate and return a short URL safe UUID.
785 793
786 794 If the url parameter is provided, set the namespace to the provided
787 795 URL and generate a UUID.
788 796
789 797 :param url to get the uuid for
790 798 :truncate_to: truncate the basic 22 UUID to shorter version
791 799
792 800 The IDs won't be universally unique any longer, but the probability of
793 801 a collision will still be very low.
794 802 """
795 803 # Define our alphabet.
796 804 _ALPHABET = alphabet or "23456789ABCDEFGHJKLMNPQRSTUVWXYZ"
797 805
798 806 # If no URL is given, generate a random UUID.
799 807 if url is None:
800 808 unique_id = uuid.uuid4().int
801 809 else:
802 810 unique_id = uuid.uuid3(uuid.NAMESPACE_URL, url).int
803 811
804 812 alphabet_length = len(_ALPHABET)
805 813 output = []
806 814 while unique_id > 0:
807 815 digit = unique_id % alphabet_length
808 816 output.append(_ALPHABET[digit])
809 817 unique_id = int(unique_id / alphabet_length)
810 818 return "".join(output)[:truncate_to]
811 819
812 820
813 821 def get_current_rhodecode_user(request=None):
814 822 """
815 823 Gets rhodecode user from request
816 824 """
817 825 pyramid_request = request or pyramid.threadlocal.get_current_request()
818 826
819 827 # web case
820 828 if pyramid_request and hasattr(pyramid_request, 'user'):
821 829 return pyramid_request.user
822 830
823 831 # api case
824 832 if pyramid_request and hasattr(pyramid_request, 'rpc_user'):
825 833 return pyramid_request.rpc_user
826 834
827 835 return None
828 836
829 837
830 838 def action_logger_generic(action, namespace=''):
831 839 """
832 840 A generic logger for actions useful to the system overview, tries to find
833 841 an acting user for the context of the call otherwise reports unknown user
834 842
835 843 :param action: logging message eg 'comment 5 deleted'
836 844 :param type: string
837 845
838 846 :param namespace: namespace of the logging message eg. 'repo.comments'
839 847 :param type: string
840 848
841 849 """
842 850
843 851 logger_name = 'rhodecode.actions'
844 852
845 853 if namespace:
846 854 logger_name += '.' + namespace
847 855
848 856 log = logging.getLogger(logger_name)
849 857
850 858 # get a user if we can
851 859 user = get_current_rhodecode_user()
852 860
853 861 logfunc = log.info
854 862
855 863 if not user:
856 864 user = '<unknown user>'
857 865 logfunc = log.warning
858 866
859 867 logfunc('Logging action by {}: {}'.format(user, action))
860 868
861 869
862 870 def escape_split(text, sep=',', maxsplit=-1):
863 871 r"""
864 872 Allows for escaping of the separator: e.g. arg='foo\, bar'
865 873
866 874 It should be noted that the way bash et. al. do command line parsing, those
867 875 single quotes are required.
868 876 """
869 877 escaped_sep = r'\%s' % sep
870 878
871 879 if escaped_sep not in text:
872 880 return text.split(sep, maxsplit)
873 881
874 882 before, _mid, after = text.partition(escaped_sep)
875 883 startlist = before.split(sep, maxsplit) # a regular split is fine here
876 884 unfinished = startlist[-1]
877 885 startlist = startlist[:-1]
878 886
879 887 # recurse because there may be more escaped separators
880 888 endlist = escape_split(after, sep, maxsplit)
881 889
882 890 # finish building the escaped value. we use endlist[0] becaue the first
883 891 # part of the string sent in recursion is the rest of the escaped value.
884 892 unfinished += sep + endlist[0]
885 893
886 894 return startlist + [unfinished] + endlist[1:] # put together all the parts
887 895
888 896
889 897 class OptionalAttr(object):
890 898 """
891 899 Special Optional Option that defines other attribute. Example::
892 900
893 901 def test(apiuser, userid=Optional(OAttr('apiuser')):
894 902 user = Optional.extract(userid)
895 903 # calls
896 904
897 905 """
898 906
899 907 def __init__(self, attr_name):
900 908 self.attr_name = attr_name
901 909
902 910 def __repr__(self):
903 911 return '<OptionalAttr:%s>' % self.attr_name
904 912
905 913 def __call__(self):
906 914 return self
907 915
908 916
909 917 # alias
910 918 OAttr = OptionalAttr
911 919
912 920
913 921 class Optional(object):
914 922 """
915 923 Defines an optional parameter::
916 924
917 925 param = param.getval() if isinstance(param, Optional) else param
918 926 param = param() if isinstance(param, Optional) else param
919 927
920 928 is equivalent of::
921 929
922 930 param = Optional.extract(param)
923 931
924 932 """
925 933
926 934 def __init__(self, type_):
927 935 self.type_ = type_
928 936
929 937 def __repr__(self):
930 938 return '<Optional:%s>' % self.type_.__repr__()
931 939
932 940 def __call__(self):
933 941 return self.getval()
934 942
935 943 def getval(self):
936 944 """
937 945 returns value from this Optional instance
938 946 """
939 947 if isinstance(self.type_, OAttr):
940 948 # use params name
941 949 return self.type_.attr_name
942 950 return self.type_
943 951
944 952 @classmethod
945 953 def extract(cls, val):
946 954 """
947 955 Extracts value from Optional() instance
948 956
949 957 :param val:
950 958 :return: original value if it's not Optional instance else
951 959 value of instance
952 960 """
953 961 if isinstance(val, cls):
954 962 return val.getval()
955 963 return val
956 964
957 965
958 966 def glob2re(pat):
959 967 """
960 968 Translate a shell PATTERN to a regular expression.
961 969
962 970 There is no way to quote meta-characters.
963 971 """
964 972
965 973 i, n = 0, len(pat)
966 974 res = ''
967 975 while i < n:
968 976 c = pat[i]
969 977 i = i+1
970 978 if c == '*':
971 979 #res = res + '.*'
972 980 res = res + '[^/]*'
973 981 elif c == '?':
974 982 #res = res + '.'
975 983 res = res + '[^/]'
976 984 elif c == '[':
977 985 j = i
978 986 if j < n and pat[j] == '!':
979 987 j = j+1
980 988 if j < n and pat[j] == ']':
981 989 j = j+1
982 990 while j < n and pat[j] != ']':
983 991 j = j+1
984 992 if j >= n:
985 993 res = res + '\\['
986 994 else:
987 995 stuff = pat[i:j].replace('\\','\\\\')
988 996 i = j+1
989 997 if stuff[0] == '!':
990 998 stuff = '^' + stuff[1:]
991 999 elif stuff[0] == '^':
992 1000 stuff = '\\' + stuff
993 1001 res = '%s[%s]' % (res, stuff)
994 1002 else:
995 1003 res = res + re.escape(c)
996 1004 return res + '\Z(?ms)'
General Comments 0
You need to be logged in to leave comments. Login now