##// END OF EJS Templates
pick_port: unified code for testing/hooks
super-admin -
r4866:6b029be8 default
parent child Browse files
Show More
@@ -1,357 +1,347 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import os
21 import os
22 import time
22 import time
23 import logging
23 import logging
24 import tempfile
24 import tempfile
25 import traceback
25 import traceback
26 import threading
26 import threading
27 import socket
27 import socket
28 import random
28 import random
29
29
30 from BaseHTTPServer import BaseHTTPRequestHandler
30 from BaseHTTPServer import BaseHTTPRequestHandler
31 from SocketServer import TCPServer
31 from SocketServer import TCPServer
32
32
33 import rhodecode
33 import rhodecode
34 from rhodecode.lib.exceptions import HTTPLockedRC, HTTPBranchProtected
34 from rhodecode.lib.exceptions import HTTPLockedRC, HTTPBranchProtected
35 from rhodecode.model import meta
35 from rhodecode.model import meta
36 from rhodecode.lib.base import bootstrap_request, bootstrap_config
36 from rhodecode.lib.base import bootstrap_request, bootstrap_config
37 from rhodecode.lib import hooks_base
37 from rhodecode.lib import hooks_base
38 from rhodecode.lib.utils2 import AttributeDict
38 from rhodecode.lib.utils2 import AttributeDict
39 from rhodecode.lib.ext_json import json
39 from rhodecode.lib.ext_json import json
40 from rhodecode.lib import rc_cache
40 from rhodecode.lib import rc_cache
41
41
42 log = logging.getLogger(__name__)
42 log = logging.getLogger(__name__)
43
43
44
44
45 class HooksHttpHandler(BaseHTTPRequestHandler):
45 class HooksHttpHandler(BaseHTTPRequestHandler):
46
46
47 def do_POST(self):
47 def do_POST(self):
48 method, extras = self._read_request()
48 method, extras = self._read_request()
49 txn_id = getattr(self.server, 'txn_id', None)
49 txn_id = getattr(self.server, 'txn_id', None)
50 if txn_id:
50 if txn_id:
51 log.debug('Computing TXN_ID based on `%s`:`%s`',
51 log.debug('Computing TXN_ID based on `%s`:`%s`',
52 extras['repository'], extras['txn_id'])
52 extras['repository'], extras['txn_id'])
53 computed_txn_id = rc_cache.utils.compute_key_from_params(
53 computed_txn_id = rc_cache.utils.compute_key_from_params(
54 extras['repository'], extras['txn_id'])
54 extras['repository'], extras['txn_id'])
55 if txn_id != computed_txn_id:
55 if txn_id != computed_txn_id:
56 raise Exception(
56 raise Exception(
57 'TXN ID fail: expected {} got {} instead'.format(
57 'TXN ID fail: expected {} got {} instead'.format(
58 txn_id, computed_txn_id))
58 txn_id, computed_txn_id))
59
59
60 try:
60 try:
61 result = self._call_hook(method, extras)
61 result = self._call_hook(method, extras)
62 except Exception as e:
62 except Exception as e:
63 exc_tb = traceback.format_exc()
63 exc_tb = traceback.format_exc()
64 result = {
64 result = {
65 'exception': e.__class__.__name__,
65 'exception': e.__class__.__name__,
66 'exception_traceback': exc_tb,
66 'exception_traceback': exc_tb,
67 'exception_args': e.args
67 'exception_args': e.args
68 }
68 }
69 self._write_response(result)
69 self._write_response(result)
70
70
71 def _read_request(self):
71 def _read_request(self):
72 length = int(self.headers['Content-Length'])
72 length = int(self.headers['Content-Length'])
73 body = self.rfile.read(length).decode('utf-8')
73 body = self.rfile.read(length).decode('utf-8')
74 data = json.loads(body)
74 data = json.loads(body)
75 return data['method'], data['extras']
75 return data['method'], data['extras']
76
76
77 def _write_response(self, result):
77 def _write_response(self, result):
78 self.send_response(200)
78 self.send_response(200)
79 self.send_header("Content-type", "text/json")
79 self.send_header("Content-type", "text/json")
80 self.end_headers()
80 self.end_headers()
81 self.wfile.write(json.dumps(result))
81 self.wfile.write(json.dumps(result))
82
82
83 def _call_hook(self, method, extras):
83 def _call_hook(self, method, extras):
84 hooks = Hooks()
84 hooks = Hooks()
85 try:
85 try:
86 result = getattr(hooks, method)(extras)
86 result = getattr(hooks, method)(extras)
87 finally:
87 finally:
88 meta.Session.remove()
88 meta.Session.remove()
89 return result
89 return result
90
90
91 def log_message(self, format, *args):
91 def log_message(self, format, *args):
92 """
92 """
93 This is an overridden method of BaseHTTPRequestHandler which logs using
93 This is an overridden method of BaseHTTPRequestHandler which logs using
94 logging library instead of writing directly to stderr.
94 logging library instead of writing directly to stderr.
95 """
95 """
96
96
97 message = format % args
97 message = format % args
98
98
99 log.debug(
99 log.debug(
100 "%s - - [%s] %s", self.client_address[0],
100 "%s - - [%s] %s", self.client_address[0],
101 self.log_date_time_string(), message)
101 self.log_date_time_string(), message)
102
102
103
103
104 class DummyHooksCallbackDaemon(object):
104 class DummyHooksCallbackDaemon(object):
105 hooks_uri = ''
105 hooks_uri = ''
106
106
107 def __init__(self):
107 def __init__(self):
108 self.hooks_module = Hooks.__module__
108 self.hooks_module = Hooks.__module__
109
109
110 def __enter__(self):
110 def __enter__(self):
111 log.debug('Running `%s` callback daemon', self.__class__.__name__)
111 log.debug('Running `%s` callback daemon', self.__class__.__name__)
112 return self
112 return self
113
113
114 def __exit__(self, exc_type, exc_val, exc_tb):
114 def __exit__(self, exc_type, exc_val, exc_tb):
115 log.debug('Exiting `%s` callback daemon', self.__class__.__name__)
115 log.debug('Exiting `%s` callback daemon', self.__class__.__name__)
116
116
117
117
118 class ThreadedHookCallbackDaemon(object):
118 class ThreadedHookCallbackDaemon(object):
119
119
120 _callback_thread = None
120 _callback_thread = None
121 _daemon = None
121 _daemon = None
122 _done = False
122 _done = False
123
123
124 def __init__(self, txn_id=None, host=None, port=None):
124 def __init__(self, txn_id=None, host=None, port=None):
125 self._prepare(txn_id=txn_id, host=host, port=port)
125 self._prepare(txn_id=txn_id, host=host, port=port)
126
126
127 def __enter__(self):
127 def __enter__(self):
128 log.debug('Running `%s` callback daemon', self.__class__.__name__)
128 log.debug('Running `%s` callback daemon', self.__class__.__name__)
129 self._run()
129 self._run()
130 return self
130 return self
131
131
132 def __exit__(self, exc_type, exc_val, exc_tb):
132 def __exit__(self, exc_type, exc_val, exc_tb):
133 log.debug('Exiting `%s` callback daemon', self.__class__.__name__)
133 log.debug('Exiting `%s` callback daemon', self.__class__.__name__)
134 self._stop()
134 self._stop()
135
135
136 def _prepare(self, txn_id=None, host=None, port=None):
136 def _prepare(self, txn_id=None, host=None, port=None):
137 raise NotImplementedError()
137 raise NotImplementedError()
138
138
139 def _run(self):
139 def _run(self):
140 raise NotImplementedError()
140 raise NotImplementedError()
141
141
142 def _stop(self):
142 def _stop(self):
143 raise NotImplementedError()
143 raise NotImplementedError()
144
144
145
145
146 class HttpHooksCallbackDaemon(ThreadedHookCallbackDaemon):
146 class HttpHooksCallbackDaemon(ThreadedHookCallbackDaemon):
147 """
147 """
148 Context manager which will run a callback daemon in a background thread.
148 Context manager which will run a callback daemon in a background thread.
149 """
149 """
150
150
151 hooks_uri = None
151 hooks_uri = None
152
152
153 # From Python docs: Polling reduces our responsiveness to a shutdown
153 # From Python docs: Polling reduces our responsiveness to a shutdown
154 # request and wastes cpu at all other times.
154 # request and wastes cpu at all other times.
155 POLL_INTERVAL = 0.01
155 POLL_INTERVAL = 0.01
156
156
157 def get_hostname(self):
157 def get_hostname(self):
158 return socket.gethostname() or '127.0.0.1'
158 return socket.gethostname() or '127.0.0.1'
159
159
160 def get_available_port(self, min_port=20000, max_port=65535):
160 def get_available_port(self, min_port=20000, max_port=65535):
161 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
161 from rhodecode.lib.utils2 import get_available_port as _get_port
162 hostname = self.get_hostname()
162 return _get_port(min_port, max_port)
163
164 for _ in range(min_port, max_port):
165 pick_port = random.randint(min_port, max_port)
166 try:
167 sock.bind((hostname, pick_port))
168 sock.close()
169 del sock
170 return pick_port
171 except OSError:
172 pass
173
163
174 def _prepare(self, txn_id=None, host=None, port=None):
164 def _prepare(self, txn_id=None, host=None, port=None):
175 if not host or host == "*":
165 if not host or host == "*":
176 host = self.get_hostname()
166 host = self.get_hostname()
177 if not port:
167 if not port:
178 port = self.get_available_port()
168 port = self.get_available_port()
179
169
180 server_address = (host, port)
170 server_address = (host, port)
181 self.hooks_uri = '{}:{}'.format(host, port)
171 self.hooks_uri = '{}:{}'.format(host, port)
182 self.txn_id = txn_id
172 self.txn_id = txn_id
183 self._done = False
173 self._done = False
184
174
185 log.debug(
175 log.debug(
186 "Preparing HTTP callback daemon at `%s` and registering hook object: %s",
176 "Preparing HTTP callback daemon at `%s` and registering hook object: %s",
187 self.hooks_uri, HooksHttpHandler)
177 self.hooks_uri, HooksHttpHandler)
188
178
189 self._daemon = TCPServer(server_address, HooksHttpHandler)
179 self._daemon = TCPServer(server_address, HooksHttpHandler)
190 # inject transaction_id for later verification
180 # inject transaction_id for later verification
191 self._daemon.txn_id = self.txn_id
181 self._daemon.txn_id = self.txn_id
192
182
193 def _run(self):
183 def _run(self):
194 log.debug("Running event loop of callback daemon in background thread")
184 log.debug("Running event loop of callback daemon in background thread")
195 callback_thread = threading.Thread(
185 callback_thread = threading.Thread(
196 target=self._daemon.serve_forever,
186 target=self._daemon.serve_forever,
197 kwargs={'poll_interval': self.POLL_INTERVAL})
187 kwargs={'poll_interval': self.POLL_INTERVAL})
198 callback_thread.daemon = True
188 callback_thread.daemon = True
199 callback_thread.start()
189 callback_thread.start()
200 self._callback_thread = callback_thread
190 self._callback_thread = callback_thread
201
191
202 def _stop(self):
192 def _stop(self):
203 log.debug("Waiting for background thread to finish.")
193 log.debug("Waiting for background thread to finish.")
204 self._daemon.shutdown()
194 self._daemon.shutdown()
205 self._callback_thread.join()
195 self._callback_thread.join()
206 self._daemon = None
196 self._daemon = None
207 self._callback_thread = None
197 self._callback_thread = None
208 if self.txn_id:
198 if self.txn_id:
209 txn_id_file = get_txn_id_data_path(self.txn_id)
199 txn_id_file = get_txn_id_data_path(self.txn_id)
210 log.debug('Cleaning up TXN ID %s', txn_id_file)
200 log.debug('Cleaning up TXN ID %s', txn_id_file)
211 if os.path.isfile(txn_id_file):
201 if os.path.isfile(txn_id_file):
212 os.remove(txn_id_file)
202 os.remove(txn_id_file)
213
203
214 log.debug("Background thread done.")
204 log.debug("Background thread done.")
215
205
216
206
217 def get_txn_id_data_path(txn_id):
207 def get_txn_id_data_path(txn_id):
218 import rhodecode
208 import rhodecode
219
209
220 root = rhodecode.CONFIG.get('cache_dir') or tempfile.gettempdir()
210 root = rhodecode.CONFIG.get('cache_dir') or tempfile.gettempdir()
221 final_dir = os.path.join(root, 'svn_txn_id')
211 final_dir = os.path.join(root, 'svn_txn_id')
222
212
223 if not os.path.isdir(final_dir):
213 if not os.path.isdir(final_dir):
224 os.makedirs(final_dir)
214 os.makedirs(final_dir)
225 return os.path.join(final_dir, 'rc_txn_id_{}'.format(txn_id))
215 return os.path.join(final_dir, 'rc_txn_id_{}'.format(txn_id))
226
216
227
217
228 def store_txn_id_data(txn_id, data_dict):
218 def store_txn_id_data(txn_id, data_dict):
229 if not txn_id:
219 if not txn_id:
230 log.warning('Cannot store txn_id because it is empty')
220 log.warning('Cannot store txn_id because it is empty')
231 return
221 return
232
222
233 path = get_txn_id_data_path(txn_id)
223 path = get_txn_id_data_path(txn_id)
234 try:
224 try:
235 with open(path, 'wb') as f:
225 with open(path, 'wb') as f:
236 f.write(json.dumps(data_dict))
226 f.write(json.dumps(data_dict))
237 except Exception:
227 except Exception:
238 log.exception('Failed to write txn_id metadata')
228 log.exception('Failed to write txn_id metadata')
239
229
240
230
241 def get_txn_id_from_store(txn_id):
231 def get_txn_id_from_store(txn_id):
242 """
232 """
243 Reads txn_id from store and if present returns the data for callback manager
233 Reads txn_id from store and if present returns the data for callback manager
244 """
234 """
245 path = get_txn_id_data_path(txn_id)
235 path = get_txn_id_data_path(txn_id)
246 try:
236 try:
247 with open(path, 'rb') as f:
237 with open(path, 'rb') as f:
248 return json.loads(f.read())
238 return json.loads(f.read())
249 except Exception:
239 except Exception:
250 return {}
240 return {}
251
241
252
242
253 def prepare_callback_daemon(extras, protocol, host, use_direct_calls, txn_id=None):
243 def prepare_callback_daemon(extras, protocol, host, use_direct_calls, txn_id=None):
254 txn_details = get_txn_id_from_store(txn_id)
244 txn_details = get_txn_id_from_store(txn_id)
255 port = txn_details.get('port', 0)
245 port = txn_details.get('port', 0)
256 if use_direct_calls:
246 if use_direct_calls:
257 callback_daemon = DummyHooksCallbackDaemon()
247 callback_daemon = DummyHooksCallbackDaemon()
258 extras['hooks_module'] = callback_daemon.hooks_module
248 extras['hooks_module'] = callback_daemon.hooks_module
259 else:
249 else:
260 if protocol == 'http':
250 if protocol == 'http':
261 callback_daemon = HttpHooksCallbackDaemon(
251 callback_daemon = HttpHooksCallbackDaemon(
262 txn_id=txn_id, host=host, port=port)
252 txn_id=txn_id, host=host, port=port)
263 else:
253 else:
264 log.error('Unsupported callback daemon protocol "%s"', protocol)
254 log.error('Unsupported callback daemon protocol "%s"', protocol)
265 raise Exception('Unsupported callback daemon protocol.')
255 raise Exception('Unsupported callback daemon protocol.')
266
256
267 extras['hooks_uri'] = callback_daemon.hooks_uri
257 extras['hooks_uri'] = callback_daemon.hooks_uri
268 extras['hooks_protocol'] = protocol
258 extras['hooks_protocol'] = protocol
269 extras['time'] = time.time()
259 extras['time'] = time.time()
270
260
271 # register txn_id
261 # register txn_id
272 extras['txn_id'] = txn_id
262 extras['txn_id'] = txn_id
273 log.debug('Prepared a callback daemon: %s at url `%s`',
263 log.debug('Prepared a callback daemon: %s at url `%s`',
274 callback_daemon.__class__.__name__, callback_daemon.hooks_uri)
264 callback_daemon.__class__.__name__, callback_daemon.hooks_uri)
275 return callback_daemon, extras
265 return callback_daemon, extras
276
266
277
267
278 class Hooks(object):
268 class Hooks(object):
279 """
269 """
280 Exposes the hooks for remote call backs
270 Exposes the hooks for remote call backs
281 """
271 """
282
272
283 def repo_size(self, extras):
273 def repo_size(self, extras):
284 log.debug("Called repo_size of %s object", self)
274 log.debug("Called repo_size of %s object", self)
285 return self._call_hook(hooks_base.repo_size, extras)
275 return self._call_hook(hooks_base.repo_size, extras)
286
276
287 def pre_pull(self, extras):
277 def pre_pull(self, extras):
288 log.debug("Called pre_pull of %s object", self)
278 log.debug("Called pre_pull of %s object", self)
289 return self._call_hook(hooks_base.pre_pull, extras)
279 return self._call_hook(hooks_base.pre_pull, extras)
290
280
291 def post_pull(self, extras):
281 def post_pull(self, extras):
292 log.debug("Called post_pull of %s object", self)
282 log.debug("Called post_pull of %s object", self)
293 return self._call_hook(hooks_base.post_pull, extras)
283 return self._call_hook(hooks_base.post_pull, extras)
294
284
295 def pre_push(self, extras):
285 def pre_push(self, extras):
296 log.debug("Called pre_push of %s object", self)
286 log.debug("Called pre_push of %s object", self)
297 return self._call_hook(hooks_base.pre_push, extras)
287 return self._call_hook(hooks_base.pre_push, extras)
298
288
299 def post_push(self, extras):
289 def post_push(self, extras):
300 log.debug("Called post_push of %s object", self)
290 log.debug("Called post_push of %s object", self)
301 return self._call_hook(hooks_base.post_push, extras)
291 return self._call_hook(hooks_base.post_push, extras)
302
292
303 def _call_hook(self, hook, extras):
293 def _call_hook(self, hook, extras):
304 extras = AttributeDict(extras)
294 extras = AttributeDict(extras)
305 server_url = extras['server_url']
295 server_url = extras['server_url']
306 request = bootstrap_request(application_url=server_url)
296 request = bootstrap_request(application_url=server_url)
307
297
308 bootstrap_config(request) # inject routes and other interfaces
298 bootstrap_config(request) # inject routes and other interfaces
309
299
310 # inject the user for usage in hooks
300 # inject the user for usage in hooks
311 request.user = AttributeDict({'username': extras.username,
301 request.user = AttributeDict({'username': extras.username,
312 'ip_addr': extras.ip,
302 'ip_addr': extras.ip,
313 'user_id': extras.user_id})
303 'user_id': extras.user_id})
314
304
315 extras.request = request
305 extras.request = request
316
306
317 try:
307 try:
318 result = hook(extras)
308 result = hook(extras)
319 if result is None:
309 if result is None:
320 raise Exception(
310 raise Exception(
321 'Failed to obtain hook result from func: {}'.format(hook))
311 'Failed to obtain hook result from func: {}'.format(hook))
322 except HTTPBranchProtected as handled_error:
312 except HTTPBranchProtected as handled_error:
323 # Those special cases doesn't need error reporting. It's a case of
313 # Those special cases doesn't need error reporting. It's a case of
324 # locked repo or protected branch
314 # locked repo or protected branch
325 result = AttributeDict({
315 result = AttributeDict({
326 'status': handled_error.code,
316 'status': handled_error.code,
327 'output': handled_error.explanation
317 'output': handled_error.explanation
328 })
318 })
329 except (HTTPLockedRC, Exception) as error:
319 except (HTTPLockedRC, Exception) as error:
330 # locked needs different handling since we need to also
320 # locked needs different handling since we need to also
331 # handle PULL operations
321 # handle PULL operations
332 exc_tb = ''
322 exc_tb = ''
333 if not isinstance(error, HTTPLockedRC):
323 if not isinstance(error, HTTPLockedRC):
334 exc_tb = traceback.format_exc()
324 exc_tb = traceback.format_exc()
335 log.exception('Exception when handling hook %s', hook)
325 log.exception('Exception when handling hook %s', hook)
336 error_args = error.args
326 error_args = error.args
337 return {
327 return {
338 'status': 128,
328 'status': 128,
339 'output': '',
329 'output': '',
340 'exception': type(error).__name__,
330 'exception': type(error).__name__,
341 'exception_traceback': exc_tb,
331 'exception_traceback': exc_tb,
342 'exception_args': error_args,
332 'exception_args': error_args,
343 }
333 }
344 finally:
334 finally:
345 meta.Session.remove()
335 meta.Session.remove()
346
336
347 log.debug('Got hook call response %s', result)
337 log.debug('Got hook call response %s', result)
348 return {
338 return {
349 'status': result.status,
339 'status': result.status,
350 'output': result.output,
340 'output': result.output,
351 }
341 }
352
342
353 def __enter__(self):
343 def __enter__(self):
354 return self
344 return self
355
345
356 def __exit__(self, exc_type, exc_val, exc_tb):
346 def __exit__(self, exc_type, exc_val, exc_tb):
357 pass
347 pass
@@ -1,1173 +1,1193 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2011-2020 RhodeCode GmbH
3 # Copyright (C) 2011-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21
21
22 """
22 """
23 Some simple helper functions
23 Some simple helper functions
24 """
24 """
25
25
26 import collections
26 import collections
27 import datetime
27 import datetime
28 import dateutil.relativedelta
28 import dateutil.relativedelta
29 import hashlib
29 import hashlib
30 import logging
30 import logging
31 import re
31 import re
32 import sys
32 import sys
33 import time
33 import time
34 import urllib
34 import urllib
35 import urlobject
35 import urlobject
36 import uuid
36 import uuid
37 import getpass
37 import getpass
38 import socket
39 import random
38 from functools import update_wrapper, partial, wraps
40 from functools import update_wrapper, partial, wraps
39
41
40 import pygments.lexers
42 import pygments.lexers
41 import sqlalchemy
43 import sqlalchemy
42 import sqlalchemy.engine.url
44 import sqlalchemy.engine.url
43 import sqlalchemy.exc
45 import sqlalchemy.exc
44 import sqlalchemy.sql
46 import sqlalchemy.sql
45 import webob
47 import webob
46 import pyramid.threadlocal
48 import pyramid.threadlocal
47 from pyramid import compat
49 from pyramid import compat
48 from pyramid.settings import asbool
50 from pyramid.settings import asbool
49
51
50 import rhodecode
52 import rhodecode
51 from rhodecode.translation import _, _pluralize
53 from rhodecode.translation import _, _pluralize
52
54
53
55
54 def md5(s):
56 def md5(s):
55 return hashlib.md5(s).hexdigest()
57 return hashlib.md5(s).hexdigest()
56
58
57
59
58 def md5_safe(s):
60 def md5_safe(s):
59 return md5(safe_str(s))
61 return md5(safe_str(s))
60
62
61
63
62 def sha1(s):
64 def sha1(s):
63 return hashlib.sha1(s).hexdigest()
65 return hashlib.sha1(s).hexdigest()
64
66
65
67
66 def sha1_safe(s):
68 def sha1_safe(s):
67 return sha1(safe_str(s))
69 return sha1(safe_str(s))
68
70
69
71
70 def __get_lem(extra_mapping=None):
72 def __get_lem(extra_mapping=None):
71 """
73 """
72 Get language extension map based on what's inside pygments lexers
74 Get language extension map based on what's inside pygments lexers
73 """
75 """
74 d = collections.defaultdict(lambda: [])
76 d = collections.defaultdict(lambda: [])
75
77
76 def __clean(s):
78 def __clean(s):
77 s = s.lstrip('*')
79 s = s.lstrip('*')
78 s = s.lstrip('.')
80 s = s.lstrip('.')
79
81
80 if s.find('[') != -1:
82 if s.find('[') != -1:
81 exts = []
83 exts = []
82 start, stop = s.find('['), s.find(']')
84 start, stop = s.find('['), s.find(']')
83
85
84 for suffix in s[start + 1:stop]:
86 for suffix in s[start + 1:stop]:
85 exts.append(s[:s.find('[')] + suffix)
87 exts.append(s[:s.find('[')] + suffix)
86 return [e.lower() for e in exts]
88 return [e.lower() for e in exts]
87 else:
89 else:
88 return [s.lower()]
90 return [s.lower()]
89
91
90 for lx, t in sorted(pygments.lexers.LEXERS.items()):
92 for lx, t in sorted(pygments.lexers.LEXERS.items()):
91 m = map(__clean, t[-2])
93 m = map(__clean, t[-2])
92 if m:
94 if m:
93 m = reduce(lambda x, y: x + y, m)
95 m = reduce(lambda x, y: x + y, m)
94 for ext in m:
96 for ext in m:
95 desc = lx.replace('Lexer', '')
97 desc = lx.replace('Lexer', '')
96 d[ext].append(desc)
98 d[ext].append(desc)
97
99
98 data = dict(d)
100 data = dict(d)
99
101
100 extra_mapping = extra_mapping or {}
102 extra_mapping = extra_mapping or {}
101 if extra_mapping:
103 if extra_mapping:
102 for k, v in extra_mapping.items():
104 for k, v in extra_mapping.items():
103 if k not in data:
105 if k not in data:
104 # register new mapping2lexer
106 # register new mapping2lexer
105 data[k] = [v]
107 data[k] = [v]
106
108
107 return data
109 return data
108
110
109
111
110 def str2bool(_str):
112 def str2bool(_str):
111 """
113 """
112 returns True/False value from given string, it tries to translate the
114 returns True/False value from given string, it tries to translate the
113 string into boolean
115 string into boolean
114
116
115 :param _str: string value to translate into boolean
117 :param _str: string value to translate into boolean
116 :rtype: boolean
118 :rtype: boolean
117 :returns: boolean from given string
119 :returns: boolean from given string
118 """
120 """
119 if _str is None:
121 if _str is None:
120 return False
122 return False
121 if _str in (True, False):
123 if _str in (True, False):
122 return _str
124 return _str
123 _str = str(_str).strip().lower()
125 _str = str(_str).strip().lower()
124 return _str in ('t', 'true', 'y', 'yes', 'on', '1')
126 return _str in ('t', 'true', 'y', 'yes', 'on', '1')
125
127
126
128
127 def aslist(obj, sep=None, strip=True):
129 def aslist(obj, sep=None, strip=True):
128 """
130 """
129 Returns given string separated by sep as list
131 Returns given string separated by sep as list
130
132
131 :param obj:
133 :param obj:
132 :param sep:
134 :param sep:
133 :param strip:
135 :param strip:
134 """
136 """
135 if isinstance(obj, (basestring,)):
137 if isinstance(obj, (basestring,)):
136 lst = obj.split(sep)
138 lst = obj.split(sep)
137 if strip:
139 if strip:
138 lst = [v.strip() for v in lst]
140 lst = [v.strip() for v in lst]
139 return lst
141 return lst
140 elif isinstance(obj, (list, tuple)):
142 elif isinstance(obj, (list, tuple)):
141 return obj
143 return obj
142 elif obj is None:
144 elif obj is None:
143 return []
145 return []
144 else:
146 else:
145 return [obj]
147 return [obj]
146
148
147
149
148 def convert_line_endings(line, mode):
150 def convert_line_endings(line, mode):
149 """
151 """
150 Converts a given line "line end" accordingly to given mode
152 Converts a given line "line end" accordingly to given mode
151
153
152 Available modes are::
154 Available modes are::
153 0 - Unix
155 0 - Unix
154 1 - Mac
156 1 - Mac
155 2 - DOS
157 2 - DOS
156
158
157 :param line: given line to convert
159 :param line: given line to convert
158 :param mode: mode to convert to
160 :param mode: mode to convert to
159 :rtype: str
161 :rtype: str
160 :return: converted line according to mode
162 :return: converted line according to mode
161 """
163 """
162 if mode == 0:
164 if mode == 0:
163 line = line.replace('\r\n', '\n')
165 line = line.replace('\r\n', '\n')
164 line = line.replace('\r', '\n')
166 line = line.replace('\r', '\n')
165 elif mode == 1:
167 elif mode == 1:
166 line = line.replace('\r\n', '\r')
168 line = line.replace('\r\n', '\r')
167 line = line.replace('\n', '\r')
169 line = line.replace('\n', '\r')
168 elif mode == 2:
170 elif mode == 2:
169 line = re.sub('\r(?!\n)|(?<!\r)\n', '\r\n', line)
171 line = re.sub('\r(?!\n)|(?<!\r)\n', '\r\n', line)
170 return line
172 return line
171
173
172
174
173 def detect_mode(line, default):
175 def detect_mode(line, default):
174 """
176 """
175 Detects line break for given line, if line break couldn't be found
177 Detects line break for given line, if line break couldn't be found
176 given default value is returned
178 given default value is returned
177
179
178 :param line: str line
180 :param line: str line
179 :param default: default
181 :param default: default
180 :rtype: int
182 :rtype: int
181 :return: value of line end on of 0 - Unix, 1 - Mac, 2 - DOS
183 :return: value of line end on of 0 - Unix, 1 - Mac, 2 - DOS
182 """
184 """
183 if line.endswith('\r\n'):
185 if line.endswith('\r\n'):
184 return 2
186 return 2
185 elif line.endswith('\n'):
187 elif line.endswith('\n'):
186 return 0
188 return 0
187 elif line.endswith('\r'):
189 elif line.endswith('\r'):
188 return 1
190 return 1
189 else:
191 else:
190 return default
192 return default
191
193
192
194
193 def safe_int(val, default=None):
195 def safe_int(val, default=None):
194 """
196 """
195 Returns int() of val if val is not convertable to int use default
197 Returns int() of val if val is not convertable to int use default
196 instead
198 instead
197
199
198 :param val:
200 :param val:
199 :param default:
201 :param default:
200 """
202 """
201
203
202 try:
204 try:
203 val = int(val)
205 val = int(val)
204 except (ValueError, TypeError):
206 except (ValueError, TypeError):
205 val = default
207 val = default
206
208
207 return val
209 return val
208
210
209
211
210 def safe_unicode(str_, from_encoding=None, use_chardet=False):
212 def safe_unicode(str_, from_encoding=None, use_chardet=False):
211 """
213 """
212 safe unicode function. Does few trick to turn str_ into unicode
214 safe unicode function. Does few trick to turn str_ into unicode
213
215
214 In case of UnicodeDecode error, we try to return it with encoding detected
216 In case of UnicodeDecode error, we try to return it with encoding detected
215 by chardet library if it fails fallback to unicode with errors replaced
217 by chardet library if it fails fallback to unicode with errors replaced
216
218
217 :param str_: string to decode
219 :param str_: string to decode
218 :rtype: unicode
220 :rtype: unicode
219 :returns: unicode object
221 :returns: unicode object
220 """
222 """
221 if isinstance(str_, unicode):
223 if isinstance(str_, unicode):
222 return str_
224 return str_
223
225
224 if not from_encoding:
226 if not from_encoding:
225 DEFAULT_ENCODINGS = aslist(rhodecode.CONFIG.get('default_encoding',
227 DEFAULT_ENCODINGS = aslist(rhodecode.CONFIG.get('default_encoding',
226 'utf8'), sep=',')
228 'utf8'), sep=',')
227 from_encoding = DEFAULT_ENCODINGS
229 from_encoding = DEFAULT_ENCODINGS
228
230
229 if not isinstance(from_encoding, (list, tuple)):
231 if not isinstance(from_encoding, (list, tuple)):
230 from_encoding = [from_encoding]
232 from_encoding = [from_encoding]
231
233
232 try:
234 try:
233 return unicode(str_)
235 return unicode(str_)
234 except UnicodeDecodeError:
236 except UnicodeDecodeError:
235 pass
237 pass
236
238
237 for enc in from_encoding:
239 for enc in from_encoding:
238 try:
240 try:
239 return unicode(str_, enc)
241 return unicode(str_, enc)
240 except UnicodeDecodeError:
242 except UnicodeDecodeError:
241 pass
243 pass
242
244
243 if use_chardet:
245 if use_chardet:
244 try:
246 try:
245 import chardet
247 import chardet
246 encoding = chardet.detect(str_)['encoding']
248 encoding = chardet.detect(str_)['encoding']
247 if encoding is None:
249 if encoding is None:
248 raise Exception()
250 raise Exception()
249 return str_.decode(encoding)
251 return str_.decode(encoding)
250 except (ImportError, UnicodeDecodeError, Exception):
252 except (ImportError, UnicodeDecodeError, Exception):
251 return unicode(str_, from_encoding[0], 'replace')
253 return unicode(str_, from_encoding[0], 'replace')
252 else:
254 else:
253 return unicode(str_, from_encoding[0], 'replace')
255 return unicode(str_, from_encoding[0], 'replace')
254
256
255 def safe_str(unicode_, to_encoding=None, use_chardet=False):
257 def safe_str(unicode_, to_encoding=None, use_chardet=False):
256 """
258 """
257 safe str function. Does few trick to turn unicode_ into string
259 safe str function. Does few trick to turn unicode_ into string
258
260
259 In case of UnicodeEncodeError, we try to return it with encoding detected
261 In case of UnicodeEncodeError, we try to return it with encoding detected
260 by chardet library if it fails fallback to string with errors replaced
262 by chardet library if it fails fallback to string with errors replaced
261
263
262 :param unicode_: unicode to encode
264 :param unicode_: unicode to encode
263 :rtype: str
265 :rtype: str
264 :returns: str object
266 :returns: str object
265 """
267 """
266
268
267 # if it's not basestr cast to str
269 # if it's not basestr cast to str
268 if not isinstance(unicode_, compat.string_types):
270 if not isinstance(unicode_, compat.string_types):
269 return str(unicode_)
271 return str(unicode_)
270
272
271 if isinstance(unicode_, str):
273 if isinstance(unicode_, str):
272 return unicode_
274 return unicode_
273
275
274 if not to_encoding:
276 if not to_encoding:
275 DEFAULT_ENCODINGS = aslist(rhodecode.CONFIG.get('default_encoding',
277 DEFAULT_ENCODINGS = aslist(rhodecode.CONFIG.get('default_encoding',
276 'utf8'), sep=',')
278 'utf8'), sep=',')
277 to_encoding = DEFAULT_ENCODINGS
279 to_encoding = DEFAULT_ENCODINGS
278
280
279 if not isinstance(to_encoding, (list, tuple)):
281 if not isinstance(to_encoding, (list, tuple)):
280 to_encoding = [to_encoding]
282 to_encoding = [to_encoding]
281
283
282 for enc in to_encoding:
284 for enc in to_encoding:
283 try:
285 try:
284 return unicode_.encode(enc)
286 return unicode_.encode(enc)
285 except UnicodeEncodeError:
287 except UnicodeEncodeError:
286 pass
288 pass
287
289
288 if use_chardet:
290 if use_chardet:
289 try:
291 try:
290 import chardet
292 import chardet
291 encoding = chardet.detect(unicode_)['encoding']
293 encoding = chardet.detect(unicode_)['encoding']
292 if encoding is None:
294 if encoding is None:
293 raise UnicodeEncodeError()
295 raise UnicodeEncodeError()
294
296
295 return unicode_.encode(encoding)
297 return unicode_.encode(encoding)
296 except (ImportError, UnicodeEncodeError):
298 except (ImportError, UnicodeEncodeError):
297 return unicode_.encode(to_encoding[0], 'replace')
299 return unicode_.encode(to_encoding[0], 'replace')
298 else:
300 else:
299 return unicode_.encode(to_encoding[0], 'replace')
301 return unicode_.encode(to_encoding[0], 'replace')
300
302
301
303
302 def remove_suffix(s, suffix):
304 def remove_suffix(s, suffix):
303 if s.endswith(suffix):
305 if s.endswith(suffix):
304 s = s[:-1 * len(suffix)]
306 s = s[:-1 * len(suffix)]
305 return s
307 return s
306
308
307
309
308 def remove_prefix(s, prefix):
310 def remove_prefix(s, prefix):
309 if s.startswith(prefix):
311 if s.startswith(prefix):
310 s = s[len(prefix):]
312 s = s[len(prefix):]
311 return s
313 return s
312
314
313
315
314 def find_calling_context(ignore_modules=None):
316 def find_calling_context(ignore_modules=None):
315 """
317 """
316 Look through the calling stack and return the frame which called
318 Look through the calling stack and return the frame which called
317 this function and is part of core module ( ie. rhodecode.* )
319 this function and is part of core module ( ie. rhodecode.* )
318
320
319 :param ignore_modules: list of modules to ignore eg. ['rhodecode.lib']
321 :param ignore_modules: list of modules to ignore eg. ['rhodecode.lib']
320 """
322 """
321
323
322 ignore_modules = ignore_modules or []
324 ignore_modules = ignore_modules or []
323
325
324 f = sys._getframe(2)
326 f = sys._getframe(2)
325 while f.f_back is not None:
327 while f.f_back is not None:
326 name = f.f_globals.get('__name__')
328 name = f.f_globals.get('__name__')
327 if name and name.startswith(__name__.split('.')[0]):
329 if name and name.startswith(__name__.split('.')[0]):
328 if name not in ignore_modules:
330 if name not in ignore_modules:
329 return f
331 return f
330 f = f.f_back
332 f = f.f_back
331 return None
333 return None
332
334
333
335
334 def ping_connection(connection, branch):
336 def ping_connection(connection, branch):
335 if branch:
337 if branch:
336 # "branch" refers to a sub-connection of a connection,
338 # "branch" refers to a sub-connection of a connection,
337 # we don't want to bother pinging on these.
339 # we don't want to bother pinging on these.
338 return
340 return
339
341
340 # turn off "close with result". This flag is only used with
342 # turn off "close with result". This flag is only used with
341 # "connectionless" execution, otherwise will be False in any case
343 # "connectionless" execution, otherwise will be False in any case
342 save_should_close_with_result = connection.should_close_with_result
344 save_should_close_with_result = connection.should_close_with_result
343 connection.should_close_with_result = False
345 connection.should_close_with_result = False
344
346
345 try:
347 try:
346 # run a SELECT 1. use a core select() so that
348 # run a SELECT 1. use a core select() so that
347 # the SELECT of a scalar value without a table is
349 # the SELECT of a scalar value without a table is
348 # appropriately formatted for the backend
350 # appropriately formatted for the backend
349 connection.scalar(sqlalchemy.sql.select([1]))
351 connection.scalar(sqlalchemy.sql.select([1]))
350 except sqlalchemy.exc.DBAPIError as err:
352 except sqlalchemy.exc.DBAPIError as err:
351 # catch SQLAlchemy's DBAPIError, which is a wrapper
353 # catch SQLAlchemy's DBAPIError, which is a wrapper
352 # for the DBAPI's exception. It includes a .connection_invalidated
354 # for the DBAPI's exception. It includes a .connection_invalidated
353 # attribute which specifies if this connection is a "disconnect"
355 # attribute which specifies if this connection is a "disconnect"
354 # condition, which is based on inspection of the original exception
356 # condition, which is based on inspection of the original exception
355 # by the dialect in use.
357 # by the dialect in use.
356 if err.connection_invalidated:
358 if err.connection_invalidated:
357 # run the same SELECT again - the connection will re-validate
359 # run the same SELECT again - the connection will re-validate
358 # itself and establish a new connection. The disconnect detection
360 # itself and establish a new connection. The disconnect detection
359 # here also causes the whole connection pool to be invalidated
361 # here also causes the whole connection pool to be invalidated
360 # so that all stale connections are discarded.
362 # so that all stale connections are discarded.
361 connection.scalar(sqlalchemy.sql.select([1]))
363 connection.scalar(sqlalchemy.sql.select([1]))
362 else:
364 else:
363 raise
365 raise
364 finally:
366 finally:
365 # restore "close with result"
367 # restore "close with result"
366 connection.should_close_with_result = save_should_close_with_result
368 connection.should_close_with_result = save_should_close_with_result
367
369
368
370
369 def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs):
371 def engine_from_config(configuration, prefix='sqlalchemy.', **kwargs):
370 """Custom engine_from_config functions."""
372 """Custom engine_from_config functions."""
371 log = logging.getLogger('sqlalchemy.engine')
373 log = logging.getLogger('sqlalchemy.engine')
372 use_ping_connection = asbool(configuration.pop('sqlalchemy.db1.ping_connection', None))
374 use_ping_connection = asbool(configuration.pop('sqlalchemy.db1.ping_connection', None))
373 debug = asbool(configuration.pop('sqlalchemy.db1.debug_query', None))
375 debug = asbool(configuration.pop('sqlalchemy.db1.debug_query', None))
374
376
375 engine = sqlalchemy.engine_from_config(configuration, prefix, **kwargs)
377 engine = sqlalchemy.engine_from_config(configuration, prefix, **kwargs)
376
378
377 def color_sql(sql):
379 def color_sql(sql):
378 color_seq = '\033[1;33m' # This is yellow: code 33
380 color_seq = '\033[1;33m' # This is yellow: code 33
379 normal = '\x1b[0m'
381 normal = '\x1b[0m'
380 return ''.join([color_seq, sql, normal])
382 return ''.join([color_seq, sql, normal])
381
383
382 if use_ping_connection:
384 if use_ping_connection:
383 log.debug('Adding ping_connection on the engine config.')
385 log.debug('Adding ping_connection on the engine config.')
384 sqlalchemy.event.listen(engine, "engine_connect", ping_connection)
386 sqlalchemy.event.listen(engine, "engine_connect", ping_connection)
385
387
386 if debug:
388 if debug:
387 # attach events only for debug configuration
389 # attach events only for debug configuration
388 def before_cursor_execute(conn, cursor, statement,
390 def before_cursor_execute(conn, cursor, statement,
389 parameters, context, executemany):
391 parameters, context, executemany):
390 setattr(conn, 'query_start_time', time.time())
392 setattr(conn, 'query_start_time', time.time())
391 log.info(color_sql(">>>>> STARTING QUERY >>>>>"))
393 log.info(color_sql(">>>>> STARTING QUERY >>>>>"))
392 calling_context = find_calling_context(ignore_modules=[
394 calling_context = find_calling_context(ignore_modules=[
393 'rhodecode.lib.caching_query',
395 'rhodecode.lib.caching_query',
394 'rhodecode.model.settings',
396 'rhodecode.model.settings',
395 ])
397 ])
396 if calling_context:
398 if calling_context:
397 log.info(color_sql('call context %s:%s' % (
399 log.info(color_sql('call context %s:%s' % (
398 calling_context.f_code.co_filename,
400 calling_context.f_code.co_filename,
399 calling_context.f_lineno,
401 calling_context.f_lineno,
400 )))
402 )))
401
403
402 def after_cursor_execute(conn, cursor, statement,
404 def after_cursor_execute(conn, cursor, statement,
403 parameters, context, executemany):
405 parameters, context, executemany):
404 delattr(conn, 'query_start_time')
406 delattr(conn, 'query_start_time')
405
407
406 sqlalchemy.event.listen(engine, "before_cursor_execute", before_cursor_execute)
408 sqlalchemy.event.listen(engine, "before_cursor_execute", before_cursor_execute)
407 sqlalchemy.event.listen(engine, "after_cursor_execute", after_cursor_execute)
409 sqlalchemy.event.listen(engine, "after_cursor_execute", after_cursor_execute)
408
410
409 return engine
411 return engine
410
412
411
413
412 def get_encryption_key(config):
414 def get_encryption_key(config):
413 secret = config.get('rhodecode.encrypted_values.secret')
415 secret = config.get('rhodecode.encrypted_values.secret')
414 default = config['beaker.session.secret']
416 default = config['beaker.session.secret']
415 return secret or default
417 return secret or default
416
418
417
419
418 def age(prevdate, now=None, show_short_version=False, show_suffix=True,
420 def age(prevdate, now=None, show_short_version=False, show_suffix=True,
419 short_format=False):
421 short_format=False):
420 """
422 """
421 Turns a datetime into an age string.
423 Turns a datetime into an age string.
422 If show_short_version is True, this generates a shorter string with
424 If show_short_version is True, this generates a shorter string with
423 an approximate age; ex. '1 day ago', rather than '1 day and 23 hours ago'.
425 an approximate age; ex. '1 day ago', rather than '1 day and 23 hours ago'.
424
426
425 * IMPORTANT*
427 * IMPORTANT*
426 Code of this function is written in special way so it's easier to
428 Code of this function is written in special way so it's easier to
427 backport it to javascript. If you mean to update it, please also update
429 backport it to javascript. If you mean to update it, please also update
428 `jquery.timeago-extension.js` file
430 `jquery.timeago-extension.js` file
429
431
430 :param prevdate: datetime object
432 :param prevdate: datetime object
431 :param now: get current time, if not define we use
433 :param now: get current time, if not define we use
432 `datetime.datetime.now()`
434 `datetime.datetime.now()`
433 :param show_short_version: if it should approximate the date and
435 :param show_short_version: if it should approximate the date and
434 return a shorter string
436 return a shorter string
435 :param show_suffix:
437 :param show_suffix:
436 :param short_format: show short format, eg 2D instead of 2 days
438 :param short_format: show short format, eg 2D instead of 2 days
437 :rtype: unicode
439 :rtype: unicode
438 :returns: unicode words describing age
440 :returns: unicode words describing age
439 """
441 """
440
442
441 def _get_relative_delta(now, prevdate):
443 def _get_relative_delta(now, prevdate):
442 base = dateutil.relativedelta.relativedelta(now, prevdate)
444 base = dateutil.relativedelta.relativedelta(now, prevdate)
443 return {
445 return {
444 'year': base.years,
446 'year': base.years,
445 'month': base.months,
447 'month': base.months,
446 'day': base.days,
448 'day': base.days,
447 'hour': base.hours,
449 'hour': base.hours,
448 'minute': base.minutes,
450 'minute': base.minutes,
449 'second': base.seconds,
451 'second': base.seconds,
450 }
452 }
451
453
452 def _is_leap_year(year):
454 def _is_leap_year(year):
453 return year % 4 == 0 and (year % 100 != 0 or year % 400 == 0)
455 return year % 4 == 0 and (year % 100 != 0 or year % 400 == 0)
454
456
455 def get_month(prevdate):
457 def get_month(prevdate):
456 return prevdate.month
458 return prevdate.month
457
459
458 def get_year(prevdate):
460 def get_year(prevdate):
459 return prevdate.year
461 return prevdate.year
460
462
461 now = now or datetime.datetime.now()
463 now = now or datetime.datetime.now()
462 order = ['year', 'month', 'day', 'hour', 'minute', 'second']
464 order = ['year', 'month', 'day', 'hour', 'minute', 'second']
463 deltas = {}
465 deltas = {}
464 future = False
466 future = False
465
467
466 if prevdate > now:
468 if prevdate > now:
467 now_old = now
469 now_old = now
468 now = prevdate
470 now = prevdate
469 prevdate = now_old
471 prevdate = now_old
470 future = True
472 future = True
471 if future:
473 if future:
472 prevdate = prevdate.replace(microsecond=0)
474 prevdate = prevdate.replace(microsecond=0)
473 # Get date parts deltas
475 # Get date parts deltas
474 for part in order:
476 for part in order:
475 rel_delta = _get_relative_delta(now, prevdate)
477 rel_delta = _get_relative_delta(now, prevdate)
476 deltas[part] = rel_delta[part]
478 deltas[part] = rel_delta[part]
477
479
478 # Fix negative offsets (there is 1 second between 10:59:59 and 11:00:00,
480 # Fix negative offsets (there is 1 second between 10:59:59 and 11:00:00,
479 # not 1 hour, -59 minutes and -59 seconds)
481 # not 1 hour, -59 minutes and -59 seconds)
480 offsets = [[5, 60], [4, 60], [3, 24]]
482 offsets = [[5, 60], [4, 60], [3, 24]]
481 for element in offsets: # seconds, minutes, hours
483 for element in offsets: # seconds, minutes, hours
482 num = element[0]
484 num = element[0]
483 length = element[1]
485 length = element[1]
484
486
485 part = order[num]
487 part = order[num]
486 carry_part = order[num - 1]
488 carry_part = order[num - 1]
487
489
488 if deltas[part] < 0:
490 if deltas[part] < 0:
489 deltas[part] += length
491 deltas[part] += length
490 deltas[carry_part] -= 1
492 deltas[carry_part] -= 1
491
493
492 # Same thing for days except that the increment depends on the (variable)
494 # Same thing for days except that the increment depends on the (variable)
493 # number of days in the month
495 # number of days in the month
494 month_lengths = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
496 month_lengths = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
495 if deltas['day'] < 0:
497 if deltas['day'] < 0:
496 if get_month(prevdate) == 2 and _is_leap_year(get_year(prevdate)):
498 if get_month(prevdate) == 2 and _is_leap_year(get_year(prevdate)):
497 deltas['day'] += 29
499 deltas['day'] += 29
498 else:
500 else:
499 deltas['day'] += month_lengths[get_month(prevdate) - 1]
501 deltas['day'] += month_lengths[get_month(prevdate) - 1]
500
502
501 deltas['month'] -= 1
503 deltas['month'] -= 1
502
504
503 if deltas['month'] < 0:
505 if deltas['month'] < 0:
504 deltas['month'] += 12
506 deltas['month'] += 12
505 deltas['year'] -= 1
507 deltas['year'] -= 1
506
508
507 # Format the result
509 # Format the result
508 if short_format:
510 if short_format:
509 fmt_funcs = {
511 fmt_funcs = {
510 'year': lambda d: u'%dy' % d,
512 'year': lambda d: u'%dy' % d,
511 'month': lambda d: u'%dm' % d,
513 'month': lambda d: u'%dm' % d,
512 'day': lambda d: u'%dd' % d,
514 'day': lambda d: u'%dd' % d,
513 'hour': lambda d: u'%dh' % d,
515 'hour': lambda d: u'%dh' % d,
514 'minute': lambda d: u'%dmin' % d,
516 'minute': lambda d: u'%dmin' % d,
515 'second': lambda d: u'%dsec' % d,
517 'second': lambda d: u'%dsec' % d,
516 }
518 }
517 else:
519 else:
518 fmt_funcs = {
520 fmt_funcs = {
519 'year': lambda d: _pluralize(u'${num} year', u'${num} years', d, mapping={'num': d}).interpolate(),
521 'year': lambda d: _pluralize(u'${num} year', u'${num} years', d, mapping={'num': d}).interpolate(),
520 'month': lambda d: _pluralize(u'${num} month', u'${num} months', d, mapping={'num': d}).interpolate(),
522 'month': lambda d: _pluralize(u'${num} month', u'${num} months', d, mapping={'num': d}).interpolate(),
521 'day': lambda d: _pluralize(u'${num} day', u'${num} days', d, mapping={'num': d}).interpolate(),
523 'day': lambda d: _pluralize(u'${num} day', u'${num} days', d, mapping={'num': d}).interpolate(),
522 'hour': lambda d: _pluralize(u'${num} hour', u'${num} hours', d, mapping={'num': d}).interpolate(),
524 'hour': lambda d: _pluralize(u'${num} hour', u'${num} hours', d, mapping={'num': d}).interpolate(),
523 'minute': lambda d: _pluralize(u'${num} minute', u'${num} minutes', d, mapping={'num': d}).interpolate(),
525 'minute': lambda d: _pluralize(u'${num} minute', u'${num} minutes', d, mapping={'num': d}).interpolate(),
524 'second': lambda d: _pluralize(u'${num} second', u'${num} seconds', d, mapping={'num': d}).interpolate(),
526 'second': lambda d: _pluralize(u'${num} second', u'${num} seconds', d, mapping={'num': d}).interpolate(),
525 }
527 }
526
528
527 i = 0
529 i = 0
528 for part in order:
530 for part in order:
529 value = deltas[part]
531 value = deltas[part]
530 if value != 0:
532 if value != 0:
531
533
532 if i < 5:
534 if i < 5:
533 sub_part = order[i + 1]
535 sub_part = order[i + 1]
534 sub_value = deltas[sub_part]
536 sub_value = deltas[sub_part]
535 else:
537 else:
536 sub_value = 0
538 sub_value = 0
537
539
538 if sub_value == 0 or show_short_version:
540 if sub_value == 0 or show_short_version:
539 _val = fmt_funcs[part](value)
541 _val = fmt_funcs[part](value)
540 if future:
542 if future:
541 if show_suffix:
543 if show_suffix:
542 return _(u'in ${ago}', mapping={'ago': _val})
544 return _(u'in ${ago}', mapping={'ago': _val})
543 else:
545 else:
544 return _(_val)
546 return _(_val)
545
547
546 else:
548 else:
547 if show_suffix:
549 if show_suffix:
548 return _(u'${ago} ago', mapping={'ago': _val})
550 return _(u'${ago} ago', mapping={'ago': _val})
549 else:
551 else:
550 return _(_val)
552 return _(_val)
551
553
552 val = fmt_funcs[part](value)
554 val = fmt_funcs[part](value)
553 val_detail = fmt_funcs[sub_part](sub_value)
555 val_detail = fmt_funcs[sub_part](sub_value)
554 mapping = {'val': val, 'detail': val_detail}
556 mapping = {'val': val, 'detail': val_detail}
555
557
556 if short_format:
558 if short_format:
557 datetime_tmpl = _(u'${val}, ${detail}', mapping=mapping)
559 datetime_tmpl = _(u'${val}, ${detail}', mapping=mapping)
558 if show_suffix:
560 if show_suffix:
559 datetime_tmpl = _(u'${val}, ${detail} ago', mapping=mapping)
561 datetime_tmpl = _(u'${val}, ${detail} ago', mapping=mapping)
560 if future:
562 if future:
561 datetime_tmpl = _(u'in ${val}, ${detail}', mapping=mapping)
563 datetime_tmpl = _(u'in ${val}, ${detail}', mapping=mapping)
562 else:
564 else:
563 datetime_tmpl = _(u'${val} and ${detail}', mapping=mapping)
565 datetime_tmpl = _(u'${val} and ${detail}', mapping=mapping)
564 if show_suffix:
566 if show_suffix:
565 datetime_tmpl = _(u'${val} and ${detail} ago', mapping=mapping)
567 datetime_tmpl = _(u'${val} and ${detail} ago', mapping=mapping)
566 if future:
568 if future:
567 datetime_tmpl = _(u'in ${val} and ${detail}', mapping=mapping)
569 datetime_tmpl = _(u'in ${val} and ${detail}', mapping=mapping)
568
570
569 return datetime_tmpl
571 return datetime_tmpl
570 i += 1
572 i += 1
571 return _(u'just now')
573 return _(u'just now')
572
574
573
575
574 def age_from_seconds(seconds):
576 def age_from_seconds(seconds):
575 seconds = safe_int(seconds) or 0
577 seconds = safe_int(seconds) or 0
576 prevdate = time_to_datetime(time.time() + seconds)
578 prevdate = time_to_datetime(time.time() + seconds)
577 return age(prevdate, show_suffix=False, show_short_version=True)
579 return age(prevdate, show_suffix=False, show_short_version=True)
578
580
579
581
580 def cleaned_uri(uri):
582 def cleaned_uri(uri):
581 """
583 """
582 Quotes '[' and ']' from uri if there is only one of them.
584 Quotes '[' and ']' from uri if there is only one of them.
583 according to RFC3986 we cannot use such chars in uri
585 according to RFC3986 we cannot use such chars in uri
584 :param uri:
586 :param uri:
585 :return: uri without this chars
587 :return: uri without this chars
586 """
588 """
587 return urllib.quote(uri, safe='@$:/')
589 return urllib.quote(uri, safe='@$:/')
588
590
589
591
590 def credentials_filter(uri):
592 def credentials_filter(uri):
591 """
593 """
592 Returns a url with removed credentials
594 Returns a url with removed credentials
593
595
594 :param uri:
596 :param uri:
595 """
597 """
596 import urlobject
598 import urlobject
597 if isinstance(uri, rhodecode.lib.encrypt.InvalidDecryptedValue):
599 if isinstance(uri, rhodecode.lib.encrypt.InvalidDecryptedValue):
598 return 'InvalidDecryptionKey'
600 return 'InvalidDecryptionKey'
599
601
600 url_obj = urlobject.URLObject(cleaned_uri(uri))
602 url_obj = urlobject.URLObject(cleaned_uri(uri))
601 url_obj = url_obj.without_password().without_username()
603 url_obj = url_obj.without_password().without_username()
602
604
603 return url_obj
605 return url_obj
604
606
605
607
606 def get_host_info(request):
608 def get_host_info(request):
607 """
609 """
608 Generate host info, to obtain full url e.g https://server.com
610 Generate host info, to obtain full url e.g https://server.com
609 use this
611 use this
610 `{scheme}://{netloc}`
612 `{scheme}://{netloc}`
611 """
613 """
612 if not request:
614 if not request:
613 return {}
615 return {}
614
616
615 qualified_home_url = request.route_url('home')
617 qualified_home_url = request.route_url('home')
616 parsed_url = urlobject.URLObject(qualified_home_url)
618 parsed_url = urlobject.URLObject(qualified_home_url)
617 decoded_path = safe_unicode(urllib.unquote(parsed_url.path.rstrip('/')))
619 decoded_path = safe_unicode(urllib.unquote(parsed_url.path.rstrip('/')))
618
620
619 return {
621 return {
620 'scheme': parsed_url.scheme,
622 'scheme': parsed_url.scheme,
621 'netloc': parsed_url.netloc+decoded_path,
623 'netloc': parsed_url.netloc+decoded_path,
622 'hostname': parsed_url.hostname,
624 'hostname': parsed_url.hostname,
623 }
625 }
624
626
625
627
626 def get_clone_url(request, uri_tmpl, repo_name, repo_id, repo_type, **override):
628 def get_clone_url(request, uri_tmpl, repo_name, repo_id, repo_type, **override):
627 qualified_home_url = request.route_url('home')
629 qualified_home_url = request.route_url('home')
628 parsed_url = urlobject.URLObject(qualified_home_url)
630 parsed_url = urlobject.URLObject(qualified_home_url)
629 decoded_path = safe_unicode(urllib.unquote(parsed_url.path.rstrip('/')))
631 decoded_path = safe_unicode(urllib.unquote(parsed_url.path.rstrip('/')))
630
632
631 args = {
633 args = {
632 'scheme': parsed_url.scheme,
634 'scheme': parsed_url.scheme,
633 'user': '',
635 'user': '',
634 'sys_user': getpass.getuser(),
636 'sys_user': getpass.getuser(),
635 # path if we use proxy-prefix
637 # path if we use proxy-prefix
636 'netloc': parsed_url.netloc+decoded_path,
638 'netloc': parsed_url.netloc+decoded_path,
637 'hostname': parsed_url.hostname,
639 'hostname': parsed_url.hostname,
638 'prefix': decoded_path,
640 'prefix': decoded_path,
639 'repo': repo_name,
641 'repo': repo_name,
640 'repoid': str(repo_id),
642 'repoid': str(repo_id),
641 'repo_type': repo_type
643 'repo_type': repo_type
642 }
644 }
643 args.update(override)
645 args.update(override)
644 args['user'] = urllib.quote(safe_str(args['user']))
646 args['user'] = urllib.quote(safe_str(args['user']))
645
647
646 for k, v in args.items():
648 for k, v in args.items():
647 uri_tmpl = uri_tmpl.replace('{%s}' % k, v)
649 uri_tmpl = uri_tmpl.replace('{%s}' % k, v)
648
650
649 # special case for SVN clone url
651 # special case for SVN clone url
650 if repo_type == 'svn':
652 if repo_type == 'svn':
651 uri_tmpl = uri_tmpl.replace('ssh://', 'svn+ssh://')
653 uri_tmpl = uri_tmpl.replace('ssh://', 'svn+ssh://')
652
654
653 # remove leading @ sign if it's present. Case of empty user
655 # remove leading @ sign if it's present. Case of empty user
654 url_obj = urlobject.URLObject(uri_tmpl)
656 url_obj = urlobject.URLObject(uri_tmpl)
655 url = url_obj.with_netloc(url_obj.netloc.lstrip('@'))
657 url = url_obj.with_netloc(url_obj.netloc.lstrip('@'))
656
658
657 return safe_unicode(url)
659 return safe_unicode(url)
658
660
659
661
660 def get_commit_safe(repo, commit_id=None, commit_idx=None, pre_load=None,
662 def get_commit_safe(repo, commit_id=None, commit_idx=None, pre_load=None,
661 maybe_unreachable=False, reference_obj=None):
663 maybe_unreachable=False, reference_obj=None):
662 """
664 """
663 Safe version of get_commit if this commit doesn't exists for a
665 Safe version of get_commit if this commit doesn't exists for a
664 repository it returns a Dummy one instead
666 repository it returns a Dummy one instead
665
667
666 :param repo: repository instance
668 :param repo: repository instance
667 :param commit_id: commit id as str
669 :param commit_id: commit id as str
668 :param commit_idx: numeric commit index
670 :param commit_idx: numeric commit index
669 :param pre_load: optional list of commit attributes to load
671 :param pre_load: optional list of commit attributes to load
670 :param maybe_unreachable: translate unreachable commits on git repos
672 :param maybe_unreachable: translate unreachable commits on git repos
671 :param reference_obj: explicitly search via a reference obj in git. E.g "branch:123" would mean branch "123"
673 :param reference_obj: explicitly search via a reference obj in git. E.g "branch:123" would mean branch "123"
672 """
674 """
673 # TODO(skreft): remove these circular imports
675 # TODO(skreft): remove these circular imports
674 from rhodecode.lib.vcs.backends.base import BaseRepository, EmptyCommit
676 from rhodecode.lib.vcs.backends.base import BaseRepository, EmptyCommit
675 from rhodecode.lib.vcs.exceptions import RepositoryError
677 from rhodecode.lib.vcs.exceptions import RepositoryError
676 if not isinstance(repo, BaseRepository):
678 if not isinstance(repo, BaseRepository):
677 raise Exception('You must pass an Repository '
679 raise Exception('You must pass an Repository '
678 'object as first argument got %s', type(repo))
680 'object as first argument got %s', type(repo))
679
681
680 try:
682 try:
681 commit = repo.get_commit(
683 commit = repo.get_commit(
682 commit_id=commit_id, commit_idx=commit_idx, pre_load=pre_load,
684 commit_id=commit_id, commit_idx=commit_idx, pre_load=pre_load,
683 maybe_unreachable=maybe_unreachable, reference_obj=reference_obj)
685 maybe_unreachable=maybe_unreachable, reference_obj=reference_obj)
684 except (RepositoryError, LookupError):
686 except (RepositoryError, LookupError):
685 commit = EmptyCommit()
687 commit = EmptyCommit()
686 return commit
688 return commit
687
689
688
690
689 def datetime_to_time(dt):
691 def datetime_to_time(dt):
690 if dt:
692 if dt:
691 return time.mktime(dt.timetuple())
693 return time.mktime(dt.timetuple())
692
694
693
695
694 def time_to_datetime(tm):
696 def time_to_datetime(tm):
695 if tm:
697 if tm:
696 if isinstance(tm, compat.string_types):
698 if isinstance(tm, compat.string_types):
697 try:
699 try:
698 tm = float(tm)
700 tm = float(tm)
699 except ValueError:
701 except ValueError:
700 return
702 return
701 return datetime.datetime.fromtimestamp(tm)
703 return datetime.datetime.fromtimestamp(tm)
702
704
703
705
704 def time_to_utcdatetime(tm):
706 def time_to_utcdatetime(tm):
705 if tm:
707 if tm:
706 if isinstance(tm, compat.string_types):
708 if isinstance(tm, compat.string_types):
707 try:
709 try:
708 tm = float(tm)
710 tm = float(tm)
709 except ValueError:
711 except ValueError:
710 return
712 return
711 return datetime.datetime.utcfromtimestamp(tm)
713 return datetime.datetime.utcfromtimestamp(tm)
712
714
713
715
714 MENTIONS_REGEX = re.compile(
716 MENTIONS_REGEX = re.compile(
715 # ^@ or @ without any special chars in front
717 # ^@ or @ without any special chars in front
716 r'(?:^@|[^a-zA-Z0-9\-\_\.]@)'
718 r'(?:^@|[^a-zA-Z0-9\-\_\.]@)'
717 # main body starts with letter, then can be . - _
719 # main body starts with letter, then can be . - _
718 r'([a-zA-Z0-9]{1}[a-zA-Z0-9\-\_\.]+)',
720 r'([a-zA-Z0-9]{1}[a-zA-Z0-9\-\_\.]+)',
719 re.VERBOSE | re.MULTILINE)
721 re.VERBOSE | re.MULTILINE)
720
722
721
723
722 def extract_mentioned_users(s):
724 def extract_mentioned_users(s):
723 """
725 """
724 Returns unique usernames from given string s that have @mention
726 Returns unique usernames from given string s that have @mention
725
727
726 :param s: string to get mentions
728 :param s: string to get mentions
727 """
729 """
728 usrs = set()
730 usrs = set()
729 for username in MENTIONS_REGEX.findall(s):
731 for username in MENTIONS_REGEX.findall(s):
730 usrs.add(username)
732 usrs.add(username)
731
733
732 return sorted(list(usrs), key=lambda k: k.lower())
734 return sorted(list(usrs), key=lambda k: k.lower())
733
735
734
736
735 class AttributeDictBase(dict):
737 class AttributeDictBase(dict):
736 def __getstate__(self):
738 def __getstate__(self):
737 odict = self.__dict__ # get attribute dictionary
739 odict = self.__dict__ # get attribute dictionary
738 return odict
740 return odict
739
741
740 def __setstate__(self, dict):
742 def __setstate__(self, dict):
741 self.__dict__ = dict
743 self.__dict__ = dict
742
744
743 __setattr__ = dict.__setitem__
745 __setattr__ = dict.__setitem__
744 __delattr__ = dict.__delitem__
746 __delattr__ = dict.__delitem__
745
747
746
748
747 class StrictAttributeDict(AttributeDictBase):
749 class StrictAttributeDict(AttributeDictBase):
748 """
750 """
749 Strict Version of Attribute dict which raises an Attribute error when
751 Strict Version of Attribute dict which raises an Attribute error when
750 requested attribute is not set
752 requested attribute is not set
751 """
753 """
752 def __getattr__(self, attr):
754 def __getattr__(self, attr):
753 try:
755 try:
754 return self[attr]
756 return self[attr]
755 except KeyError:
757 except KeyError:
756 raise AttributeError('%s object has no attribute %s' % (
758 raise AttributeError('%s object has no attribute %s' % (
757 self.__class__, attr))
759 self.__class__, attr))
758
760
759
761
760 class AttributeDict(AttributeDictBase):
762 class AttributeDict(AttributeDictBase):
761 def __getattr__(self, attr):
763 def __getattr__(self, attr):
762 return self.get(attr, None)
764 return self.get(attr, None)
763
765
764
766
765
767
766 class OrderedDefaultDict(collections.OrderedDict, collections.defaultdict):
768 class OrderedDefaultDict(collections.OrderedDict, collections.defaultdict):
767 def __init__(self, default_factory=None, *args, **kwargs):
769 def __init__(self, default_factory=None, *args, **kwargs):
768 # in python3 you can omit the args to super
770 # in python3 you can omit the args to super
769 super(OrderedDefaultDict, self).__init__(*args, **kwargs)
771 super(OrderedDefaultDict, self).__init__(*args, **kwargs)
770 self.default_factory = default_factory
772 self.default_factory = default_factory
771
773
772
774
773 def fix_PATH(os_=None):
775 def fix_PATH(os_=None):
774 """
776 """
775 Get current active python path, and append it to PATH variable to fix
777 Get current active python path, and append it to PATH variable to fix
776 issues of subprocess calls and different python versions
778 issues of subprocess calls and different python versions
777 """
779 """
778 if os_ is None:
780 if os_ is None:
779 import os
781 import os
780 else:
782 else:
781 os = os_
783 os = os_
782
784
783 cur_path = os.path.split(sys.executable)[0]
785 cur_path = os.path.split(sys.executable)[0]
784 if not os.environ['PATH'].startswith(cur_path):
786 if not os.environ['PATH'].startswith(cur_path):
785 os.environ['PATH'] = '%s:%s' % (cur_path, os.environ['PATH'])
787 os.environ['PATH'] = '%s:%s' % (cur_path, os.environ['PATH'])
786
788
787
789
788 def obfuscate_url_pw(engine):
790 def obfuscate_url_pw(engine):
789 _url = engine or ''
791 _url = engine or ''
790 try:
792 try:
791 _url = sqlalchemy.engine.url.make_url(engine)
793 _url = sqlalchemy.engine.url.make_url(engine)
792 if _url.password:
794 if _url.password:
793 _url.password = 'XXXXX'
795 _url.password = 'XXXXX'
794 except Exception:
796 except Exception:
795 pass
797 pass
796 return unicode(_url)
798 return unicode(_url)
797
799
798
800
799 def get_server_url(environ):
801 def get_server_url(environ):
800 req = webob.Request(environ)
802 req = webob.Request(environ)
801 return req.host_url + req.script_name
803 return req.host_url + req.script_name
802
804
803
805
804 def unique_id(hexlen=32):
806 def unique_id(hexlen=32):
805 alphabet = "23456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghjklmnpqrstuvwxyz"
807 alphabet = "23456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghjklmnpqrstuvwxyz"
806 return suuid(truncate_to=hexlen, alphabet=alphabet)
808 return suuid(truncate_to=hexlen, alphabet=alphabet)
807
809
808
810
809 def suuid(url=None, truncate_to=22, alphabet=None):
811 def suuid(url=None, truncate_to=22, alphabet=None):
810 """
812 """
811 Generate and return a short URL safe UUID.
813 Generate and return a short URL safe UUID.
812
814
813 If the url parameter is provided, set the namespace to the provided
815 If the url parameter is provided, set the namespace to the provided
814 URL and generate a UUID.
816 URL and generate a UUID.
815
817
816 :param url to get the uuid for
818 :param url to get the uuid for
817 :truncate_to: truncate the basic 22 UUID to shorter version
819 :truncate_to: truncate the basic 22 UUID to shorter version
818
820
819 The IDs won't be universally unique any longer, but the probability of
821 The IDs won't be universally unique any longer, but the probability of
820 a collision will still be very low.
822 a collision will still be very low.
821 """
823 """
822 # Define our alphabet.
824 # Define our alphabet.
823 _ALPHABET = alphabet or "23456789ABCDEFGHJKLMNPQRSTUVWXYZ"
825 _ALPHABET = alphabet or "23456789ABCDEFGHJKLMNPQRSTUVWXYZ"
824
826
825 # If no URL is given, generate a random UUID.
827 # If no URL is given, generate a random UUID.
826 if url is None:
828 if url is None:
827 unique_id = uuid.uuid4().int
829 unique_id = uuid.uuid4().int
828 else:
830 else:
829 unique_id = uuid.uuid3(uuid.NAMESPACE_URL, url).int
831 unique_id = uuid.uuid3(uuid.NAMESPACE_URL, url).int
830
832
831 alphabet_length = len(_ALPHABET)
833 alphabet_length = len(_ALPHABET)
832 output = []
834 output = []
833 while unique_id > 0:
835 while unique_id > 0:
834 digit = unique_id % alphabet_length
836 digit = unique_id % alphabet_length
835 output.append(_ALPHABET[digit])
837 output.append(_ALPHABET[digit])
836 unique_id = int(unique_id / alphabet_length)
838 unique_id = int(unique_id / alphabet_length)
837 return "".join(output)[:truncate_to]
839 return "".join(output)[:truncate_to]
838
840
839
841
840 def get_current_rhodecode_user(request=None):
842 def get_current_rhodecode_user(request=None):
841 """
843 """
842 Gets rhodecode user from request
844 Gets rhodecode user from request
843 """
845 """
844 pyramid_request = request or pyramid.threadlocal.get_current_request()
846 pyramid_request = request or pyramid.threadlocal.get_current_request()
845
847
846 # web case
848 # web case
847 if pyramid_request and hasattr(pyramid_request, 'user'):
849 if pyramid_request and hasattr(pyramid_request, 'user'):
848 return pyramid_request.user
850 return pyramid_request.user
849
851
850 # api case
852 # api case
851 if pyramid_request and hasattr(pyramid_request, 'rpc_user'):
853 if pyramid_request and hasattr(pyramid_request, 'rpc_user'):
852 return pyramid_request.rpc_user
854 return pyramid_request.rpc_user
853
855
854 return None
856 return None
855
857
856
858
857 def action_logger_generic(action, namespace=''):
859 def action_logger_generic(action, namespace=''):
858 """
860 """
859 A generic logger for actions useful to the system overview, tries to find
861 A generic logger for actions useful to the system overview, tries to find
860 an acting user for the context of the call otherwise reports unknown user
862 an acting user for the context of the call otherwise reports unknown user
861
863
862 :param action: logging message eg 'comment 5 deleted'
864 :param action: logging message eg 'comment 5 deleted'
863 :param type: string
865 :param type: string
864
866
865 :param namespace: namespace of the logging message eg. 'repo.comments'
867 :param namespace: namespace of the logging message eg. 'repo.comments'
866 :param type: string
868 :param type: string
867
869
868 """
870 """
869
871
870 logger_name = 'rhodecode.actions'
872 logger_name = 'rhodecode.actions'
871
873
872 if namespace:
874 if namespace:
873 logger_name += '.' + namespace
875 logger_name += '.' + namespace
874
876
875 log = logging.getLogger(logger_name)
877 log = logging.getLogger(logger_name)
876
878
877 # get a user if we can
879 # get a user if we can
878 user = get_current_rhodecode_user()
880 user = get_current_rhodecode_user()
879
881
880 logfunc = log.info
882 logfunc = log.info
881
883
882 if not user:
884 if not user:
883 user = '<unknown user>'
885 user = '<unknown user>'
884 logfunc = log.warning
886 logfunc = log.warning
885
887
886 logfunc('Logging action by {}: {}'.format(user, action))
888 logfunc('Logging action by {}: {}'.format(user, action))
887
889
888
890
889 def escape_split(text, sep=',', maxsplit=-1):
891 def escape_split(text, sep=',', maxsplit=-1):
890 r"""
892 r"""
891 Allows for escaping of the separator: e.g. arg='foo\, bar'
893 Allows for escaping of the separator: e.g. arg='foo\, bar'
892
894
893 It should be noted that the way bash et. al. do command line parsing, those
895 It should be noted that the way bash et. al. do command line parsing, those
894 single quotes are required.
896 single quotes are required.
895 """
897 """
896 escaped_sep = r'\%s' % sep
898 escaped_sep = r'\%s' % sep
897
899
898 if escaped_sep not in text:
900 if escaped_sep not in text:
899 return text.split(sep, maxsplit)
901 return text.split(sep, maxsplit)
900
902
901 before, _mid, after = text.partition(escaped_sep)
903 before, _mid, after = text.partition(escaped_sep)
902 startlist = before.split(sep, maxsplit) # a regular split is fine here
904 startlist = before.split(sep, maxsplit) # a regular split is fine here
903 unfinished = startlist[-1]
905 unfinished = startlist[-1]
904 startlist = startlist[:-1]
906 startlist = startlist[:-1]
905
907
906 # recurse because there may be more escaped separators
908 # recurse because there may be more escaped separators
907 endlist = escape_split(after, sep, maxsplit)
909 endlist = escape_split(after, sep, maxsplit)
908
910
909 # finish building the escaped value. we use endlist[0] becaue the first
911 # finish building the escaped value. we use endlist[0] becaue the first
910 # part of the string sent in recursion is the rest of the escaped value.
912 # part of the string sent in recursion is the rest of the escaped value.
911 unfinished += sep + endlist[0]
913 unfinished += sep + endlist[0]
912
914
913 return startlist + [unfinished] + endlist[1:] # put together all the parts
915 return startlist + [unfinished] + endlist[1:] # put together all the parts
914
916
915
917
916 class OptionalAttr(object):
918 class OptionalAttr(object):
917 """
919 """
918 Special Optional Option that defines other attribute. Example::
920 Special Optional Option that defines other attribute. Example::
919
921
920 def test(apiuser, userid=Optional(OAttr('apiuser')):
922 def test(apiuser, userid=Optional(OAttr('apiuser')):
921 user = Optional.extract(userid)
923 user = Optional.extract(userid)
922 # calls
924 # calls
923
925
924 """
926 """
925
927
926 def __init__(self, attr_name):
928 def __init__(self, attr_name):
927 self.attr_name = attr_name
929 self.attr_name = attr_name
928
930
929 def __repr__(self):
931 def __repr__(self):
930 return '<OptionalAttr:%s>' % self.attr_name
932 return '<OptionalAttr:%s>' % self.attr_name
931
933
932 def __call__(self):
934 def __call__(self):
933 return self
935 return self
934
936
935
937
936 # alias
938 # alias
937 OAttr = OptionalAttr
939 OAttr = OptionalAttr
938
940
939
941
940 class Optional(object):
942 class Optional(object):
941 """
943 """
942 Defines an optional parameter::
944 Defines an optional parameter::
943
945
944 param = param.getval() if isinstance(param, Optional) else param
946 param = param.getval() if isinstance(param, Optional) else param
945 param = param() if isinstance(param, Optional) else param
947 param = param() if isinstance(param, Optional) else param
946
948
947 is equivalent of::
949 is equivalent of::
948
950
949 param = Optional.extract(param)
951 param = Optional.extract(param)
950
952
951 """
953 """
952
954
953 def __init__(self, type_):
955 def __init__(self, type_):
954 self.type_ = type_
956 self.type_ = type_
955
957
956 def __repr__(self):
958 def __repr__(self):
957 return '<Optional:%s>' % self.type_.__repr__()
959 return '<Optional:%s>' % self.type_.__repr__()
958
960
959 def __call__(self):
961 def __call__(self):
960 return self.getval()
962 return self.getval()
961
963
962 def getval(self):
964 def getval(self):
963 """
965 """
964 returns value from this Optional instance
966 returns value from this Optional instance
965 """
967 """
966 if isinstance(self.type_, OAttr):
968 if isinstance(self.type_, OAttr):
967 # use params name
969 # use params name
968 return self.type_.attr_name
970 return self.type_.attr_name
969 return self.type_
971 return self.type_
970
972
971 @classmethod
973 @classmethod
972 def extract(cls, val):
974 def extract(cls, val):
973 """
975 """
974 Extracts value from Optional() instance
976 Extracts value from Optional() instance
975
977
976 :param val:
978 :param val:
977 :return: original value if it's not Optional instance else
979 :return: original value if it's not Optional instance else
978 value of instance
980 value of instance
979 """
981 """
980 if isinstance(val, cls):
982 if isinstance(val, cls):
981 return val.getval()
983 return val.getval()
982 return val
984 return val
983
985
984
986
985 def glob2re(pat):
987 def glob2re(pat):
986 """
988 """
987 Translate a shell PATTERN to a regular expression.
989 Translate a shell PATTERN to a regular expression.
988
990
989 There is no way to quote meta-characters.
991 There is no way to quote meta-characters.
990 """
992 """
991
993
992 i, n = 0, len(pat)
994 i, n = 0, len(pat)
993 res = ''
995 res = ''
994 while i < n:
996 while i < n:
995 c = pat[i]
997 c = pat[i]
996 i = i+1
998 i = i+1
997 if c == '*':
999 if c == '*':
998 #res = res + '.*'
1000 #res = res + '.*'
999 res = res + '[^/]*'
1001 res = res + '[^/]*'
1000 elif c == '?':
1002 elif c == '?':
1001 #res = res + '.'
1003 #res = res + '.'
1002 res = res + '[^/]'
1004 res = res + '[^/]'
1003 elif c == '[':
1005 elif c == '[':
1004 j = i
1006 j = i
1005 if j < n and pat[j] == '!':
1007 if j < n and pat[j] == '!':
1006 j = j+1
1008 j = j+1
1007 if j < n and pat[j] == ']':
1009 if j < n and pat[j] == ']':
1008 j = j+1
1010 j = j+1
1009 while j < n and pat[j] != ']':
1011 while j < n and pat[j] != ']':
1010 j = j+1
1012 j = j+1
1011 if j >= n:
1013 if j >= n:
1012 res = res + '\\['
1014 res = res + '\\['
1013 else:
1015 else:
1014 stuff = pat[i:j].replace('\\','\\\\')
1016 stuff = pat[i:j].replace('\\','\\\\')
1015 i = j+1
1017 i = j+1
1016 if stuff[0] == '!':
1018 if stuff[0] == '!':
1017 stuff = '^' + stuff[1:]
1019 stuff = '^' + stuff[1:]
1018 elif stuff[0] == '^':
1020 elif stuff[0] == '^':
1019 stuff = '\\' + stuff
1021 stuff = '\\' + stuff
1020 res = '%s[%s]' % (res, stuff)
1022 res = '%s[%s]' % (res, stuff)
1021 else:
1023 else:
1022 res = res + re.escape(c)
1024 res = res + re.escape(c)
1023 return res + '\Z(?ms)'
1025 return res + '\Z(?ms)'
1024
1026
1025
1027
1026 def parse_byte_string(size_str):
1028 def parse_byte_string(size_str):
1027 match = re.match(r'(\d+)(MB|KB)', size_str, re.IGNORECASE)
1029 match = re.match(r'(\d+)(MB|KB)', size_str, re.IGNORECASE)
1028 if not match:
1030 if not match:
1029 raise ValueError('Given size:%s is invalid, please make sure '
1031 raise ValueError('Given size:%s is invalid, please make sure '
1030 'to use format of <num>(MB|KB)' % size_str)
1032 'to use format of <num>(MB|KB)' % size_str)
1031
1033
1032 _parts = match.groups()
1034 _parts = match.groups()
1033 num, type_ = _parts
1035 num, type_ = _parts
1034 return long(num) * {'mb': 1024*1024, 'kb': 1024}[type_.lower()]
1036 return long(num) * {'mb': 1024*1024, 'kb': 1024}[type_.lower()]
1035
1037
1036
1038
1037 class CachedProperty(object):
1039 class CachedProperty(object):
1038 """
1040 """
1039 Lazy Attributes. With option to invalidate the cache by running a method
1041 Lazy Attributes. With option to invalidate the cache by running a method
1040
1042
1041 >>> class Foo(object):
1043 >>> class Foo(object):
1042 ...
1044 ...
1043 ... @CachedProperty
1045 ... @CachedProperty
1044 ... def heavy_func(self):
1046 ... def heavy_func(self):
1045 ... return 'super-calculation'
1047 ... return 'super-calculation'
1046 ...
1048 ...
1047 ... foo = Foo()
1049 ... foo = Foo()
1048 ... foo.heavy_func() # first computation
1050 ... foo.heavy_func() # first computation
1049 ... foo.heavy_func() # fetch from cache
1051 ... foo.heavy_func() # fetch from cache
1050 ... foo._invalidate_prop_cache('heavy_func')
1052 ... foo._invalidate_prop_cache('heavy_func')
1051
1053
1052 # at this point calling foo.heavy_func() will be re-computed
1054 # at this point calling foo.heavy_func() will be re-computed
1053 """
1055 """
1054
1056
1055 def __init__(self, func, func_name=None):
1057 def __init__(self, func, func_name=None):
1056
1058
1057 if func_name is None:
1059 if func_name is None:
1058 func_name = func.__name__
1060 func_name = func.__name__
1059 self.data = (func, func_name)
1061 self.data = (func, func_name)
1060 update_wrapper(self, func)
1062 update_wrapper(self, func)
1061
1063
1062 def __get__(self, inst, class_):
1064 def __get__(self, inst, class_):
1063 if inst is None:
1065 if inst is None:
1064 return self
1066 return self
1065
1067
1066 func, func_name = self.data
1068 func, func_name = self.data
1067 value = func(inst)
1069 value = func(inst)
1068 inst.__dict__[func_name] = value
1070 inst.__dict__[func_name] = value
1069 if '_invalidate_prop_cache' not in inst.__dict__:
1071 if '_invalidate_prop_cache' not in inst.__dict__:
1070 inst.__dict__['_invalidate_prop_cache'] = partial(
1072 inst.__dict__['_invalidate_prop_cache'] = partial(
1071 self._invalidate_prop_cache, inst)
1073 self._invalidate_prop_cache, inst)
1072 return value
1074 return value
1073
1075
1074 def _invalidate_prop_cache(self, inst, name):
1076 def _invalidate_prop_cache(self, inst, name):
1075 inst.__dict__.pop(name, None)
1077 inst.__dict__.pop(name, None)
1076
1078
1077
1079
1078 def retry(func=None, exception=Exception, n_tries=5, delay=5, backoff=1, logger=True):
1080 def retry(func=None, exception=Exception, n_tries=5, delay=5, backoff=1, logger=True):
1079 """
1081 """
1080 Retry decorator with exponential backoff.
1082 Retry decorator with exponential backoff.
1081
1083
1082 Parameters
1084 Parameters
1083 ----------
1085 ----------
1084 func : typing.Callable, optional
1086 func : typing.Callable, optional
1085 Callable on which the decorator is applied, by default None
1087 Callable on which the decorator is applied, by default None
1086 exception : Exception or tuple of Exceptions, optional
1088 exception : Exception or tuple of Exceptions, optional
1087 Exception(s) that invoke retry, by default Exception
1089 Exception(s) that invoke retry, by default Exception
1088 n_tries : int, optional
1090 n_tries : int, optional
1089 Number of tries before giving up, by default 5
1091 Number of tries before giving up, by default 5
1090 delay : int, optional
1092 delay : int, optional
1091 Initial delay between retries in seconds, by default 5
1093 Initial delay between retries in seconds, by default 5
1092 backoff : int, optional
1094 backoff : int, optional
1093 Backoff multiplier e.g. value of 2 will double the delay, by default 1
1095 Backoff multiplier e.g. value of 2 will double the delay, by default 1
1094 logger : bool, optional
1096 logger : bool, optional
1095 Option to log or print, by default False
1097 Option to log or print, by default False
1096
1098
1097 Returns
1099 Returns
1098 -------
1100 -------
1099 typing.Callable
1101 typing.Callable
1100 Decorated callable that calls itself when exception(s) occur.
1102 Decorated callable that calls itself when exception(s) occur.
1101
1103
1102 Examples
1104 Examples
1103 --------
1105 --------
1104 >>> import random
1106 >>> import random
1105 >>> @retry(exception=Exception, n_tries=3)
1107 >>> @retry(exception=Exception, n_tries=3)
1106 ... def test_random(text):
1108 ... def test_random(text):
1107 ... x = random.random()
1109 ... x = random.random()
1108 ... if x < 0.5:
1110 ... if x < 0.5:
1109 ... raise Exception("Fail")
1111 ... raise Exception("Fail")
1110 ... else:
1112 ... else:
1111 ... print("Success: ", text)
1113 ... print("Success: ", text)
1112 >>> test_random("It works!")
1114 >>> test_random("It works!")
1113 """
1115 """
1114
1116
1115 if func is None:
1117 if func is None:
1116 return partial(
1118 return partial(
1117 retry,
1119 retry,
1118 exception=exception,
1120 exception=exception,
1119 n_tries=n_tries,
1121 n_tries=n_tries,
1120 delay=delay,
1122 delay=delay,
1121 backoff=backoff,
1123 backoff=backoff,
1122 logger=logger,
1124 logger=logger,
1123 )
1125 )
1124
1126
1125 @wraps(func)
1127 @wraps(func)
1126 def wrapper(*args, **kwargs):
1128 def wrapper(*args, **kwargs):
1127 _n_tries, n_delay = n_tries, delay
1129 _n_tries, n_delay = n_tries, delay
1128 log = logging.getLogger('rhodecode.retry')
1130 log = logging.getLogger('rhodecode.retry')
1129
1131
1130 while _n_tries > 1:
1132 while _n_tries > 1:
1131 try:
1133 try:
1132 return func(*args, **kwargs)
1134 return func(*args, **kwargs)
1133 except exception as e:
1135 except exception as e:
1134 e_details = repr(e)
1136 e_details = repr(e)
1135 msg = "Exception on calling func {func}: {e}, " \
1137 msg = "Exception on calling func {func}: {e}, " \
1136 "Retrying in {n_delay} seconds..."\
1138 "Retrying in {n_delay} seconds..."\
1137 .format(func=func, e=e_details, n_delay=n_delay)
1139 .format(func=func, e=e_details, n_delay=n_delay)
1138 if logger:
1140 if logger:
1139 log.warning(msg)
1141 log.warning(msg)
1140 else:
1142 else:
1141 print(msg)
1143 print(msg)
1142 time.sleep(n_delay)
1144 time.sleep(n_delay)
1143 _n_tries -= 1
1145 _n_tries -= 1
1144 n_delay *= backoff
1146 n_delay *= backoff
1145
1147
1146 return func(*args, **kwargs)
1148 return func(*args, **kwargs)
1147
1149
1148 return wrapper
1150 return wrapper
1149
1151
1150
1152
1151 def user_agent_normalizer(user_agent_raw, safe=True):
1153 def user_agent_normalizer(user_agent_raw, safe=True):
1152 log = logging.getLogger('rhodecode.user_agent_normalizer')
1154 log = logging.getLogger('rhodecode.user_agent_normalizer')
1153 ua = (user_agent_raw or '').strip().lower()
1155 ua = (user_agent_raw or '').strip().lower()
1154 ua = ua.replace('"', '')
1156 ua = ua.replace('"', '')
1155
1157
1156 try:
1158 try:
1157 if 'mercurial/proto-1.0' in ua:
1159 if 'mercurial/proto-1.0' in ua:
1158 ua = ua.replace('mercurial/proto-1.0', '')
1160 ua = ua.replace('mercurial/proto-1.0', '')
1159 ua = ua.replace('(', '').replace(')', '').strip()
1161 ua = ua.replace('(', '').replace(')', '').strip()
1160 ua = ua.replace('mercurial ', 'mercurial/')
1162 ua = ua.replace('mercurial ', 'mercurial/')
1161 elif ua.startswith('git'):
1163 elif ua.startswith('git'):
1162 parts = ua.split(' ')
1164 parts = ua.split(' ')
1163 if parts:
1165 if parts:
1164 ua = parts[0]
1166 ua = parts[0]
1165 ua = re.sub('\.windows\.\d', '', ua).strip()
1167 ua = re.sub('\.windows\.\d', '', ua).strip()
1166
1168
1167 return ua
1169 return ua
1168 except Exception:
1170 except Exception:
1169 log.exception('Failed to parse scm user-agent')
1171 log.exception('Failed to parse scm user-agent')
1170 if not safe:
1172 if not safe:
1171 raise
1173 raise
1172
1174
1173 return ua
1175 return ua
1176
1177
1178 def get_available_port(min_port=40000, max_port=55555):
1179 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1180 hostname = '127.0.0.1'
1181 pick_port = min_port
1182
1183 for _ in range(min_port, max_port):
1184 pick_port = random.randint(min_port, max_port)
1185 try:
1186 sock.bind((hostname, pick_port))
1187 sock.close()
1188 break
1189 except OSError:
1190 pass
1191
1192 del sock
1193 return pick_port
@@ -1,295 +1,285 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2
2
3 # Copyright (C) 2010-2020 RhodeCode GmbH
3 # Copyright (C) 2010-2020 RhodeCode GmbH
4 #
4 #
5 # This program is free software: you can redistribute it and/or modify
5 # This program is free software: you can redistribute it and/or modify
6 # it under the terms of the GNU Affero General Public License, version 3
6 # it under the terms of the GNU Affero General Public License, version 3
7 # (only), as published by the Free Software Foundation.
7 # (only), as published by the Free Software Foundation.
8 #
8 #
9 # This program is distributed in the hope that it will be useful,
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
12 # GNU General Public License for more details.
13 #
13 #
14 # You should have received a copy of the GNU Affero General Public License
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 #
16 #
17 # This program is dual-licensed. If you wish to learn more about the
17 # This program is dual-licensed. If you wish to learn more about the
18 # RhodeCode Enterprise Edition, including its added features, Support services,
18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20
20
21 import json
21 import json
22 import platform
22 import platform
23 import socket
23 import socket
24 import random
24 import random
25 import pytest
25 import pytest
26
26
27 from rhodecode.lib.pyramid_utils import get_app_config
27 from rhodecode.lib.pyramid_utils import get_app_config
28 from rhodecode.tests.fixture import TestINI
28 from rhodecode.tests.fixture import TestINI
29 from rhodecode.tests.server_utils import RcVCSServer
29 from rhodecode.tests.server_utils import RcVCSServer
30
30
31
31
32 def _parse_json(value):
32 def _parse_json(value):
33 return json.loads(value) if value else None
33 return json.loads(value) if value else None
34
34
35
35
36 def pytest_addoption(parser):
36 def pytest_addoption(parser):
37 parser.addoption(
37 parser.addoption(
38 '--test-loglevel', dest='test_loglevel',
38 '--test-loglevel', dest='test_loglevel',
39 help="Set default Logging level for tests, critical(default), error, warn , info, debug")
39 help="Set default Logging level for tests, critical(default), error, warn , info, debug")
40 group = parser.getgroup('pylons')
40 group = parser.getgroup('pylons')
41 group.addoption(
41 group.addoption(
42 '--with-pylons', dest='pyramid_config',
42 '--with-pylons', dest='pyramid_config',
43 help="Set up a Pylons environment with the specified config file.")
43 help="Set up a Pylons environment with the specified config file.")
44 group.addoption(
44 group.addoption(
45 '--ini-config-override', action='store', type=_parse_json,
45 '--ini-config-override', action='store', type=_parse_json,
46 default=None, dest='pyramid_config_override', help=(
46 default=None, dest='pyramid_config_override', help=(
47 "Overrides the .ini file settings. Should be specified in JSON"
47 "Overrides the .ini file settings. Should be specified in JSON"
48 " format, e.g. '{\"section\": {\"parameter\": \"value\", ...}}'"
48 " format, e.g. '{\"section\": {\"parameter\": \"value\", ...}}'"
49 )
49 )
50 )
50 )
51 parser.addini(
51 parser.addini(
52 'pyramid_config',
52 'pyramid_config',
53 "Set up a Pyramid environment with the specified config file.")
53 "Set up a Pyramid environment with the specified config file.")
54
54
55 vcsgroup = parser.getgroup('vcs')
55 vcsgroup = parser.getgroup('vcs')
56 vcsgroup.addoption(
56 vcsgroup.addoption(
57 '--without-vcsserver', dest='with_vcsserver', action='store_false',
57 '--without-vcsserver', dest='with_vcsserver', action='store_false',
58 help="Do not start the VCSServer in a background process.")
58 help="Do not start the VCSServer in a background process.")
59 vcsgroup.addoption(
59 vcsgroup.addoption(
60 '--with-vcsserver-http', dest='vcsserver_config_http',
60 '--with-vcsserver-http', dest='vcsserver_config_http',
61 help="Start the HTTP VCSServer with the specified config file.")
61 help="Start the HTTP VCSServer with the specified config file.")
62 vcsgroup.addoption(
62 vcsgroup.addoption(
63 '--vcsserver-protocol', dest='vcsserver_protocol',
63 '--vcsserver-protocol', dest='vcsserver_protocol',
64 help="Start the VCSServer with HTTP protocol support.")
64 help="Start the VCSServer with HTTP protocol support.")
65 vcsgroup.addoption(
65 vcsgroup.addoption(
66 '--vcsserver-config-override', action='store', type=_parse_json,
66 '--vcsserver-config-override', action='store', type=_parse_json,
67 default=None, dest='vcsserver_config_override', help=(
67 default=None, dest='vcsserver_config_override', help=(
68 "Overrides the .ini file settings for the VCSServer. "
68 "Overrides the .ini file settings for the VCSServer. "
69 "Should be specified in JSON "
69 "Should be specified in JSON "
70 "format, e.g. '{\"section\": {\"parameter\": \"value\", ...}}'"
70 "format, e.g. '{\"section\": {\"parameter\": \"value\", ...}}'"
71 )
71 )
72 )
72 )
73 vcsgroup.addoption(
73 vcsgroup.addoption(
74 '--vcsserver-port', action='store', type=int,
74 '--vcsserver-port', action='store', type=int,
75 default=None, help=(
75 default=None, help=(
76 "Allows to set the port of the vcsserver. Useful when testing "
76 "Allows to set the port of the vcsserver. Useful when testing "
77 "against an already running server and random ports cause "
77 "against an already running server and random ports cause "
78 "trouble."))
78 "trouble."))
79 parser.addini(
79 parser.addini(
80 'vcsserver_config_http',
80 'vcsserver_config_http',
81 "Start the HTTP VCSServer with the specified config file.")
81 "Start the HTTP VCSServer with the specified config file.")
82 parser.addini(
82 parser.addini(
83 'vcsserver_protocol',
83 'vcsserver_protocol',
84 "Start the VCSServer with HTTP protocol support.")
84 "Start the VCSServer with HTTP protocol support.")
85
85
86
86
87 @pytest.fixture(scope='session')
87 @pytest.fixture(scope='session')
88 def vcsserver(request, vcsserver_port, vcsserver_factory):
88 def vcsserver(request, vcsserver_port, vcsserver_factory):
89 """
89 """
90 Session scope VCSServer.
90 Session scope VCSServer.
91
91
92 Tests wich need the VCSServer have to rely on this fixture in order
92 Tests wich need the VCSServer have to rely on this fixture in order
93 to ensure it will be running.
93 to ensure it will be running.
94
94
95 For specific needs, the fixture vcsserver_factory can be used. It allows to
95 For specific needs, the fixture vcsserver_factory can be used. It allows to
96 adjust the configuration file for the test run.
96 adjust the configuration file for the test run.
97
97
98 Command line args:
98 Command line args:
99
99
100 --without-vcsserver: Allows to switch this fixture off. You have to
100 --without-vcsserver: Allows to switch this fixture off. You have to
101 manually start the server.
101 manually start the server.
102
102
103 --vcsserver-port: Will expect the VCSServer to listen on this port.
103 --vcsserver-port: Will expect the VCSServer to listen on this port.
104 """
104 """
105
105
106 if not request.config.getoption('with_vcsserver'):
106 if not request.config.getoption('with_vcsserver'):
107 return None
107 return None
108
108
109 return vcsserver_factory(
109 return vcsserver_factory(
110 request, vcsserver_port=vcsserver_port)
110 request, vcsserver_port=vcsserver_port)
111
111
112
112
113 @pytest.fixture(scope='session')
113 @pytest.fixture(scope='session')
114 def vcsserver_factory(tmpdir_factory):
114 def vcsserver_factory(tmpdir_factory):
115 """
115 """
116 Use this if you need a running vcsserver with a special configuration.
116 Use this if you need a running vcsserver with a special configuration.
117 """
117 """
118
118
119 def factory(request, overrides=(), vcsserver_port=None,
119 def factory(request, overrides=(), vcsserver_port=None,
120 log_file=None):
120 log_file=None):
121
121
122 if vcsserver_port is None:
122 if vcsserver_port is None:
123 vcsserver_port = get_available_port()
123 vcsserver_port = get_available_port()
124
124
125 overrides = list(overrides)
125 overrides = list(overrides)
126 overrides.append({'server:main': {'port': vcsserver_port}})
126 overrides.append({'server:main': {'port': vcsserver_port}})
127
127
128 option_name = 'vcsserver_config_http'
128 option_name = 'vcsserver_config_http'
129 override_option_name = 'vcsserver_config_override'
129 override_option_name = 'vcsserver_config_override'
130 config_file = get_config(
130 config_file = get_config(
131 request.config, option_name=option_name,
131 request.config, option_name=option_name,
132 override_option_name=override_option_name, overrides=overrides,
132 override_option_name=override_option_name, overrides=overrides,
133 basetemp=tmpdir_factory.getbasetemp().strpath,
133 basetemp=tmpdir_factory.getbasetemp().strpath,
134 prefix='test_vcs_')
134 prefix='test_vcs_')
135
135
136 server = RcVCSServer(config_file, log_file)
136 server = RcVCSServer(config_file, log_file)
137 server.start()
137 server.start()
138
138
139 @request.addfinalizer
139 @request.addfinalizer
140 def cleanup():
140 def cleanup():
141 server.shutdown()
141 server.shutdown()
142
142
143 server.wait_until_ready()
143 server.wait_until_ready()
144 return server
144 return server
145
145
146 return factory
146 return factory
147
147
148
148
149 def is_cygwin():
149 def is_cygwin():
150 return 'cygwin' in platform.system().lower()
150 return 'cygwin' in platform.system().lower()
151
151
152
152
153 def _use_log_level(config):
153 def _use_log_level(config):
154 level = config.getoption('test_loglevel') or 'critical'
154 level = config.getoption('test_loglevel') or 'critical'
155 return level.upper()
155 return level.upper()
156
156
157
157
158 @pytest.fixture(scope='session')
158 @pytest.fixture(scope='session')
159 def ini_config(request, tmpdir_factory, rcserver_port, vcsserver_port):
159 def ini_config(request, tmpdir_factory, rcserver_port, vcsserver_port):
160 option_name = 'pyramid_config'
160 option_name = 'pyramid_config'
161 log_level = _use_log_level(request.config)
161 log_level = _use_log_level(request.config)
162
162
163 overrides = [
163 overrides = [
164 {'server:main': {'port': rcserver_port}},
164 {'server:main': {'port': rcserver_port}},
165 {'app:main': {
165 {'app:main': {
166 'vcs.server': 'localhost:%s' % vcsserver_port,
166 'vcs.server': 'localhost:%s' % vcsserver_port,
167 # johbo: We will always start the VCSServer on our own based on the
167 # johbo: We will always start the VCSServer on our own based on the
168 # fixtures of the test cases. For the test run it must always be
168 # fixtures of the test cases. For the test run it must always be
169 # off in the INI file.
169 # off in the INI file.
170 'vcs.start_server': 'false',
170 'vcs.start_server': 'false',
171
171
172 'vcs.server.protocol': 'http',
172 'vcs.server.protocol': 'http',
173 'vcs.scm_app_implementation': 'http',
173 'vcs.scm_app_implementation': 'http',
174 'vcs.hooks.protocol': 'http',
174 'vcs.hooks.protocol': 'http',
175 'vcs.hooks.host': '127.0.0.1',
175 'vcs.hooks.host': '127.0.0.1',
176 }},
176 }},
177
177
178 {'handler_console': {
178 {'handler_console': {
179 'class ': 'StreamHandler',
179 'class ': 'StreamHandler',
180 'args ': '(sys.stderr,)',
180 'args ': '(sys.stderr,)',
181 'level': log_level,
181 'level': log_level,
182 }},
182 }},
183
183
184 ]
184 ]
185
185
186 filename = get_config(
186 filename = get_config(
187 request.config, option_name=option_name,
187 request.config, option_name=option_name,
188 override_option_name='{}_override'.format(option_name),
188 override_option_name='{}_override'.format(option_name),
189 overrides=overrides,
189 overrides=overrides,
190 basetemp=tmpdir_factory.getbasetemp().strpath,
190 basetemp=tmpdir_factory.getbasetemp().strpath,
191 prefix='test_rce_')
191 prefix='test_rce_')
192 return filename
192 return filename
193
193
194
194
195 @pytest.fixture(scope='session')
195 @pytest.fixture(scope='session')
196 def ini_settings(ini_config):
196 def ini_settings(ini_config):
197 ini_path = ini_config
197 ini_path = ini_config
198 return get_app_config(ini_path)
198 return get_app_config(ini_path)
199
199
200
200
201 def get_available_port(min_port=40000, max_port=55555):
201 def get_available_port(min_port=40000, max_port=55555):
202 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
202 from rhodecode.lib.utils2 import get_available_port as _get_port
203 hostname = '127.0.0.1'
203 return _get_port(min_port, max_port)
204
205 for _ in range(min_port, max_port):
206 pick_port = random.randint(min_port, max_port)
207 try:
208 sock.bind((hostname, pick_port))
209 sock.close()
210 del sock
211 return pick_port
212 except OSError:
213 pass
214
204
215
205
216 @pytest.fixture(scope='session')
206 @pytest.fixture(scope='session')
217 def rcserver_port(request):
207 def rcserver_port(request):
218 port = get_available_port()
208 port = get_available_port()
219 print('Using rcserver port {}'.format(port))
209 print('Using rcserver port {}'.format(port))
220 return port
210 return port
221
211
222
212
223 @pytest.fixture(scope='session')
213 @pytest.fixture(scope='session')
224 def vcsserver_port(request):
214 def vcsserver_port(request):
225 port = request.config.getoption('--vcsserver-port')
215 port = request.config.getoption('--vcsserver-port')
226 if port is None:
216 if port is None:
227 port = get_available_port()
217 port = get_available_port()
228 print('Using vcsserver port {}'.format(port))
218 print('Using vcsserver port {}'.format(port))
229 return port
219 return port
230
220
231
221
232 @pytest.fixture(scope='session')
222 @pytest.fixture(scope='session')
233 def available_port_factory():
223 def available_port_factory():
234 """
224 """
235 Returns a callable which returns free port numbers.
225 Returns a callable which returns free port numbers.
236 """
226 """
237 return get_available_port
227 return get_available_port
238
228
239
229
240 @pytest.fixture()
230 @pytest.fixture()
241 def available_port(available_port_factory):
231 def available_port(available_port_factory):
242 """
232 """
243 Gives you one free port for the current test.
233 Gives you one free port for the current test.
244
234
245 Uses "available_port_factory" to retrieve the port.
235 Uses "available_port_factory" to retrieve the port.
246 """
236 """
247 return available_port_factory()
237 return available_port_factory()
248
238
249
239
250 @pytest.fixture(scope='session')
240 @pytest.fixture(scope='session')
251 def testini_factory(tmpdir_factory, ini_config):
241 def testini_factory(tmpdir_factory, ini_config):
252 """
242 """
253 Factory to create an INI file based on TestINI.
243 Factory to create an INI file based on TestINI.
254
244
255 It will make sure to place the INI file in the correct directory.
245 It will make sure to place the INI file in the correct directory.
256 """
246 """
257 basetemp = tmpdir_factory.getbasetemp().strpath
247 basetemp = tmpdir_factory.getbasetemp().strpath
258 return TestIniFactory(basetemp, ini_config)
248 return TestIniFactory(basetemp, ini_config)
259
249
260
250
261 class TestIniFactory(object):
251 class TestIniFactory(object):
262
252
263 def __init__(self, basetemp, template_ini):
253 def __init__(self, basetemp, template_ini):
264 self._basetemp = basetemp
254 self._basetemp = basetemp
265 self._template_ini = template_ini
255 self._template_ini = template_ini
266
256
267 def __call__(self, ini_params, new_file_prefix='test'):
257 def __call__(self, ini_params, new_file_prefix='test'):
268 ini_file = TestINI(
258 ini_file = TestINI(
269 self._template_ini, ini_params=ini_params,
259 self._template_ini, ini_params=ini_params,
270 new_file_prefix=new_file_prefix, dir=self._basetemp)
260 new_file_prefix=new_file_prefix, dir=self._basetemp)
271 result = ini_file.create()
261 result = ini_file.create()
272 return result
262 return result
273
263
274
264
275 def get_config(
265 def get_config(
276 config, option_name, override_option_name, overrides=None,
266 config, option_name, override_option_name, overrides=None,
277 basetemp=None, prefix='test'):
267 basetemp=None, prefix='test'):
278 """
268 """
279 Find a configuration file and apply overrides for the given `prefix`.
269 Find a configuration file and apply overrides for the given `prefix`.
280 """
270 """
281 config_file = (
271 config_file = (
282 config.getoption(option_name) or config.getini(option_name))
272 config.getoption(option_name) or config.getini(option_name))
283 if not config_file:
273 if not config_file:
284 pytest.exit(
274 pytest.exit(
285 "Configuration error, could not extract {}.".format(option_name))
275 "Configuration error, could not extract {}.".format(option_name))
286
276
287 overrides = overrides or []
277 overrides = overrides or []
288 config_override = config.getoption(override_option_name)
278 config_override = config.getoption(override_option_name)
289 if config_override:
279 if config_override:
290 overrides.append(config_override)
280 overrides.append(config_override)
291 temp_ini_file = TestINI(
281 temp_ini_file = TestINI(
292 config_file, ini_params=overrides, new_file_prefix=prefix,
282 config_file, ini_params=overrides, new_file_prefix=prefix,
293 dir=basetemp)
283 dir=basetemp)
294
284
295 return temp_ini_file.create()
285 return temp_ini_file.create()
General Comments 0
You need to be logged in to leave comments. Login now