##// END OF EJS Templates
pick_port: unified code for testing/hooks
super-admin -
r4866:6b029be8 default
parent child Browse files
Show More
@@ -1,357 +1,347 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import os
22 22 import time
23 23 import logging
24 24 import tempfile
25 25 import traceback
26 26 import threading
27 27 import socket
28 28 import random
29 29
30 30 from BaseHTTPServer import BaseHTTPRequestHandler
31 31 from SocketServer import TCPServer
32 32
33 33 import rhodecode
34 34 from rhodecode.lib.exceptions import HTTPLockedRC, HTTPBranchProtected
35 35 from rhodecode.model import meta
36 36 from rhodecode.lib.base import bootstrap_request, bootstrap_config
37 37 from rhodecode.lib import hooks_base
38 38 from rhodecode.lib.utils2 import AttributeDict
39 39 from rhodecode.lib.ext_json import json
40 40 from rhodecode.lib import rc_cache
41 41
42 42 log = logging.getLogger(__name__)
43 43
44 44
45 45 class HooksHttpHandler(BaseHTTPRequestHandler):
46 46
47 47 def do_POST(self):
48 48 method, extras = self._read_request()
49 49 txn_id = getattr(self.server, 'txn_id', None)
50 50 if txn_id:
51 51 log.debug('Computing TXN_ID based on `%s`:`%s`',
52 52 extras['repository'], extras['txn_id'])
53 53 computed_txn_id = rc_cache.utils.compute_key_from_params(
54 54 extras['repository'], extras['txn_id'])
55 55 if txn_id != computed_txn_id:
56 56 raise Exception(
57 57 'TXN ID fail: expected {} got {} instead'.format(
58 58 txn_id, computed_txn_id))
59 59
60 60 try:
61 61 result = self._call_hook(method, extras)
62 62 except Exception as e:
63 63 exc_tb = traceback.format_exc()
64 64 result = {
65 65 'exception': e.__class__.__name__,
66 66 'exception_traceback': exc_tb,
67 67 'exception_args': e.args
68 68 }
69 69 self._write_response(result)
70 70
71 71 def _read_request(self):
72 72 length = int(self.headers['Content-Length'])
73 73 body = self.rfile.read(length).decode('utf-8')
74 74 data = json.loads(body)
75 75 return data['method'], data['extras']
76 76
77 77 def _write_response(self, result):
78 78 self.send_response(200)
79 79 self.send_header("Content-type", "text/json")
80 80 self.end_headers()
81 81 self.wfile.write(json.dumps(result))
82 82
83 83 def _call_hook(self, method, extras):
84 84 hooks = Hooks()
85 85 try:
86 86 result = getattr(hooks, method)(extras)
87 87 finally:
88 88 meta.Session.remove()
89 89 return result
90 90
91 91 def log_message(self, format, *args):
92 92 """
93 93 This is an overridden method of BaseHTTPRequestHandler which logs using
94 94 logging library instead of writing directly to stderr.
95 95 """
96 96
97 97 message = format % args
98 98
99 99 log.debug(
100 100 "%s - - [%s] %s", self.client_address[0],
101 101 self.log_date_time_string(), message)
102 102
103 103
104 104 class DummyHooksCallbackDaemon(object):
105 105 hooks_uri = ''
106 106
107 107 def __init__(self):
108 108 self.hooks_module = Hooks.__module__
109 109
110 110 def __enter__(self):
111 111 log.debug('Running `%s` callback daemon', self.__class__.__name__)
112 112 return self
113 113
114 114 def __exit__(self, exc_type, exc_val, exc_tb):
115 115 log.debug('Exiting `%s` callback daemon', self.__class__.__name__)
116 116
117 117
118 118 class ThreadedHookCallbackDaemon(object):
119 119
120 120 _callback_thread = None
121 121 _daemon = None
122 122 _done = False
123 123
124 124 def __init__(self, txn_id=None, host=None, port=None):
125 125 self._prepare(txn_id=txn_id, host=host, port=port)
126 126
127 127 def __enter__(self):
128 128 log.debug('Running `%s` callback daemon', self.__class__.__name__)
129 129 self._run()
130 130 return self
131 131
132 132 def __exit__(self, exc_type, exc_val, exc_tb):
133 133 log.debug('Exiting `%s` callback daemon', self.__class__.__name__)
134 134 self._stop()
135 135
136 136 def _prepare(self, txn_id=None, host=None, port=None):
137 137 raise NotImplementedError()
138 138
139 139 def _run(self):
140 140 raise NotImplementedError()
141 141
142 142 def _stop(self):
143 143 raise NotImplementedError()
144 144
145 145
146 146 class HttpHooksCallbackDaemon(ThreadedHookCallbackDaemon):
147 147 """
148 148 Context manager which will run a callback daemon in a background thread.
149 149 """
150 150
151 151 hooks_uri = None
152 152
153 153 # From Python docs: Polling reduces our responsiveness to a shutdown
154 154 # request and wastes cpu at all other times.
155 155 POLL_INTERVAL = 0.01
156 156
157 157 def get_hostname(self):
158 158 return socket.gethostname() or '127.0.0.1'
159 159
160 160 def get_available_port(self, min_port=20000, max_port=65535):
161 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
162 hostname = self.get_hostname()
163
164 for _ in range(min_port, max_port):
165 pick_port = random.randint(min_port, max_port)
166 try:
167 sock.bind((hostname, pick_port))
168 sock.close()
169 del sock
170 return pick_port
171 except OSError:
172 pass
161 from rhodecode.lib.utils2 import get_available_port as _get_port
162 return _get_port(min_port, max_port)
173 163
174 164 def _prepare(self, txn_id=None, host=None, port=None):
175 165 if not host or host == "*":
176 166 host = self.get_hostname()
177 167 if not port:
178 168 port = self.get_available_port()
179 169
180 170 server_address = (host, port)
181 171 self.hooks_uri = '{}:{}'.format(host, port)
182 172 self.txn_id = txn_id
183 173 self._done = False
184 174
185 175 log.debug(
186 176 "Preparing HTTP callback daemon at `%s` and registering hook object: %s",
187 177 self.hooks_uri, HooksHttpHandler)
188 178
189 179 self._daemon = TCPServer(server_address, HooksHttpHandler)
190 180 # inject transaction_id for later verification
191 181 self._daemon.txn_id = self.txn_id
192 182
193 183 def _run(self):
194 184 log.debug("Running event loop of callback daemon in background thread")
195 185 callback_thread = threading.Thread(
196 186 target=self._daemon.serve_forever,
197 187 kwargs={'poll_interval': self.POLL_INTERVAL})
198 188 callback_thread.daemon = True
199 189 callback_thread.start()
200 190 self._callback_thread = callback_thread
201 191
202 192 def _stop(self):
203 193 log.debug("Waiting for background thread to finish.")
204 194 self._daemon.shutdown()
205 195 self._callback_thread.join()
206 196 self._daemon = None
207 197 self._callback_thread = None
208 198 if self.txn_id:
209 199 txn_id_file = get_txn_id_data_path(self.txn_id)
210 200 log.debug('Cleaning up TXN ID %s', txn_id_file)
211 201 if os.path.isfile(txn_id_file):
212 202 os.remove(txn_id_file)
213 203
214 204 log.debug("Background thread done.")
215 205
216 206
217 207 def get_txn_id_data_path(txn_id):
218 208 import rhodecode
219 209
220 210 root = rhodecode.CONFIG.get('cache_dir') or tempfile.gettempdir()
221 211 final_dir = os.path.join(root, 'svn_txn_id')
222 212
223 213 if not os.path.isdir(final_dir):
224 214 os.makedirs(final_dir)
225 215 return os.path.join(final_dir, 'rc_txn_id_{}'.format(txn_id))
226 216
227 217
228 218 def store_txn_id_data(txn_id, data_dict):
229 219 if not txn_id:
230 220 log.warning('Cannot store txn_id because it is empty')
231 221 return
232 222
233 223 path = get_txn_id_data_path(txn_id)
234 224 try:
235 225 with open(path, 'wb') as f:
236 226 f.write(json.dumps(data_dict))
237 227 except Exception:
238 228 log.exception('Failed to write txn_id metadata')
239 229
240 230
241 231 def get_txn_id_from_store(txn_id):
242 232 """
243 233 Reads txn_id from store and if present returns the data for callback manager
244 234 """
245 235 path = get_txn_id_data_path(txn_id)
246 236 try:
247 237 with open(path, 'rb') as f:
248 238 return json.loads(f.read())
249 239 except Exception:
250 240 return {}
251 241
252 242
253 243 def prepare_callback_daemon(extras, protocol, host, use_direct_calls, txn_id=None):
254 244 txn_details = get_txn_id_from_store(txn_id)
255 245 port = txn_details.get('port', 0)
256 246 if use_direct_calls:
257 247 callback_daemon = DummyHooksCallbackDaemon()
258 248 extras['hooks_module'] = callback_daemon.hooks_module
259 249 else:
260 250 if protocol == 'http':
261 251 callback_daemon = HttpHooksCallbackDaemon(
262 252 txn_id=txn_id, host=host, port=port)
263 253 else:
264 254 log.error('Unsupported callback daemon protocol "%s"', protocol)
265 255 raise Exception('Unsupported callback daemon protocol.')
266 256
267 257 extras['hooks_uri'] = callback_daemon.hooks_uri
268 258 extras['hooks_protocol'] = protocol
269 259 extras['time'] = time.time()
270 260
271 261 # register txn_id
272 262 extras['txn_id'] = txn_id
273 263 log.debug('Prepared a callback daemon: %s at url `%s`',
274 264 callback_daemon.__class__.__name__, callback_daemon.hooks_uri)
275 265 return callback_daemon, extras
276 266
277 267
278 268 class Hooks(object):
279 269 """
280 270 Exposes the hooks for remote call backs
281 271 """
282 272
283 273 def repo_size(self, extras):
284 274 log.debug("Called repo_size of %s object", self)
285 275 return self._call_hook(hooks_base.repo_size, extras)
286 276
287 277 def pre_pull(self, extras):
288 278 log.debug("Called pre_pull of %s object", self)
289 279 return self._call_hook(hooks_base.pre_pull, extras)
290 280
291 281 def post_pull(self, extras):
292 282 log.debug("Called post_pull of %s object", self)
293 283 return self._call_hook(hooks_base.post_pull, extras)
294 284
295 285 def pre_push(self, extras):
296 286 log.debug("Called pre_push of %s object", self)
297 287 return self._call_hook(hooks_base.pre_push, extras)
298 288
299 289 def post_push(self, extras):
300 290 log.debug("Called post_push of %s object", self)
301 291 return self._call_hook(hooks_base.post_push, extras)
302 292
303 293 def _call_hook(self, hook, extras):
304 294 extras = AttributeDict(extras)
305 295 server_url = extras['server_url']
306 296 request = bootstrap_request(application_url=server_url)
307 297
308 298 bootstrap_config(request) # inject routes and other interfaces
309 299
310 300 # inject the user for usage in hooks
311 301 request.user = AttributeDict({'username': extras.username,
312 302 'ip_addr': extras.ip,
313 303 'user_id': extras.user_id})
314 304
315 305 extras.request = request
316 306
317 307 try:
318 308 result = hook(extras)
319 309 if result is None:
320 310 raise Exception(
321 311 'Failed to obtain hook result from func: {}'.format(hook))
322 312 except HTTPBranchProtected as handled_error:
323 313 # Those special cases doesn't need error reporting. It's a case of
324 314 # locked repo or protected branch
325 315 result = AttributeDict({
326 316 'status': handled_error.code,
327 317 'output': handled_error.explanation
328 318 })
329 319 except (HTTPLockedRC, Exception) as error:
330 320 # locked needs different handling since we need to also
331 321 # handle PULL operations
332 322 exc_tb = ''
333 323 if not isinstance(error, HTTPLockedRC):
334 324 exc_tb = traceback.format_exc()
335 325 log.exception('Exception when handling hook %s', hook)
336 326 error_args = error.args
337 327 return {
338 328 'status': 128,
339 329 'output': '',
340 330 'exception': type(error).__name__,
341 331 'exception_traceback': exc_tb,
342 332 'exception_args': error_args,
343 333 }
344 334 finally:
345 335 meta.Session.remove()
346 336
347 337 log.debug('Got hook call response %s', result)
348 338 return {
349 339 'status': result.status,
350 340 'output': result.output,
351 341 }
352 342
353 343 def __enter__(self):
354 344 return self
355 345
356 346 def __exit__(self, exc_type, exc_val, exc_tb):
357 347 pass
@@ -1,1173 +1,1193 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21
22 22 """
23 23 Some simple helper functions
24 24 """
25 25
26 26 import collections
27 27 import datetime
28 28 import dateutil.relativedelta
29 29 import hashlib
30 30 import logging
31 31 import re
32 32 import sys
33 33 import time
34 34 import urllib
35 35 import urlobject
36 36 import uuid
37 37 import getpass
38 import socket
39 import random
38 40 from functools import update_wrapper, partial, wraps
39 41
40 42 import pygments.lexers
41 43 import sqlalchemy
42 44 import sqlalchemy.engine.url
43 45 import sqlalchemy.exc
44 46 import sqlalchemy.sql
45 47 import webob
46 48 import pyramid.threadlocal
47 49 from pyramid import compat
48 50 from pyramid.settings import asbool
49 51
50 52 import rhodecode
51 53 from rhodecode.translation import _, _pluralize
52 54
53 55
54 56 def md5(s):
55 57 return hashlib.md5(s).hexdigest()
56 58
57 59
58 60 def md5_safe(s):
59 61 return md5(safe_str(s))
60 62
61 63
62 64 def sha1(s):
63 65 return hashlib.sha1(s).hexdigest()
64 66
65 67
66 68 def sha1_safe(s):
67 69 return sha1(safe_str(s))
68 70
69 71
70 72 def __get_lem(extra_mapping=None):
71 73 """
72 74 Get language extension map based on what's inside pygments lexers
73 75 """
74 76 d = collections.defaultdict(lambda: [])
75 77
76 78 def __clean(s):
77 79 s = s.lstrip('*')
78 80 s = s.lstrip('.')
79 81
80 82 if s.find('[') != -1:
81 83 exts = []
82 84 start, stop = s.find('['), s.find(']')
83 85
84 86 for suffix in s[start + 1:stop]:
85 87 exts.append(s[:s.find('[')] + suffix)
86 88 return [e.lower() for e in exts]
87 89 else:
88 90 return [s.lower()]
89 91
90 92 for lx, t in sorted(pygments.lexers.LEXERS.items()):
91 93 m = map(__clean, t[-2])
92 94 if m:
93 95 m = reduce(lambda x, y: x + y, m)
94 96 for ext in m:
95 97 desc = lx.replace('Lexer', '')
96 98 d[ext].append(desc)
97 99
98 100 data = dict(d)
99 101
100 102 extra_mapping = extra_mapping or {}
101 103 if extra_mapping:
102 104 for k, v in extra_mapping.items():
103 105 if k not in data:
104 106 # register new mapping2lexer
105 107 data[k] = [v]
106 108
107 109 return data
108 110
109 111
110 112 def str2bool(_str):
111 113 """
112 114 returns True/False value from given string, it tries to translate the
113 115 string into boolean
114 116
115 117 :param _str: string value to translate into boolean
116 118 :rtype: boolean
117 119 :returns: boolean from given string
118 120 """
119 121 if _str is None:
120 122 return False
121 123 if _str in (True, False):
122 124 return _str
123 125 _str = str(_str).strip().lower()
124 126 return _str in ('t', 'true', 'y', 'yes', 'on', '1')
125 127
126 128
127 129 def aslist(obj, sep=None, strip=True):
128 130 """
129 131 Returns given string separated by sep as list
130 132
131 133 :param obj:
132 134 :param sep:
133 135 :param strip:
134 136 """
135 137 if isinstance(obj, (basestring,)):
136 138 lst = obj.split(sep)
137 139 if strip:
138 140 lst = [v.strip() for v in lst]
139 141 return lst
140 142 elif isinstance(obj, (list, tuple)):
141 143 return obj
142 144 elif obj is None:
143 145 return []
144 146 else:
145 147 return [obj]
146 148
147 149
148 150 def convert_line_endings(line, mode):
149 151 """
150 152 Converts a given line "line end" accordingly to given mode
151 153
152 154 Available modes are::
153 155 0 - Unix
154 156 1 - Mac
155 157 2 - DOS
156 158
157 159 :param line: given line to convert
158 160 :param mode: mode to convert to
159 161 :rtype: str
160 162 :return: converted line according to mode
161 163 """
162 164 if mode == 0:
163 165 line = line.replace('\r\n', '\n')
164 166 line = line.replace('\r', '\n')
165 167 elif mode == 1:
166 168 line = line.replace('\r\n', '\r')
167 169 line = line.replace('\n', '\r')
168 170 elif mode == 2:
169 171 line = re.sub('\r(?!\n)|(?<!\r)\n', '\r\n', line)
170 172 return line
171 173
172 174
173 175 def detect_mode(line, default):
174 176 """
175 177 Detects line break for given line, if line break couldn't be found
176 178 given default value is returned
177 179
178 180 :param line: str line
179 181 :param default: default
180 182 :rtype: int
181 183 :return: value of line end on of 0 - Unix, 1 - Mac, 2 - DOS
182 184 """
183 185 if line.endswith('\r\n'):
184 186 return 2
185 187 elif line.endswith('\n'):
186 188 return 0
187 189 elif line.endswith('\r'):
188 190 return 1
189 191 else:
190 192 return default
191 193
192 194
193 195 def safe_int(val, default=None):
194 196 """
195 197 Returns int() of val if val is not convertable to int use default
196 198 instead
197 199
198 200 :param val:
199 201 :param default:
200 202 """
201 203
202 204 try:
203 205 val = int(val)
204 206 except (ValueError, TypeError):
205 207 val = default
206 208
207 209 return val
208 210
209 211
210 212 def safe_unicode(str_, from_encoding=None, use_chardet=False):
211 213 """
212 214 safe unicode function. Does few trick to turn str_ into unicode
213 215
214 216 In case of UnicodeDecode error, we try to return it with encoding detected
215 217 by chardet library if it fails fallback to unicode with errors replaced
216 218
217 219 :param str_: string to decode
218 220 :rtype: unicode
219 221 :returns: unicode object
220 222 """
221 223 if isinstance(str_, unicode):
222 224 return str_
223 225
224 226 if not from_encoding:
225 227 DEFAULT_ENCODINGS = aslist(rhodecode.CONFIG.get('default_encoding',
226 228 'utf8'), sep=',')
227 229 from_encoding = DEFAULT_ENCODINGS
228 230
229 231 if not isinstance(from_encoding, (list, tuple)):
230 232 from_encoding = [from_encoding]
231 233
232 234 try:
233 235 return unicode(str_)
234 236 except UnicodeDecodeError:
235 237 pass
236 238
237 239 for enc in from_encoding:
238 240 try:
239 241 return unicode(str_, enc)
240 242 except UnicodeDecodeError:
241 243 pass
242 244
243 245 if use_chardet:
244 246 try:
245 247 import chardet
246 248 encoding = chardet.detect(str_)['encoding']
247 249 if encoding is None:
248 250 raise Exception()
249 251 return str_.decode(encoding)
250 252 except (ImportError, UnicodeDecodeError, Exception):
251 253 return unicode(str_, from_encoding[0], 'replace')
252 254 else:
253 255 return unicode(str_, from_encoding[0], 'replace')
254 256
255 257 def safe_str(unicode_, to_encoding=None, use_chardet=False):
256 258 """
257 259 safe str function. Does few trick to turn unicode_ into string
258 260
259 261 In case of UnicodeEncodeError, we try to return it with encoding detected
260 262 by chardet library if it fails fallback to string with errors replaced
261 263
262 264 :param unicode_: unicode to encode
263 265 :rtype: str
264 266 :returns: str object
265 267 """
266 268
267 269 # if it's not basestr cast to str
268 270 if not isinstance(unicode_, compat.string_types):
269 271 return str(unicode_)
270 272
271 273 if isinstance(unicode_, str):
272 274 return unicode_
273 275
274 276 if not to_encoding:
275 277 DEFAULT_ENCODINGS = aslist(rhodecode.CONFIG.get('default_encoding',
276 278 'utf8'), sep=',')
277 279 to_encoding = DEFAULT_ENCODINGS
278 280
279 281 if not isinstance(to_encoding, (list, tuple)):
280 282 to_encoding = [to_encoding]
281 283
282 284 for enc in to_encoding:
283 285 try:
284 286 return unicode_.encode(enc)
285 287 except UnicodeEncodeError:
286 288 pass
287 289
288 290 if use_chardet:
289 291 try:
290 292 import chardet
291 293 encoding = chardet.detect(unicode_)['encoding']
292 294 if encoding is None:
293 295 raise UnicodeEncodeError()
294 296
295 297 return unicode_.encode(encoding)
296 298 except (ImportError, UnicodeEncodeError):
297 299 return unicode_.encode(to_encoding[0], 'replace')
298 300 else:
299 301 return unicode_.encode(to_encoding[0], 'replace')
300 302
301 303
302 304 def remove_suffix(s, suffix):
303 305 if s.endswith(suffix):
304 306 s = s[:-1 * len(suffix)]
305 307 return s
306 308
307 309
308 310 def remove_prefix(s, prefix):
309 311 if s.startswith(prefix):
310 312 s = s[len(prefix):]
311 313 return s
312 314
313 315
314 316 def find_calling_context(ignore_modules=None):
315 317 """
316 318 Look through the calling stack and return the frame which called
317 319 this function and is part of core module ( ie. rhodecode.* )
318 320
319 321 :param ignore_modules: list of modules to ignore eg. ['rhodecode.lib']
320 322 """
321 323
322 324 ignore_modules = ignore_modules or []
323 325
324 326 f = sys._getframe(2)
325 327 while f.f_back is not None:
326 328 name = f.f_globals.get('__name__')
327 329 if name and name.startswith(__name__.split('.')[0]):
328 330 if name not in ignore_modules:
329 331 return f
330 332 f = f.f_back
331 333 return None
332 334
333 335
334 336 def ping_connection(connection, branch):
335 337 if branch:
336 338 # "branch" refers to a sub-connection of a connection,
337 339 # we don't want to bother pinging on these.
338 340 return
339 341
340 342 # turn off "close with result". This flag is only used with
341 343 # "connectionless" execution, otherwise will be False in any case
342 344 save_should_close_with_result = connection.should_close_with_result
343 345 connection.should_close_with_result = False
344 346
345 347 try:
346 348 # run a SELECT 1. use a core select() so that
347 349 # the SELECT of a scalar value without a table is
348 350 # appropriately formatted for the backend
349 351 connection.scalar(sqlalchemy.sql.select([1]))
350 352 except sqlalchemy.exc.DBAPIError as err:
351 353 # catch SQLAlchemy's DBAPIError, which is a wrapper
352 354 # for the DBAPI's exception. It includes a .connection_invalidated
353 355 # attribute which specifies if this connection is a "disconnect"
354 356 # condition, which is based on inspection of the original exception
355 357 # by the dialect in use.
356 358 if err.connection_invalidated:
357 359 # run the same SELECT again - the connection will re-validate
358 360 # itself and establish a new connection. The disconnect detection
359 361 # here also causes the whole connection pool to be invalidated
360 362 # so that all stale connections are discarded.
361 363 connection.scalar(sqlalchemy.sql.select([1]))
362 364 else:
363 365 raise
364 366 finally:
365 367 # restore "close with result"
366 368 connection.should_close_with_result = save_should_close_with_result
367 369
368 370
369 371 def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs):
370 372 """Custom engine_from_config functions."""
371 373 log = logging.getLogger('sqlalchemy.engine')
372 374 use_ping_connection = asbool(configuration.pop('sqlalchemy.db1.ping_connection', None))
373 375 debug = asbool(configuration.pop('sqlalchemy.db1.debug_query', None))
374 376
375 377 engine = sqlalchemy.engine_from_config(configuration, prefix, **kwargs)
376 378
377 379 def color_sql(sql):
378 380 color_seq = '\033[1;33m' # This is yellow: code 33
379 381 normal = '\x1b[0m'
380 382 return ''.join([color_seq, sql, normal])
381 383
382 384 if use_ping_connection:
383 385 log.debug('Adding ping_connection on the engine config.')
384 386 sqlalchemy.event.listen(engine, "engine_connect", ping_connection)
385 387
386 388 if debug:
387 389 # attach events only for debug configuration
388 390 def before_cursor_execute(conn, cursor, statement,
389 391 parameters, context, executemany):
390 392 setattr(conn, 'query_start_time', time.time())
391 393 log.info(color_sql(">>>>> STARTING QUERY >>>>>"))
392 394 calling_context = find_calling_context(ignore_modules=[
393 395 'rhodecode.lib.caching_query',
394 396 'rhodecode.model.settings',
395 397 ])
396 398 if calling_context:
397 399 log.info(color_sql('call context %s:%s' % (
398 400 calling_context.f_code.co_filename,
399 401 calling_context.f_lineno,
400 402 )))
401 403
402 404 def after_cursor_execute(conn, cursor, statement,
403 405 parameters, context, executemany):
404 406 delattr(conn, 'query_start_time')
405 407
406 408 sqlalchemy.event.listen(engine, "before_cursor_execute", before_cursor_execute)
407 409 sqlalchemy.event.listen(engine, "after_cursor_execute", after_cursor_execute)
408 410
409 411 return engine
410 412
411 413
412 414 def get_encryption_key(config):
413 415 secret = config.get('rhodecode.encrypted_values.secret')
414 416 default = config['beaker.session.secret']
415 417 return secret or default
416 418
417 419
418 420 def age(prevdate, now=None, show_short_version=False, show_suffix=True,
419 421 short_format=False):
420 422 """
421 423 Turns a datetime into an age string.
422 424 If show_short_version is True, this generates a shorter string with
423 425 an approximate age; ex. '1 day ago', rather than '1 day and 23 hours ago'.
424 426
425 427 * IMPORTANT*
426 428 Code of this function is written in special way so it's easier to
427 429 backport it to javascript. If you mean to update it, please also update
428 430 `jquery.timeago-extension.js` file
429 431
430 432 :param prevdate: datetime object
431 433 :param now: get current time, if not define we use
432 434 `datetime.datetime.now()`
433 435 :param show_short_version: if it should approximate the date and
434 436 return a shorter string
435 437 :param show_suffix:
436 438 :param short_format: show short format, eg 2D instead of 2 days
437 439 :rtype: unicode
438 440 :returns: unicode words describing age
439 441 """
440 442
441 443 def _get_relative_delta(now, prevdate):
442 444 base = dateutil.relativedelta.relativedelta(now, prevdate)
443 445 return {
444 446 'year': base.years,
445 447 'month': base.months,
446 448 'day': base.days,
447 449 'hour': base.hours,
448 450 'minute': base.minutes,
449 451 'second': base.seconds,
450 452 }
451 453
452 454 def _is_leap_year(year):
453 455 return year % 4 == 0 and (year % 100 != 0 or year % 400 == 0)
454 456
455 457 def get_month(prevdate):
456 458 return prevdate.month
457 459
458 460 def get_year(prevdate):
459 461 return prevdate.year
460 462
461 463 now = now or datetime.datetime.now()
462 464 order = ['year', 'month', 'day', 'hour', 'minute', 'second']
463 465 deltas = {}
464 466 future = False
465 467
466 468 if prevdate > now:
467 469 now_old = now
468 470 now = prevdate
469 471 prevdate = now_old
470 472 future = True
471 473 if future:
472 474 prevdate = prevdate.replace(microsecond=0)
473 475 # Get date parts deltas
474 476 for part in order:
475 477 rel_delta = _get_relative_delta(now, prevdate)
476 478 deltas[part] = rel_delta[part]
477 479
478 480 # Fix negative offsets (there is 1 second between 10:59:59 and 11:00:00,
479 481 # not 1 hour, -59 minutes and -59 seconds)
480 482 offsets = [[5, 60], [4, 60], [3, 24]]
481 483 for element in offsets: # seconds, minutes, hours
482 484 num = element[0]
483 485 length = element[1]
484 486
485 487 part = order[num]
486 488 carry_part = order[num - 1]
487 489
488 490 if deltas[part] < 0:
489 491 deltas[part] += length
490 492 deltas[carry_part] -= 1
491 493
492 494 # Same thing for days except that the increment depends on the (variable)
493 495 # number of days in the month
494 496 month_lengths = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
495 497 if deltas['day'] < 0:
496 498 if get_month(prevdate) == 2 and _is_leap_year(get_year(prevdate)):
497 499 deltas['day'] += 29
498 500 else:
499 501 deltas['day'] += month_lengths[get_month(prevdate) - 1]
500 502
501 503 deltas['month'] -= 1
502 504
503 505 if deltas['month'] < 0:
504 506 deltas['month'] += 12
505 507 deltas['year'] -= 1
506 508
507 509 # Format the result
508 510 if short_format:
509 511 fmt_funcs = {
510 512 'year': lambda d: u'%dy' % d,
511 513 'month': lambda d: u'%dm' % d,
512 514 'day': lambda d: u'%dd' % d,
513 515 'hour': lambda d: u'%dh' % d,
514 516 'minute': lambda d: u'%dmin' % d,
515 517 'second': lambda d: u'%dsec' % d,
516 518 }
517 519 else:
518 520 fmt_funcs = {
519 521 'year': lambda d: _pluralize(u'${num} year', u'${num} years', d, mapping={'num': d}).interpolate(),
520 522 'month': lambda d: _pluralize(u'${num} month', u'${num} months', d, mapping={'num': d}).interpolate(),
521 523 'day': lambda d: _pluralize(u'${num} day', u'${num} days', d, mapping={'num': d}).interpolate(),
522 524 'hour': lambda d: _pluralize(u'${num} hour', u'${num} hours', d, mapping={'num': d}).interpolate(),
523 525 'minute': lambda d: _pluralize(u'${num} minute', u'${num} minutes', d, mapping={'num': d}).interpolate(),
524 526 'second': lambda d: _pluralize(u'${num} second', u'${num} seconds', d, mapping={'num': d}).interpolate(),
525 527 }
526 528
527 529 i = 0
528 530 for part in order:
529 531 value = deltas[part]
530 532 if value != 0:
531 533
532 534 if i < 5:
533 535 sub_part = order[i + 1]
534 536 sub_value = deltas[sub_part]
535 537 else:
536 538 sub_value = 0
537 539
538 540 if sub_value == 0 or show_short_version:
539 541 _val = fmt_funcs[part](value)
540 542 if future:
541 543 if show_suffix:
542 544 return _(u'in ${ago}', mapping={'ago': _val})
543 545 else:
544 546 return _(_val)
545 547
546 548 else:
547 549 if show_suffix:
548 550 return _(u'${ago} ago', mapping={'ago': _val})
549 551 else:
550 552 return _(_val)
551 553
552 554 val = fmt_funcs[part](value)
553 555 val_detail = fmt_funcs[sub_part](sub_value)
554 556 mapping = {'val': val, 'detail': val_detail}
555 557
556 558 if short_format:
557 559 datetime_tmpl = _(u'${val}, ${detail}', mapping=mapping)
558 560 if show_suffix:
559 561 datetime_tmpl = _(u'${val}, ${detail} ago', mapping=mapping)
560 562 if future:
561 563 datetime_tmpl = _(u'in ${val}, ${detail}', mapping=mapping)
562 564 else:
563 565 datetime_tmpl = _(u'${val} and ${detail}', mapping=mapping)
564 566 if show_suffix:
565 567 datetime_tmpl = _(u'${val} and ${detail} ago', mapping=mapping)
566 568 if future:
567 569 datetime_tmpl = _(u'in ${val} and ${detail}', mapping=mapping)
568 570
569 571 return datetime_tmpl
570 572 i += 1
571 573 return _(u'just now')
572 574
573 575
574 576 def age_from_seconds(seconds):
575 577 seconds = safe_int(seconds) or 0
576 578 prevdate = time_to_datetime(time.time() + seconds)
577 579 return age(prevdate, show_suffix=False, show_short_version=True)
578 580
579 581
580 582 def cleaned_uri(uri):
581 583 """
582 584 Quotes '[' and ']' from uri if there is only one of them.
583 585 according to RFC3986 we cannot use such chars in uri
584 586 :param uri:
585 587 :return: uri without this chars
586 588 """
587 589 return urllib.quote(uri, safe='@$:/')
588 590
589 591
590 592 def credentials_filter(uri):
591 593 """
592 594 Returns a url with removed credentials
593 595
594 596 :param uri:
595 597 """
596 598 import urlobject
597 599 if isinstance(uri, rhodecode.lib.encrypt.InvalidDecryptedValue):
598 600 return 'InvalidDecryptionKey'
599 601
600 602 url_obj = urlobject.URLObject(cleaned_uri(uri))
601 603 url_obj = url_obj.without_password().without_username()
602 604
603 605 return url_obj
604 606
605 607
606 608 def get_host_info(request):
607 609 """
608 610 Generate host info, to obtain full url e.g https://server.com
609 611 use this
610 612 `{scheme}://{netloc}`
611 613 """
612 614 if not request:
613 615 return {}
614 616
615 617 qualified_home_url = request.route_url('home')
616 618 parsed_url = urlobject.URLObject(qualified_home_url)
617 619 decoded_path = safe_unicode(urllib.unquote(parsed_url.path.rstrip('/')))
618 620
619 621 return {
620 622 'scheme': parsed_url.scheme,
621 623 'netloc': parsed_url.netloc+decoded_path,
622 624 'hostname': parsed_url.hostname,
623 625 }
624 626
625 627
626 628 def get_clone_url(request, uri_tmpl, repo_name, repo_id, repo_type, **override):
627 629 qualified_home_url = request.route_url('home')
628 630 parsed_url = urlobject.URLObject(qualified_home_url)
629 631 decoded_path = safe_unicode(urllib.unquote(parsed_url.path.rstrip('/')))
630 632
631 633 args = {
632 634 'scheme': parsed_url.scheme,
633 635 'user': '',
634 636 'sys_user': getpass.getuser(),
635 637 # path if we use proxy-prefix
636 638 'netloc': parsed_url.netloc+decoded_path,
637 639 'hostname': parsed_url.hostname,
638 640 'prefix': decoded_path,
639 641 'repo': repo_name,
640 642 'repoid': str(repo_id),
641 643 'repo_type': repo_type
642 644 }
643 645 args.update(override)
644 646 args['user'] = urllib.quote(safe_str(args['user']))
645 647
646 648 for k, v in args.items():
647 649 uri_tmpl = uri_tmpl.replace('{%s}' % k, v)
648 650
649 651 # special case for SVN clone url
650 652 if repo_type == 'svn':
651 653 uri_tmpl = uri_tmpl.replace('ssh://', 'svn+ssh://')
652 654
653 655 # remove leading @ sign if it's present. Case of empty user
654 656 url_obj = urlobject.URLObject(uri_tmpl)
655 657 url = url_obj.with_netloc(url_obj.netloc.lstrip('@'))
656 658
657 659 return safe_unicode(url)
658 660
659 661
660 662 def get_commit_safe(repo, commit_id=None, commit_idx=None, pre_load=None,
661 663 maybe_unreachable=False, reference_obj=None):
662 664 """
663 665 Safe version of get_commit if this commit doesn't exists for a
664 666 repository it returns a Dummy one instead
665 667
666 668 :param repo: repository instance
667 669 :param commit_id: commit id as str
668 670 :param commit_idx: numeric commit index
669 671 :param pre_load: optional list of commit attributes to load
670 672 :param maybe_unreachable: translate unreachable commits on git repos
671 673 :param reference_obj: explicitly search via a reference obj in git. E.g "branch:123" would mean branch "123"
672 674 """
673 675 # TODO(skreft): remove these circular imports
674 676 from rhodecode.lib.vcs.backends.base import BaseRepository, EmptyCommit
675 677 from rhodecode.lib.vcs.exceptions import RepositoryError
676 678 if not isinstance(repo, BaseRepository):
677 679 raise Exception('You must pass an Repository '
678 680 'object as first argument got %s', type(repo))
679 681
680 682 try:
681 683 commit = repo.get_commit(
682 684 commit_id=commit_id, commit_idx=commit_idx, pre_load=pre_load,
683 685 maybe_unreachable=maybe_unreachable, reference_obj=reference_obj)
684 686 except (RepositoryError, LookupError):
685 687 commit = EmptyCommit()
686 688 return commit
687 689
688 690
689 691 def datetime_to_time(dt):
690 692 if dt:
691 693 return time.mktime(dt.timetuple())
692 694
693 695
694 696 def time_to_datetime(tm):
695 697 if tm:
696 698 if isinstance(tm, compat.string_types):
697 699 try:
698 700 tm = float(tm)
699 701 except ValueError:
700 702 return
701 703 return datetime.datetime.fromtimestamp(tm)
702 704
703 705
704 706 def time_to_utcdatetime(tm):
705 707 if tm:
706 708 if isinstance(tm, compat.string_types):
707 709 try:
708 710 tm = float(tm)
709 711 except ValueError:
710 712 return
711 713 return datetime.datetime.utcfromtimestamp(tm)
712 714
713 715
714 716 MENTIONS_REGEX = re.compile(
715 717 # ^@ or @ without any special chars in front
716 718 r'(?:^@|[^a-zA-Z0-9\-\_\.]@)'
717 719 # main body starts with letter, then can be . - _
718 720 r'([a-zA-Z0-9]{1}[a-zA-Z0-9\-\_\.]+)',
719 721 re.VERBOSE | re.MULTILINE)
720 722
721 723
722 724 def extract_mentioned_users(s):
723 725 """
724 726 Returns unique usernames from given string s that have @mention
725 727
726 728 :param s: string to get mentions
727 729 """
728 730 usrs = set()
729 731 for username in MENTIONS_REGEX.findall(s):
730 732 usrs.add(username)
731 733
732 734 return sorted(list(usrs), key=lambda k: k.lower())
733 735
734 736
735 737 class AttributeDictBase(dict):
736 738 def __getstate__(self):
737 739 odict = self.__dict__ # get attribute dictionary
738 740 return odict
739 741
740 742 def __setstate__(self, dict):
741 743 self.__dict__ = dict
742 744
743 745 __setattr__ = dict.__setitem__
744 746 __delattr__ = dict.__delitem__
745 747
746 748
747 749 class StrictAttributeDict(AttributeDictBase):
748 750 """
749 751 Strict Version of Attribute dict which raises an Attribute error when
750 752 requested attribute is not set
751 753 """
752 754 def __getattr__(self, attr):
753 755 try:
754 756 return self[attr]
755 757 except KeyError:
756 758 raise AttributeError('%s object has no attribute %s' % (
757 759 self.__class__, attr))
758 760
759 761
760 762 class AttributeDict(AttributeDictBase):
761 763 def __getattr__(self, attr):
762 764 return self.get(attr, None)
763 765
764 766
765 767
766 768 class OrderedDefaultDict(collections.OrderedDict, collections.defaultdict):
767 769 def __init__(self, default_factory=None, *args, **kwargs):
768 770 # in python3 you can omit the args to super
769 771 super(OrderedDefaultDict, self).__init__(*args, **kwargs)
770 772 self.default_factory = default_factory
771 773
772 774
773 775 def fix_PATH(os_=None):
774 776 """
775 777 Get current active python path, and append it to PATH variable to fix
776 778 issues of subprocess calls and different python versions
777 779 """
778 780 if os_ is None:
779 781 import os
780 782 else:
781 783 os = os_
782 784
783 785 cur_path = os.path.split(sys.executable)[0]
784 786 if not os.environ['PATH'].startswith(cur_path):
785 787 os.environ['PATH'] = '%s:%s' % (cur_path, os.environ['PATH'])
786 788
787 789
788 790 def obfuscate_url_pw(engine):
789 791 _url = engine or ''
790 792 try:
791 793 _url = sqlalchemy.engine.url.make_url(engine)
792 794 if _url.password:
793 795 _url.password = 'XXXXX'
794 796 except Exception:
795 797 pass
796 798 return unicode(_url)
797 799
798 800
799 801 def get_server_url(environ):
800 802 req = webob.Request(environ)
801 803 return req.host_url + req.script_name
802 804
803 805
804 806 def unique_id(hexlen=32):
805 807 alphabet = "23456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghjklmnpqrstuvwxyz"
806 808 return suuid(truncate_to=hexlen, alphabet=alphabet)
807 809
808 810
809 811 def suuid(url=None, truncate_to=22, alphabet=None):
810 812 """
811 813 Generate and return a short URL safe UUID.
812 814
813 815 If the url parameter is provided, set the namespace to the provided
814 816 URL and generate a UUID.
815 817
816 818 :param url to get the uuid for
817 819 :truncate_to: truncate the basic 22 UUID to shorter version
818 820
819 821 The IDs won't be universally unique any longer, but the probability of
820 822 a collision will still be very low.
821 823 """
822 824 # Define our alphabet.
823 825 _ALPHABET = alphabet or "23456789ABCDEFGHJKLMNPQRSTUVWXYZ"
824 826
825 827 # If no URL is given, generate a random UUID.
826 828 if url is None:
827 829 unique_id = uuid.uuid4().int
828 830 else:
829 831 unique_id = uuid.uuid3(uuid.NAMESPACE_URL, url).int
830 832
831 833 alphabet_length = len(_ALPHABET)
832 834 output = []
833 835 while unique_id > 0:
834 836 digit = unique_id % alphabet_length
835 837 output.append(_ALPHABET[digit])
836 838 unique_id = int(unique_id / alphabet_length)
837 839 return "".join(output)[:truncate_to]
838 840
839 841
840 842 def get_current_rhodecode_user(request=None):
841 843 """
842 844 Gets rhodecode user from request
843 845 """
844 846 pyramid_request = request or pyramid.threadlocal.get_current_request()
845 847
846 848 # web case
847 849 if pyramid_request and hasattr(pyramid_request, 'user'):
848 850 return pyramid_request.user
849 851
850 852 # api case
851 853 if pyramid_request and hasattr(pyramid_request, 'rpc_user'):
852 854 return pyramid_request.rpc_user
853 855
854 856 return None
855 857
856 858
857 859 def action_logger_generic(action, namespace=''):
858 860 """
859 861 A generic logger for actions useful to the system overview, tries to find
860 862 an acting user for the context of the call otherwise reports unknown user
861 863
862 864 :param action: logging message eg 'comment 5 deleted'
863 865 :param type: string
864 866
865 867 :param namespace: namespace of the logging message eg. 'repo.comments'
866 868 :param type: string
867 869
868 870 """
869 871
870 872 logger_name = 'rhodecode.actions'
871 873
872 874 if namespace:
873 875 logger_name += '.' + namespace
874 876
875 877 log = logging.getLogger(logger_name)
876 878
877 879 # get a user if we can
878 880 user = get_current_rhodecode_user()
879 881
880 882 logfunc = log.info
881 883
882 884 if not user:
883 885 user = '<unknown user>'
884 886 logfunc = log.warning
885 887
886 888 logfunc('Logging action by {}: {}'.format(user, action))
887 889
888 890
889 891 def escape_split(text, sep=',', maxsplit=-1):
890 892 r"""
891 893 Allows for escaping of the separator: e.g. arg='foo\, bar'
892 894
893 895 It should be noted that the way bash et. al. do command line parsing, those
894 896 single quotes are required.
895 897 """
896 898 escaped_sep = r'\%s' % sep
897 899
898 900 if escaped_sep not in text:
899 901 return text.split(sep, maxsplit)
900 902
901 903 before, _mid, after = text.partition(escaped_sep)
902 904 startlist = before.split(sep, maxsplit) # a regular split is fine here
903 905 unfinished = startlist[-1]
904 906 startlist = startlist[:-1]
905 907
906 908 # recurse because there may be more escaped separators
907 909 endlist = escape_split(after, sep, maxsplit)
908 910
909 911 # finish building the escaped value. we use endlist[0] becaue the first
910 912 # part of the string sent in recursion is the rest of the escaped value.
911 913 unfinished += sep + endlist[0]
912 914
913 915 return startlist + [unfinished] + endlist[1:] # put together all the parts
914 916
915 917
916 918 class OptionalAttr(object):
917 919 """
918 920 Special Optional Option that defines other attribute. Example::
919 921
920 922 def test(apiuser, userid=Optional(OAttr('apiuser')):
921 923 user = Optional.extract(userid)
922 924 # calls
923 925
924 926 """
925 927
926 928 def __init__(self, attr_name):
927 929 self.attr_name = attr_name
928 930
929 931 def __repr__(self):
930 932 return '<OptionalAttr:%s>' % self.attr_name
931 933
932 934 def __call__(self):
933 935 return self
934 936
935 937
936 938 # alias
937 939 OAttr = OptionalAttr
938 940
939 941
940 942 class Optional(object):
941 943 """
942 944 Defines an optional parameter::
943 945
944 946 param = param.getval() if isinstance(param, Optional) else param
945 947 param = param() if isinstance(param, Optional) else param
946 948
947 949 is equivalent of::
948 950
949 951 param = Optional.extract(param)
950 952
951 953 """
952 954
953 955 def __init__(self, type_):
954 956 self.type_ = type_
955 957
956 958 def __repr__(self):
957 959 return '<Optional:%s>' % self.type_.__repr__()
958 960
959 961 def __call__(self):
960 962 return self.getval()
961 963
962 964 def getval(self):
963 965 """
964 966 returns value from this Optional instance
965 967 """
966 968 if isinstance(self.type_, OAttr):
967 969 # use params name
968 970 return self.type_.attr_name
969 971 return self.type_
970 972
971 973 @classmethod
972 974 def extract(cls, val):
973 975 """
974 976 Extracts value from Optional() instance
975 977
976 978 :param val:
977 979 :return: original value if it's not Optional instance else
978 980 value of instance
979 981 """
980 982 if isinstance(val, cls):
981 983 return val.getval()
982 984 return val
983 985
984 986
985 987 def glob2re(pat):
986 988 """
987 989 Translate a shell PATTERN to a regular expression.
988 990
989 991 There is no way to quote meta-characters.
990 992 """
991 993
992 994 i, n = 0, len(pat)
993 995 res = ''
994 996 while i < n:
995 997 c = pat[i]
996 998 i = i+1
997 999 if c == '*':
998 1000 #res = res + '.*'
999 1001 res = res + '[^/]*'
1000 1002 elif c == '?':
1001 1003 #res = res + '.'
1002 1004 res = res + '[^/]'
1003 1005 elif c == '[':
1004 1006 j = i
1005 1007 if j < n and pat[j] == '!':
1006 1008 j = j+1
1007 1009 if j < n and pat[j] == ']':
1008 1010 j = j+1
1009 1011 while j < n and pat[j] != ']':
1010 1012 j = j+1
1011 1013 if j >= n:
1012 1014 res = res + '\\['
1013 1015 else:
1014 1016 stuff = pat[i:j].replace('\\','\\\\')
1015 1017 i = j+1
1016 1018 if stuff[0] == '!':
1017 1019 stuff = '^' + stuff[1:]
1018 1020 elif stuff[0] == '^':
1019 1021 stuff = '\\' + stuff
1020 1022 res = '%s[%s]' % (res, stuff)
1021 1023 else:
1022 1024 res = res + re.escape(c)
1023 1025 return res + '\Z(?ms)'
1024 1026
1025 1027
1026 1028 def parse_byte_string(size_str):
1027 1029 match = re.match(r'(\d+)(MB|KB)', size_str, re.IGNORECASE)
1028 1030 if not match:
1029 1031 raise ValueError('Given size:%s is invalid, please make sure '
1030 1032 'to use format of <num>(MB|KB)' % size_str)
1031 1033
1032 1034 _parts = match.groups()
1033 1035 num, type_ = _parts
1034 1036 return long(num) * {'mb': 1024*1024, 'kb': 1024}[type_.lower()]
1035 1037
1036 1038
1037 1039 class CachedProperty(object):
1038 1040 """
1039 1041 Lazy Attributes. With option to invalidate the cache by running a method
1040 1042
1041 1043 >>> class Foo(object):
1042 1044 ...
1043 1045 ... @CachedProperty
1044 1046 ... def heavy_func(self):
1045 1047 ... return 'super-calculation'
1046 1048 ...
1047 1049 ... foo = Foo()
1048 1050 ... foo.heavy_func() # first computation
1049 1051 ... foo.heavy_func() # fetch from cache
1050 1052 ... foo._invalidate_prop_cache('heavy_func')
1051 1053
1052 1054 # at this point calling foo.heavy_func() will be re-computed
1053 1055 """
1054 1056
1055 1057 def __init__(self, func, func_name=None):
1056 1058
1057 1059 if func_name is None:
1058 1060 func_name = func.__name__
1059 1061 self.data = (func, func_name)
1060 1062 update_wrapper(self, func)
1061 1063
1062 1064 def __get__(self, inst, class_):
1063 1065 if inst is None:
1064 1066 return self
1065 1067
1066 1068 func, func_name = self.data
1067 1069 value = func(inst)
1068 1070 inst.__dict__[func_name] = value
1069 1071 if '_invalidate_prop_cache' not in inst.__dict__:
1070 1072 inst.__dict__['_invalidate_prop_cache'] = partial(
1071 1073 self._invalidate_prop_cache, inst)
1072 1074 return value
1073 1075
1074 1076 def _invalidate_prop_cache(self, inst, name):
1075 1077 inst.__dict__.pop(name, None)
1076 1078
1077 1079
1078 1080 def retry(func=None, exception=Exception, n_tries=5, delay=5, backoff=1, logger=True):
1079 1081 """
1080 1082 Retry decorator with exponential backoff.
1081 1083
1082 1084 Parameters
1083 1085 ----------
1084 1086 func : typing.Callable, optional
1085 1087 Callable on which the decorator is applied, by default None
1086 1088 exception : Exception or tuple of Exceptions, optional
1087 1089 Exception(s) that invoke retry, by default Exception
1088 1090 n_tries : int, optional
1089 1091 Number of tries before giving up, by default 5
1090 1092 delay : int, optional
1091 1093 Initial delay between retries in seconds, by default 5
1092 1094 backoff : int, optional
1093 1095 Backoff multiplier e.g. value of 2 will double the delay, by default 1
1094 1096 logger : bool, optional
1095 1097 Option to log or print, by default False
1096 1098
1097 1099 Returns
1098 1100 -------
1099 1101 typing.Callable
1100 1102 Decorated callable that calls itself when exception(s) occur.
1101 1103
1102 1104 Examples
1103 1105 --------
1104 1106 >>> import random
1105 1107 >>> @retry(exception=Exception, n_tries=3)
1106 1108 ... def test_random(text):
1107 1109 ... x = random.random()
1108 1110 ... if x < 0.5:
1109 1111 ... raise Exception("Fail")
1110 1112 ... else:
1111 1113 ... print("Success: ", text)
1112 1114 >>> test_random("It works!")
1113 1115 """
1114 1116
1115 1117 if func is None:
1116 1118 return partial(
1117 1119 retry,
1118 1120 exception=exception,
1119 1121 n_tries=n_tries,
1120 1122 delay=delay,
1121 1123 backoff=backoff,
1122 1124 logger=logger,
1123 1125 )
1124 1126
1125 1127 @wraps(func)
1126 1128 def wrapper(*args, **kwargs):
1127 1129 _n_tries, n_delay = n_tries, delay
1128 1130 log = logging.getLogger('rhodecode.retry')
1129 1131
1130 1132 while _n_tries > 1:
1131 1133 try:
1132 1134 return func(*args, **kwargs)
1133 1135 except exception as e:
1134 1136 e_details = repr(e)
1135 1137 msg = "Exception on calling func {func}: {e}, " \
1136 1138 "Retrying in {n_delay} seconds..."\
1137 1139 .format(func=func, e=e_details, n_delay=n_delay)
1138 1140 if logger:
1139 1141 log.warning(msg)
1140 1142 else:
1141 1143 print(msg)
1142 1144 time.sleep(n_delay)
1143 1145 _n_tries -= 1
1144 1146 n_delay *= backoff
1145 1147
1146 1148 return func(*args, **kwargs)
1147 1149
1148 1150 return wrapper
1149 1151
1150 1152
1151 1153 def user_agent_normalizer(user_agent_raw, safe=True):
1152 1154 log = logging.getLogger('rhodecode.user_agent_normalizer')
1153 1155 ua = (user_agent_raw or '').strip().lower()
1154 1156 ua = ua.replace('"', '')
1155 1157
1156 1158 try:
1157 1159 if 'mercurial/proto-1.0' in ua:
1158 1160 ua = ua.replace('mercurial/proto-1.0', '')
1159 1161 ua = ua.replace('(', '').replace(')', '').strip()
1160 1162 ua = ua.replace('mercurial ', 'mercurial/')
1161 1163 elif ua.startswith('git'):
1162 1164 parts = ua.split(' ')
1163 1165 if parts:
1164 1166 ua = parts[0]
1165 1167 ua = re.sub('\.windows\.\d', '', ua).strip()
1166 1168
1167 1169 return ua
1168 1170 except Exception:
1169 1171 log.exception('Failed to parse scm user-agent')
1170 1172 if not safe:
1171 1173 raise
1172 1174
1173 1175 return ua
1176
1177
1178 def get_available_port(min_port=40000, max_port=55555):
1179 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1180 hostname = '127.0.0.1'
1181 pick_port = min_port
1182
1183 for _ in range(min_port, max_port):
1184 pick_port = random.randint(min_port, max_port)
1185 try:
1186 sock.bind((hostname, pick_port))
1187 sock.close()
1188 break
1189 except OSError:
1190 pass
1191
1192 del sock
1193 return pick_port
@@ -1,295 +1,285 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import json
22 22 import platform
23 23 import socket
24 24 import random
25 25 import pytest
26 26
27 27 from rhodecode.lib.pyramid_utils import get_app_config
28 28 from rhodecode.tests.fixture import TestINI
29 29 from rhodecode.tests.server_utils import RcVCSServer
30 30
31 31
32 32 def _parse_json(value):
33 33 return json.loads(value) if value else None
34 34
35 35
36 36 def pytest_addoption(parser):
37 37 parser.addoption(
38 38 '--test-loglevel', dest='test_loglevel',
39 39 help="Set default Logging level for tests, critical(default), error, warn , info, debug")
40 40 group = parser.getgroup('pylons')
41 41 group.addoption(
42 42 '--with-pylons', dest='pyramid_config',
43 43 help="Set up a Pylons environment with the specified config file.")
44 44 group.addoption(
45 45 '--ini-config-override', action='store', type=_parse_json,
46 46 default=None, dest='pyramid_config_override', help=(
47 47 "Overrides the .ini file settings. Should be specified in JSON"
48 48 " format, e.g. '{\"section\": {\"parameter\": \"value\", ...}}'"
49 49 )
50 50 )
51 51 parser.addini(
52 52 'pyramid_config',
53 53 "Set up a Pyramid environment with the specified config file.")
54 54
55 55 vcsgroup = parser.getgroup('vcs')
56 56 vcsgroup.addoption(
57 57 '--without-vcsserver', dest='with_vcsserver', action='store_false',
58 58 help="Do not start the VCSServer in a background process.")
59 59 vcsgroup.addoption(
60 60 '--with-vcsserver-http', dest='vcsserver_config_http',
61 61 help="Start the HTTP VCSServer with the specified config file.")
62 62 vcsgroup.addoption(
63 63 '--vcsserver-protocol', dest='vcsserver_protocol',
64 64 help="Start the VCSServer with HTTP protocol support.")
65 65 vcsgroup.addoption(
66 66 '--vcsserver-config-override', action='store', type=_parse_json,
67 67 default=None, dest='vcsserver_config_override', help=(
68 68 "Overrides the .ini file settings for the VCSServer. "
69 69 "Should be specified in JSON "
70 70 "format, e.g. '{\"section\": {\"parameter\": \"value\", ...}}'"
71 71 )
72 72 )
73 73 vcsgroup.addoption(
74 74 '--vcsserver-port', action='store', type=int,
75 75 default=None, help=(
76 76 "Allows to set the port of the vcsserver. Useful when testing "
77 77 "against an already running server and random ports cause "
78 78 "trouble."))
79 79 parser.addini(
80 80 'vcsserver_config_http',
81 81 "Start the HTTP VCSServer with the specified config file.")
82 82 parser.addini(
83 83 'vcsserver_protocol',
84 84 "Start the VCSServer with HTTP protocol support.")
85 85
86 86
87 87 @pytest.fixture(scope='session')
88 88 def vcsserver(request, vcsserver_port, vcsserver_factory):
89 89 """
90 90 Session scope VCSServer.
91 91
92 92 Tests wich need the VCSServer have to rely on this fixture in order
93 93 to ensure it will be running.
94 94
95 95 For specific needs, the fixture vcsserver_factory can be used. It allows to
96 96 adjust the configuration file for the test run.
97 97
98 98 Command line args:
99 99
100 100 --without-vcsserver: Allows to switch this fixture off. You have to
101 101 manually start the server.
102 102
103 103 --vcsserver-port: Will expect the VCSServer to listen on this port.
104 104 """
105 105
106 106 if not request.config.getoption('with_vcsserver'):
107 107 return None
108 108
109 109 return vcsserver_factory(
110 110 request, vcsserver_port=vcsserver_port)
111 111
112 112
113 113 @pytest.fixture(scope='session')
114 114 def vcsserver_factory(tmpdir_factory):
115 115 """
116 116 Use this if you need a running vcsserver with a special configuration.
117 117 """
118 118
119 119 def factory(request, overrides=(), vcsserver_port=None,
120 120 log_file=None):
121 121
122 122 if vcsserver_port is None:
123 123 vcsserver_port = get_available_port()
124 124
125 125 overrides = list(overrides)
126 126 overrides.append({'server:main': {'port': vcsserver_port}})
127 127
128 128 option_name = 'vcsserver_config_http'
129 129 override_option_name = 'vcsserver_config_override'
130 130 config_file = get_config(
131 131 request.config, option_name=option_name,
132 132 override_option_name=override_option_name, overrides=overrides,
133 133 basetemp=tmpdir_factory.getbasetemp().strpath,
134 134 prefix='test_vcs_')
135 135
136 136 server = RcVCSServer(config_file, log_file)
137 137 server.start()
138 138
139 139 @request.addfinalizer
140 140 def cleanup():
141 141 server.shutdown()
142 142
143 143 server.wait_until_ready()
144 144 return server
145 145
146 146 return factory
147 147
148 148
149 149 def is_cygwin():
150 150 return 'cygwin' in platform.system().lower()
151 151
152 152
153 153 def _use_log_level(config):
154 154 level = config.getoption('test_loglevel') or 'critical'
155 155 return level.upper()
156 156
157 157
158 158 @pytest.fixture(scope='session')
159 159 def ini_config(request, tmpdir_factory, rcserver_port, vcsserver_port):
160 160 option_name = 'pyramid_config'
161 161 log_level = _use_log_level(request.config)
162 162
163 163 overrides = [
164 164 {'server:main': {'port': rcserver_port}},
165 165 {'app:main': {
166 166 'vcs.server': 'localhost:%s' % vcsserver_port,
167 167 # johbo: We will always start the VCSServer on our own based on the
168 168 # fixtures of the test cases. For the test run it must always be
169 169 # off in the INI file.
170 170 'vcs.start_server': 'false',
171 171
172 172 'vcs.server.protocol': 'http',
173 173 'vcs.scm_app_implementation': 'http',
174 174 'vcs.hooks.protocol': 'http',
175 175 'vcs.hooks.host': '127.0.0.1',
176 176 }},
177 177
178 178 {'handler_console': {
179 179 'class ': 'StreamHandler',
180 180 'args ': '(sys.stderr,)',
181 181 'level': log_level,
182 182 }},
183 183
184 184 ]
185 185
186 186 filename = get_config(
187 187 request.config, option_name=option_name,
188 188 override_option_name='{}_override'.format(option_name),
189 189 overrides=overrides,
190 190 basetemp=tmpdir_factory.getbasetemp().strpath,
191 191 prefix='test_rce_')
192 192 return filename
193 193
194 194
195 195 @pytest.fixture(scope='session')
196 196 def ini_settings(ini_config):
197 197 ini_path = ini_config
198 198 return get_app_config(ini_path)
199 199
200 200
201 201 def get_available_port(min_port=40000, max_port=55555):
202 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
203 hostname = '127.0.0.1'
204
205 for _ in range(min_port, max_port):
206 pick_port = random.randint(min_port, max_port)
207 try:
208 sock.bind((hostname, pick_port))
209 sock.close()
210 del sock
211 return pick_port
212 except OSError:
213 pass
202 from rhodecode.lib.utils2 import get_available_port as _get_port
203 return _get_port(min_port, max_port)
214 204
215 205
216 206 @pytest.fixture(scope='session')
217 207 def rcserver_port(request):
218 208 port = get_available_port()
219 209 print('Using rcserver port {}'.format(port))
220 210 return port
221 211
222 212
223 213 @pytest.fixture(scope='session')
224 214 def vcsserver_port(request):
225 215 port = request.config.getoption('--vcsserver-port')
226 216 if port is None:
227 217 port = get_available_port()
228 218 print('Using vcsserver port {}'.format(port))
229 219 return port
230 220
231 221
232 222 @pytest.fixture(scope='session')
233 223 def available_port_factory():
234 224 """
235 225 Returns a callable which returns free port numbers.
236 226 """
237 227 return get_available_port
238 228
239 229
240 230 @pytest.fixture()
241 231 def available_port(available_port_factory):
242 232 """
243 233 Gives you one free port for the current test.
244 234
245 235 Uses "available_port_factory" to retrieve the port.
246 236 """
247 237 return available_port_factory()
248 238
249 239
250 240 @pytest.fixture(scope='session')
251 241 def testini_factory(tmpdir_factory, ini_config):
252 242 """
253 243 Factory to create an INI file based on TestINI.
254 244
255 245 It will make sure to place the INI file in the correct directory.
256 246 """
257 247 basetemp = tmpdir_factory.getbasetemp().strpath
258 248 return TestIniFactory(basetemp, ini_config)
259 249
260 250
261 251 class TestIniFactory(object):
262 252
263 253 def __init__(self, basetemp, template_ini):
264 254 self._basetemp = basetemp
265 255 self._template_ini = template_ini
266 256
267 257 def __call__(self, ini_params, new_file_prefix='test'):
268 258 ini_file = TestINI(
269 259 self._template_ini, ini_params=ini_params,
270 260 new_file_prefix=new_file_prefix, dir=self._basetemp)
271 261 result = ini_file.create()
272 262 return result
273 263
274 264
275 265 def get_config(
276 266 config, option_name, override_option_name, overrides=None,
277 267 basetemp=None, prefix='test'):
278 268 """
279 269 Find a configuration file and apply overrides for the given `prefix`.
280 270 """
281 271 config_file = (
282 272 config.getoption(option_name) or config.getini(option_name))
283 273 if not config_file:
284 274 pytest.exit(
285 275 "Configuration error, could not extract {}.".format(option_name))
286 276
287 277 overrides = overrides or []
288 278 config_override = config.getoption(override_option_name)
289 279 if config_override:
290 280 overrides.append(config_override)
291 281 temp_ini_file = TestINI(
292 282 config_file, ini_params=overrides, new_file_prefix=prefix,
293 283 dir=basetemp)
294 284
295 285 return temp_ini_file.create()
General Comments 0
You need to be logged in to leave comments. Login now