# Copyright (C) 2010-2024 RhodeCode GmbH # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License, version 3 # (only), as published by the Free Software Foundation. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # # This program is dual-licensed. If you wish to learn more about the # RhodeCode Enterprise Edition, including its added features, Support services, # and proprietary license terms, please see https://rhodecode.com/licenses/ import threading import time import sys import logging import os.path import subprocess import tempfile import urllib.request import urllib.error import urllib.parse from lxml.html import fromstring, tostring from lxml.cssselect import CSSSelector from urllib.parse import unquote_plus import webob from webtest.app import TestResponse, TestApp import pytest from rhodecode.model.db import User, Repository from rhodecode.model.meta import Session from rhodecode.model.scm import ScmModel from rhodecode.lib.vcs.backends.svn.repository import SubversionRepository from rhodecode.lib.vcs.backends.base import EmptyCommit from rhodecode.tests import login_user_session, console_printer from rhodecode.authentication import AuthenticationPluginRegistry from rhodecode.model.settings import SettingsModel log = logging.getLogger(__name__) def console_printer_utils(msg): console_printer(f" :white_check_mark: [green]test-utils[/green] {msg}") def get_rc_testdata(): try: import rc_testdata except ImportError: raise ImportError('Failed to import rc_testdata, ' 'please make sure this package is installed from requirements_test.txt') return rc_testdata class CustomTestResponse(TestResponse): def _save_output(self, out): f = tempfile.NamedTemporaryFile(mode='w', delete=False, prefix='rc-test-', suffix='.html') f.write(out) return f.name def mustcontain(self, *strings, **kw): """ Assert that the response contains all the strings passed in as arguments. Equivalent to:: assert string in res """ print_body = kw.pop('print_body', False) if 'no' in kw: no = kw['no'] del kw['no'] if isinstance(no, str): no = [no] else: no = [] if kw: raise TypeError(f"The only keyword argument allowed is 'no' got {kw}") f = self._save_output(str(self)) for s in strings: if s not in self: console_printer_utils(f"Actual response (no {s!r}):") console_printer_utils(f"body output saved as `{f}`") if print_body: console_printer_utils(str(self)) raise IndexError(f"Body does not contain string {s!r}, body output saved as {f}") for no_s in no: if no_s in self: console_printer_utils(f"Actual response (has {no_s!r})") console_printer_utils(f"body output saved as `{f}`") if print_body: console_printer_utils(str(self)) raise IndexError(f"Body contains bad string {no_s!r}, body output saved as {f}") def assert_response(self): return AssertResponse(self) def get_session_from_response(self): """ This returns the session from a response object. """ from rhodecode.lib.rc_beaker import session_factory_from_settings session = session_factory_from_settings(self.test_app._pyramid_settings) return session(self.request) class TestRequest(webob.BaseRequest): # for py.test, so it doesn't try to run this tas by name starting with test... disabled = True ResponseClass = CustomTestResponse def add_response_callback(self, callback): pass @classmethod def blank(cls, path, environ=None, base_url=None, headers=None, POST=None, **kw): if not path.isascii(): # our custom quote path if it contains non-ascii chars path = urllib.parse.quote(path) return super(TestRequest, cls).blank( path, environ=environ, base_url=base_url, headers=headers, POST=POST, **kw) class CustomTestApp(TestApp): """ Custom app to make mustcontain more Useful, and extract special methods """ RequestClass = TestRequest rc_login_data = {} rc_current_session = None def login(self, username=None, password=None): from rhodecode.lib import auth if username and password: session = login_user_session(self, username, password) else: session = login_user_session(self) self.rc_login_data['csrf_token'] = auth.get_csrf_token(session) self.rc_current_session = session return session['rhodecode_user'] @property def csrf_token(self): return self.rc_login_data['csrf_token'] @property def _pyramid_registry(self): return self.app.config.registry @property def _pyramid_settings(self): return self._pyramid_registry.settings def do_request(self, req, status=None, expect_errors=None): # you can put custom code here return super().do_request(req, status, expect_errors) def set_anonymous_access(enabled): """(Dis)allows anonymous access depending on parameter `enabled`""" user = User.get_default_user() user.active = enabled Session().add(user) Session().commit() time.sleep(1.5) # must sleep for cache (1s to expire) log.info('anonymous access is now: %s', enabled) assert enabled == User.get_default_user().active, ( 'Cannot set anonymous access') def check_xfail_backends(node, backend_alias): # Using "xfail_backends" here intentionally, since this marks work # which is "to be done" soon. skip_marker = node.get_closest_marker('xfail_backends') if skip_marker and backend_alias in skip_marker.args: msg = "Support for backend %s to be developed." % (backend_alias, ) msg = skip_marker.kwargs.get('reason', msg) pytest.xfail(msg) def check_skip_backends(node, backend_alias): # Using "skip_backends" here intentionally, since this marks work which is # not supported. skip_marker = node.get_closest_marker('skip_backends') if skip_marker and backend_alias in skip_marker.args: msg = "Feature not supported for backend %s." % (backend_alias, ) msg = skip_marker.kwargs.get('reason', msg) pytest.skip(msg) def extract_git_repo_from_dump(dump_name, repo_name): """Create git repo `repo_name` from dump `dump_name`.""" repos_path = ScmModel().repos_path target_path = os.path.join(repos_path, repo_name) rc_testdata = get_rc_testdata() rc_testdata.extract_git_dump(dump_name, target_path) return target_path def extract_hg_repo_from_dump(dump_name, repo_name): """Create hg repo `repo_name` from dump `dump_name`.""" repos_path = ScmModel().repos_path target_path = os.path.join(repos_path, repo_name) rc_testdata = get_rc_testdata() rc_testdata.extract_hg_dump(dump_name, target_path) return target_path def extract_svn_repo_from_dump(dump_name, repo_name): """Create a svn repo `repo_name` from dump `dump_name`.""" repos_path = ScmModel().repos_path target_path = os.path.join(repos_path, repo_name) SubversionRepository(target_path, create=True) _load_svn_dump_into_repo(dump_name, target_path) return target_path def assert_message_in_log(log_records, message, levelno, module): messages = [ r.message for r in log_records if r.module == module and r.levelno == levelno ] assert message in messages def _load_svn_dump_into_repo(dump_name, repo_path): """ Utility to populate a svn repository with a named dump Currently the dumps are in rc_testdata. They might later on be integrated with the main repository once they stabilize more. """ rc_testdata = get_rc_testdata() dump = rc_testdata.load_svn_dump(dump_name) load_dump = subprocess.Popen( ['svnadmin', 'load', repo_path], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) out, err = load_dump.communicate(dump) if load_dump.returncode != 0: log.error("Output of load_dump command: %s", out) log.error("Error output of load_dump command: %s", err) raise Exception(f'Failed to load dump "{dump_name}" into repository at path "{repo_path}".') class AssertResponse(object): """ Utility that helps to assert things about a given HTML response. """ def __init__(self, response): self.response = response def get_imports(self): return fromstring, tostring, CSSSelector def one_element_exists(self, css_selector): self.get_element(css_selector) def no_element_exists(self, css_selector): assert not self._get_elements(css_selector) def element_equals_to(self, css_selector, expected_content): element = self.get_element(css_selector) element_text = self._element_to_string(element) assert expected_content in element_text def element_contains(self, css_selector, expected_content): element = self.get_element(css_selector) assert expected_content in element.text_content() def element_value_contains(self, css_selector, expected_content): element = self.get_element(css_selector) assert expected_content in element.value def contains_one_link(self, link_text, href): fromstring, tostring, CSSSelector = self.get_imports() doc = fromstring(self.response.body) sel = CSSSelector('a[href]') elements = [ e for e in sel(doc) if e.text_content().strip() == link_text] assert len(elements) == 1, "Did not find link or found multiple links" self._ensure_url_equal(elements[0].attrib.get('href'), href) def contains_one_anchor(self, anchor_id): fromstring, tostring, CSSSelector = self.get_imports() doc = fromstring(self.response.body) sel = CSSSelector('#' + anchor_id) elements = sel(doc) assert len(elements) == 1, 'cannot find 1 element {}'.format(anchor_id) def _ensure_url_equal(self, found, expected): assert _Url(found) == _Url(expected) def get_element(self, css_selector): elements = self._get_elements(css_selector) assert len(elements) == 1, 'cannot find 1 element {}'.format(css_selector) return elements[0] def get_elements(self, css_selector): return self._get_elements(css_selector) def _get_elements(self, css_selector): fromstring, tostring, CSSSelector = self.get_imports() doc = fromstring(self.response.body) sel = CSSSelector(css_selector) elements = sel(doc) return elements def _element_to_string(self, element): fromstring, tostring, CSSSelector = self.get_imports() return tostring(element, encoding='unicode') class _Url(object): """ A url object that can be compared with other url orbjects without regard to the vagaries of encoding, escaping, and ordering of parameters in query strings. Inspired by http://stackoverflow.com/questions/5371992/comparing-two-urls-in-python """ def __init__(self, url): parts = urllib.parse.urlparse(url) _query = frozenset(urllib.parse.parse_qsl(parts.query)) _path = unquote_plus(parts.path) parts = parts._replace(query=_query, path=_path) self.parts = parts def __eq__(self, other): return self.parts == other.parts def __hash__(self): return hash(self.parts) def run_test_concurrently(times, raise_catched_exc=True): """ Add this decorator to small pieces of code that you want to test concurrently ex: @test_concurrently(25) def my_test_function(): ... """ def test_concurrently_decorator(test_func): def wrapper(*args, **kwargs): exceptions = [] def call_test_func(): try: test_func(*args, **kwargs) except Exception as e: exceptions.append(e) if raise_catched_exc: raise threads = [] for i in range(times): threads.append(threading.Thread(target=call_test_func)) for t in threads: t.start() for t in threads: t.join() if exceptions: raise Exception( 'test_concurrently intercepted %s exceptions: %s' % ( len(exceptions), exceptions)) return wrapper return test_concurrently_decorator def wait_for_url(url, timeout=10): """ Wait until URL becomes reachable. It polls the URL until the timeout is reached or it became reachable. If will call to `py.test.fail` in case the URL is not reachable. """ timeout = time.time() + timeout last = 0 wait = 0.1 while timeout > last: last = time.time() if is_url_reachable(url, log_exc=False): break elif (last + wait) > time.time(): # Go to sleep because not enough time has passed since last check. time.sleep(wait) else: pytest.fail(f"Timeout while waiting for URL {url}") def is_url_reachable(url: str, log_exc: bool = False) -> bool: try: urllib.request.urlopen(url) except urllib.error.URLError: if log_exc: log.exception(f'URL `{url}` reach error') return False return True def repo_on_filesystem(repo_name): from rhodecode.lib import vcs from rhodecode.tests import TESTS_TMP_PATH repo = vcs.get_vcs_instance( os.path.join(TESTS_TMP_PATH, repo_name), create=False) return repo is not None def commit_change( repo, filename: bytes, content: bytes, message, vcs_type, parent=None, branch=None, newfile=False): from rhodecode.tests import TEST_USER_ADMIN_LOGIN repo = Repository.get_by_repo_name(repo) _commit = parent if not parent: _commit = EmptyCommit(alias=vcs_type) if newfile: nodes = { filename: { 'content': content } } commit = ScmModel().create_nodes( user=TEST_USER_ADMIN_LOGIN, repo=repo, message=message, nodes=nodes, parent_commit=_commit, author=f'{TEST_USER_ADMIN_LOGIN} ', ) else: commit = ScmModel().commit_change( repo=repo.scm_instance(), repo_name=repo.repo_name, commit=parent, user=TEST_USER_ADMIN_LOGIN, author=f'{TEST_USER_ADMIN_LOGIN} ', message=message, content=content, f_path=filename, branch=branch ) return commit def permission_update_data_generator(csrf_token, default=None, grant=None, revoke=None): if not default: raise ValueError('Permission for default user must be given') form_data = [( 'csrf_token', csrf_token )] # add default form_data.extend([ ('u_perm_1', default) ]) if grant: for cnt, (obj_id, perm, obj_name, obj_type) in enumerate(grant, 1): form_data.extend([ ('perm_new_member_perm_new{}'.format(cnt), perm), ('perm_new_member_id_new{}'.format(cnt), obj_id), ('perm_new_member_name_new{}'.format(cnt), obj_name), ('perm_new_member_type_new{}'.format(cnt), obj_type), ]) if revoke: for obj_id, obj_type in revoke: form_data.extend([ ('perm_del_member_id_{}'.format(obj_id), obj_id), ('perm_del_member_type_{}'.format(obj_id), obj_type), ]) return form_data class AuthPluginManager: def cleanup(self): self._enable_plugins(['egg:rhodecode-enterprise-ce#rhodecode']) def enable(self, plugins_list, override=None): return self._enable_plugins(plugins_list, override) @classmethod def _enable_plugins(cls, plugins_list, override: object = None): override = override or {} params = { 'auth_plugins': ','.join(plugins_list), } # helper translate some names to others, to fix settings code name_map = { 'token': 'authtoken' } log.debug('enable_auth_plugins: enabling following auth-plugins: %s', plugins_list) for module in plugins_list: plugin_name = module.partition('#')[-1] if plugin_name in name_map: plugin_name = name_map[plugin_name] enabled_plugin = f'auth_{plugin_name}_enabled' cache_ttl = f'auth_{plugin_name}_cache_ttl' # default params that are needed for each plugin, # `enabled` and `cache_ttl` params.update({ enabled_plugin: True, cache_ttl: 0 }) if override.get: params.update(override.get(module, {})) validated_params = params for k, v in validated_params.items(): setting = SettingsModel().create_or_update_setting(k, v) Session().add(setting) Session().commit() AuthenticationPluginRegistry.invalidate_auth_plugins_cache(hard=True) enabled_plugins = SettingsModel().get_auth_plugins() assert plugins_list == enabled_plugins