utils.py
550 lines
| 17.8 KiB
| text/x-python
|
PythonLexer
r5088 | # Copyright (C) 2010-2023 RhodeCode GmbH | |||
r1 | # | |||
# This program is free software: you can redistribute it and/or modify | ||||
# it under the terms of the GNU Affero General Public License, version 3 | ||||
# (only), as published by the Free Software Foundation. | ||||
# | ||||
# This program is distributed in the hope that it will be useful, | ||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||||
# GNU General Public License for more details. | ||||
# | ||||
# You should have received a copy of the GNU Affero General Public License | ||||
# along with this program. If not, see <http://www.gnu.org/licenses/>. | ||||
# | ||||
# This program is dual-licensed. If you wish to learn more about the | ||||
# RhodeCode Enterprise Edition, including its added features, Support services, | ||||
# and proprietary license terms, please see https://rhodecode.com/licenses/ | ||||
import threading | ||||
import time | ||||
r5190 | import sys | |||
r1 | import logging | |||
import os.path | ||||
r4926 | import subprocess | |||
r1256 | import tempfile | |||
r4994 | import urllib.request | |||
import urllib.error | ||||
import urllib.parse | ||||
r1525 | from lxml.html import fromstring, tostring | |||
from lxml.cssselect import CSSSelector | ||||
r4914 | from urllib.parse import unquote_plus | |||
r1906 | import webob | |||
r1 | ||||
r4919 | from webtest.app import TestResponse, TestApp | |||
r5190 | ||||
r1256 | ||||
r1 | import pytest | |||
r4986 | ||||
r1256 | from rhodecode.model.db import User, Repository | |||
r1 | from rhodecode.model.meta import Session | |||
from rhodecode.model.scm import ScmModel | ||||
from rhodecode.lib.vcs.backends.svn.repository import SubversionRepository | ||||
r1259 | from rhodecode.lib.vcs.backends.base import EmptyCommit | |||
r5607 | from rhodecode.tests import login_user_session, console_printer | |||
from rhodecode.authentication import AuthenticationPluginRegistry | ||||
from rhodecode.model.settings import SettingsModel | ||||
r1 | ||||
log = logging.getLogger(__name__) | ||||
r5607 | def console_printer_utils(msg): | |||
console_printer(f" :white_check_mark: [green]test-utils[/green] {msg}") | ||||
def get_rc_testdata(): | ||||
try: | ||||
import rc_testdata | ||||
except ImportError: | ||||
raise ImportError('Failed to import rc_testdata, ' | ||||
'please make sure this package is installed from requirements_test.txt') | ||||
return rc_testdata | ||||
r5190 | ||||
r1256 | class CustomTestResponse(TestResponse): | |||
r3773 | ||||
r1256 | def _save_output(self, out): | |||
r4994 | f = tempfile.NamedTemporaryFile(mode='w', delete=False, prefix='rc-test-', suffix='.html') | |||
r1256 | f.write(out) | |||
return f.name | ||||
def mustcontain(self, *strings, **kw): | ||||
""" | ||||
r5190 | Assert that the response contains all the strings passed | |||
r1256 | in as arguments. | |||
Equivalent to:: | ||||
assert string in res | ||||
""" | ||||
r3773 | print_body = kw.pop('print_body', False) | |||
r5190 | ||||
r1256 | if 'no' in kw: | |||
no = kw['no'] | ||||
del kw['no'] | ||||
r4919 | if isinstance(no, str): | |||
r1256 | no = [no] | |||
else: | ||||
no = [] | ||||
if kw: | ||||
r5087 | raise TypeError(f"The only keyword argument allowed is 'no' got {kw}") | |||
r1256 | ||||
f = self._save_output(str(self)) | ||||
for s in strings: | ||||
r4994 | if s not in self: | |||
r5607 | console_printer_utils(f"Actual response (no {s!r}):") | |||
console_printer_utils(f"body output saved as `{f}`") | ||||
r3773 | if print_body: | |||
r5607 | console_printer_utils(str(self)) | |||
r5087 | raise IndexError(f"Body does not contain string {s!r}, body output saved as {f}") | |||
r1256 | ||||
for no_s in no: | ||||
if no_s in self: | ||||
r5607 | console_printer_utils(f"Actual response (has {no_s!r})") | |||
console_printer_utils(f"body output saved as `{f}`") | ||||
r3773 | if print_body: | |||
r5607 | console_printer_utils(str(self)) | |||
r5087 | raise IndexError(f"Body contains bad string {no_s!r}, body output saved as {f}") | |||
r1256 | ||||
def assert_response(self): | ||||
return AssertResponse(self) | ||||
r1774 | def get_session_from_response(self): | |||
""" | ||||
r2351 | This returns the session from a response object. | |||
r1774 | """ | |||
r3765 | from rhodecode.lib.rc_beaker import session_factory_from_settings | |||
r3432 | session = session_factory_from_settings(self.test_app._pyramid_settings) | |||
r2351 | return session(self.request) | |||
r1774 | ||||
r1256 | ||||
r1906 | class TestRequest(webob.BaseRequest): | |||
r1256 | ||||
r4994 | # for py.test, so it doesn't try to run this tas by name starting with test... | |||
r1256 | disabled = True | |||
ResponseClass = CustomTestResponse | ||||
r2351 | def add_response_callback(self, callback): | |||
pass | ||||
r4994 | @classmethod | |||
def blank(cls, path, environ=None, base_url=None, | ||||
headers=None, POST=None, **kw): | ||||
if not path.isascii(): | ||||
# our custom quote path if it contains non-ascii chars | ||||
path = urllib.parse.quote(path) | ||||
return super(TestRequest, cls).blank( | ||||
path, environ=environ, base_url=base_url, headers=headers, POST=POST, **kw) | ||||
r1256 | ||||
class CustomTestApp(TestApp): | ||||
""" | ||||
r3465 | Custom app to make mustcontain more Useful, and extract special methods | |||
r1256 | """ | |||
RequestClass = TestRequest | ||||
r2374 | rc_login_data = {} | |||
rc_current_session = None | ||||
def login(self, username=None, password=None): | ||||
from rhodecode.lib import auth | ||||
if username and password: | ||||
session = login_user_session(self, username, password) | ||||
else: | ||||
session = login_user_session(self) | ||||
self.rc_login_data['csrf_token'] = auth.get_csrf_token(session) | ||||
self.rc_current_session = session | ||||
return session['rhodecode_user'] | ||||
@property | ||||
def csrf_token(self): | ||||
return self.rc_login_data['csrf_token'] | ||||
r1256 | ||||
r3432 | @property | |||
def _pyramid_registry(self): | ||||
return self.app.config.registry | ||||
@property | ||||
def _pyramid_settings(self): | ||||
return self._pyramid_registry.settings | ||||
r5087 | def do_request(self, req, status=None, expect_errors=None): | |||
# you can put custom code here | ||||
return super().do_request(req, status, expect_errors) | ||||
r1256 | ||||
r1 | def set_anonymous_access(enabled): | |||
"""(Dis)allows anonymous access depending on parameter `enabled`""" | ||||
user = User.get_default_user() | ||||
user.active = enabled | ||||
Session().add(user) | ||||
Session().commit() | ||||
r1785 | time.sleep(1.5) # must sleep for cache (1s to expire) | |||
r1 | log.info('anonymous access is now: %s', enabled) | |||
assert enabled == User.get_default_user().active, ( | ||||
'Cannot set anonymous access') | ||||
def check_xfail_backends(node, backend_alias): | ||||
# Using "xfail_backends" here intentionally, since this marks work | ||||
# which is "to be done" soon. | ||||
r3098 | skip_marker = node.get_closest_marker('xfail_backends') | |||
r1 | if skip_marker and backend_alias in skip_marker.args: | |||
msg = "Support for backend %s to be developed." % (backend_alias, ) | ||||
msg = skip_marker.kwargs.get('reason', msg) | ||||
pytest.xfail(msg) | ||||
def check_skip_backends(node, backend_alias): | ||||
# Using "skip_backends" here intentionally, since this marks work which is | ||||
# not supported. | ||||
r3098 | skip_marker = node.get_closest_marker('skip_backends') | |||
r1 | if skip_marker and backend_alias in skip_marker.args: | |||
msg = "Feature not supported for backend %s." % (backend_alias, ) | ||||
msg = skip_marker.kwargs.get('reason', msg) | ||||
pytest.skip(msg) | ||||
def extract_git_repo_from_dump(dump_name, repo_name): | ||||
"""Create git repo `repo_name` from dump `dump_name`.""" | ||||
repos_path = ScmModel().repos_path | ||||
target_path = os.path.join(repos_path, repo_name) | ||||
r5607 | rc_testdata = get_rc_testdata() | |||
r1 | rc_testdata.extract_git_dump(dump_name, target_path) | |||
return target_path | ||||
def extract_hg_repo_from_dump(dump_name, repo_name): | ||||
"""Create hg repo `repo_name` from dump `dump_name`.""" | ||||
repos_path = ScmModel().repos_path | ||||
target_path = os.path.join(repos_path, repo_name) | ||||
r5607 | rc_testdata = get_rc_testdata() | |||
r1 | rc_testdata.extract_hg_dump(dump_name, target_path) | |||
return target_path | ||||
def extract_svn_repo_from_dump(dump_name, repo_name): | ||||
"""Create a svn repo `repo_name` from dump `dump_name`.""" | ||||
repos_path = ScmModel().repos_path | ||||
target_path = os.path.join(repos_path, repo_name) | ||||
SubversionRepository(target_path, create=True) | ||||
_load_svn_dump_into_repo(dump_name, target_path) | ||||
return target_path | ||||
def assert_message_in_log(log_records, message, levelno, module): | ||||
messages = [ | ||||
r.message for r in log_records | ||||
if r.module == module and r.levelno == levelno | ||||
] | ||||
assert message in messages | ||||
def _load_svn_dump_into_repo(dump_name, repo_path): | ||||
""" | ||||
Utility to populate a svn repository with a named dump | ||||
Currently the dumps are in rc_testdata. They might later on be | ||||
integrated with the main repository once they stabilize more. | ||||
""" | ||||
r5607 | rc_testdata = get_rc_testdata() | |||
r1 | dump = rc_testdata.load_svn_dump(dump_name) | |||
r4926 | load_dump = subprocess.Popen( | |||
r1 | ['svnadmin', 'load', repo_path], | |||
r4926 | stdin=subprocess.PIPE, stdout=subprocess.PIPE, | |||
stderr=subprocess.PIPE) | ||||
r1 | out, err = load_dump.communicate(dump) | |||
if load_dump.returncode != 0: | ||||
log.error("Output of load_dump command: %s", out) | ||||
log.error("Error output of load_dump command: %s", err) | ||||
r5607 | raise Exception(f'Failed to load dump "{dump_name}" into repository at path "{repo_path}".') | |||
r1 | ||||
class AssertResponse(object): | ||||
""" | ||||
Utility that helps to assert things about a given HTML response. | ||||
""" | ||||
def __init__(self, response): | ||||
self.response = response | ||||
r1239 | def get_imports(self): | |||
return fromstring, tostring, CSSSelector | ||||
r1 | def one_element_exists(self, css_selector): | |||
self.get_element(css_selector) | ||||
def no_element_exists(self, css_selector): | ||||
assert not self._get_elements(css_selector) | ||||
def element_equals_to(self, css_selector, expected_content): | ||||
element = self.get_element(css_selector) | ||||
element_text = self._element_to_string(element) | ||||
r4994 | ||||
r1 | assert expected_content in element_text | |||
def element_contains(self, css_selector, expected_content): | ||||
element = self.get_element(css_selector) | ||||
assert expected_content in element.text_content() | ||||
Martin Bornhold
|
r1046 | def element_value_contains(self, css_selector, expected_content): | ||
element = self.get_element(css_selector) | ||||
assert expected_content in element.value | ||||
r1 | def contains_one_link(self, link_text, href): | |||
r1239 | fromstring, tostring, CSSSelector = self.get_imports() | |||
r1 | doc = fromstring(self.response.body) | |||
sel = CSSSelector('a[href]') | ||||
elements = [ | ||||
e for e in sel(doc) if e.text_content().strip() == link_text] | ||||
assert len(elements) == 1, "Did not find link or found multiple links" | ||||
self._ensure_url_equal(elements[0].attrib.get('href'), href) | ||||
def contains_one_anchor(self, anchor_id): | ||||
r1239 | fromstring, tostring, CSSSelector = self.get_imports() | |||
r1 | doc = fromstring(self.response.body) | |||
sel = CSSSelector('#' + anchor_id) | ||||
elements = sel(doc) | ||||
r1442 | assert len(elements) == 1, 'cannot find 1 element {}'.format(anchor_id) | |||
r1 | ||||
def _ensure_url_equal(self, found, expected): | ||||
assert _Url(found) == _Url(expected) | ||||
def get_element(self, css_selector): | ||||
elements = self._get_elements(css_selector) | ||||
r1442 | assert len(elements) == 1, 'cannot find 1 element {}'.format(css_selector) | |||
r1 | return elements[0] | |||
r423 | def get_elements(self, css_selector): | |||
return self._get_elements(css_selector) | ||||
r1 | def _get_elements(self, css_selector): | |||
r1239 | fromstring, tostring, CSSSelector = self.get_imports() | |||
r1 | doc = fromstring(self.response.body) | |||
sel = CSSSelector(css_selector) | ||||
elements = sel(doc) | ||||
return elements | ||||
def _element_to_string(self, element): | ||||
r1239 | fromstring, tostring, CSSSelector = self.get_imports() | |||
r4994 | return tostring(element, encoding='unicode') | |||
r1 | ||||
class _Url(object): | ||||
""" | ||||
A url object that can be compared with other url orbjects | ||||
without regard to the vagaries of encoding, escaping, and ordering | ||||
of parameters in query strings. | ||||
Inspired by | ||||
http://stackoverflow.com/questions/5371992/comparing-two-urls-in-python | ||||
""" | ||||
def __init__(self, url): | ||||
r4919 | parts = urllib.parse.urlparse(url) | |||
_query = frozenset(urllib.parse.parse_qsl(parts.query)) | ||||
r1 | _path = unquote_plus(parts.path) | |||
parts = parts._replace(query=_query, path=_path) | ||||
self.parts = parts | ||||
def __eq__(self, other): | ||||
return self.parts == other.parts | ||||
def __hash__(self): | ||||
return hash(self.parts) | ||||
def run_test_concurrently(times, raise_catched_exc=True): | ||||
""" | ||||
Add this decorator to small pieces of code that you want to test | ||||
concurrently | ||||
ex: | ||||
@test_concurrently(25) | ||||
def my_test_function(): | ||||
... | ||||
""" | ||||
def test_concurrently_decorator(test_func): | ||||
def wrapper(*args, **kwargs): | ||||
exceptions = [] | ||||
def call_test_func(): | ||||
try: | ||||
test_func(*args, **kwargs) | ||||
r1239 | except Exception as e: | |||
r1 | exceptions.append(e) | |||
if raise_catched_exc: | ||||
raise | ||||
threads = [] | ||||
for i in range(times): | ||||
threads.append(threading.Thread(target=call_test_func)) | ||||
for t in threads: | ||||
t.start() | ||||
for t in threads: | ||||
t.join() | ||||
if exceptions: | ||||
raise Exception( | ||||
'test_concurrently intercepted %s exceptions: %s' % ( | ||||
len(exceptions), exceptions)) | ||||
return wrapper | ||||
return test_concurrently_decorator | ||||
def wait_for_url(url, timeout=10): | ||||
""" | ||||
Wait until URL becomes reachable. | ||||
It polls the URL until the timeout is reached or it became reachable. | ||||
If will call to `py.test.fail` in case the URL is not reachable. | ||||
""" | ||||
timeout = time.time() + timeout | ||||
last = 0 | ||||
wait = 0.1 | ||||
r1239 | while timeout > last: | |||
r1 | last = time.time() | |||
r5023 | if is_url_reachable(url, log_exc=False): | |||
r1 | break | |||
r1239 | elif (last + wait) > time.time(): | |||
r1 | # Go to sleep because not enough time has passed since last check. | |||
time.sleep(wait) | ||||
else: | ||||
r5023 | pytest.fail(f"Timeout while waiting for URL {url}") | |||
r1 | ||||
r5087 | def is_url_reachable(url: str, log_exc: bool = False) -> bool: | |||
r1 | try: | |||
r4914 | urllib.request.urlopen(url) | |||
except urllib.error.URLError: | ||||
r5023 | if log_exc: | |||
r5087 | log.exception(f'URL `{url}` reach error') | |||
r1 | return False | |||
return True | ||||
r41 | ||||
Martin Bornhold
|
r486 | def repo_on_filesystem(repo_name): | ||
from rhodecode.lib import vcs | ||||
from rhodecode.tests import TESTS_TMP_PATH | ||||
repo = vcs.get_vcs_instance( | ||||
os.path.join(TESTS_TMP_PATH, repo_name), create=False) | ||||
return repo is not None | ||||
r1259 | ||||
def commit_change( | ||||
r5198 | repo, filename: bytes, content: bytes, message, vcs_type, parent=None, branch=None, newfile=False): | |||
r1259 | from rhodecode.tests import TEST_USER_ADMIN_LOGIN | |||
repo = Repository.get_by_repo_name(repo) | ||||
_commit = parent | ||||
if not parent: | ||||
_commit = EmptyCommit(alias=vcs_type) | ||||
if newfile: | ||||
nodes = { | ||||
filename: { | ||||
'content': content | ||||
} | ||||
} | ||||
commit = ScmModel().create_nodes( | ||||
user=TEST_USER_ADMIN_LOGIN, repo=repo, | ||||
message=message, | ||||
nodes=nodes, | ||||
parent_commit=_commit, | ||||
r4986 | author=f'{TEST_USER_ADMIN_LOGIN} <admin@rhodecode.com>', | |||
r1259 | ) | |||
else: | ||||
commit = ScmModel().commit_change( | ||||
repo=repo.scm_instance(), repo_name=repo.repo_name, | ||||
commit=parent, user=TEST_USER_ADMIN_LOGIN, | ||||
r4986 | author=f'{TEST_USER_ADMIN_LOGIN} <admin@rhodecode.com>', | |||
r1259 | message=message, | |||
content=content, | ||||
r5198 | f_path=filename, | |||
branch=branch | ||||
r1259 | ) | |||
return commit | ||||
r2827 | ||||
def permission_update_data_generator(csrf_token, default=None, grant=None, revoke=None): | ||||
if not default: | ||||
raise ValueError('Permission for default user must be given') | ||||
form_data = [( | ||||
'csrf_token', csrf_token | ||||
)] | ||||
# add default | ||||
form_data.extend([ | ||||
('u_perm_1', default) | ||||
]) | ||||
if grant: | ||||
for cnt, (obj_id, perm, obj_name, obj_type) in enumerate(grant, 1): | ||||
form_data.extend([ | ||||
('perm_new_member_perm_new{}'.format(cnt), perm), | ||||
('perm_new_member_id_new{}'.format(cnt), obj_id), | ||||
('perm_new_member_name_new{}'.format(cnt), obj_name), | ||||
('perm_new_member_type_new{}'.format(cnt), obj_type), | ||||
]) | ||||
if revoke: | ||||
for obj_id, obj_type in revoke: | ||||
form_data.extend([ | ||||
('perm_del_member_id_{}'.format(obj_id), obj_id), | ||||
('perm_del_member_type_{}'.format(obj_id), obj_type), | ||||
]) | ||||
return form_data | ||||
r5607 | ||||
class AuthPluginManager: | ||||
def cleanup(self): | ||||
self._enable_plugins(['egg:rhodecode-enterprise-ce#rhodecode']) | ||||
def enable(self, plugins_list, override=None): | ||||
return self._enable_plugins(plugins_list, override) | ||||
@classmethod | ||||
def _enable_plugins(cls, plugins_list, override: object = None): | ||||
override = override or {} | ||||
params = { | ||||
'auth_plugins': ','.join(plugins_list), | ||||
} | ||||
# helper translate some names to others, to fix settings code | ||||
name_map = { | ||||
'token': 'authtoken' | ||||
} | ||||
log.debug('enable_auth_plugins: enabling following auth-plugins: %s', plugins_list) | ||||
for module in plugins_list: | ||||
plugin_name = module.partition('#')[-1] | ||||
if plugin_name in name_map: | ||||
plugin_name = name_map[plugin_name] | ||||
enabled_plugin = f'auth_{plugin_name}_enabled' | ||||
cache_ttl = f'auth_{plugin_name}_cache_ttl' | ||||
# default params that are needed for each plugin, | ||||
# `enabled` and `cache_ttl` | ||||
params.update({ | ||||
enabled_plugin: True, | ||||
cache_ttl: 0 | ||||
}) | ||||
if override.get: | ||||
params.update(override.get(module, {})) | ||||
validated_params = params | ||||
for k, v in validated_params.items(): | ||||
setting = SettingsModel().create_or_update_setting(k, v) | ||||
Session().add(setting) | ||||
Session().commit() | ||||
AuthenticationPluginRegistry.invalidate_auth_plugins_cache(hard=True) | ||||
enabled_plugins = SettingsModel().get_auth_plugins() | ||||
assert plugins_list == enabled_plugins | ||||