|
|
# Copyright (C) 2010-2023 RhodeCode GmbH
|
|
|
#
|
|
|
# This program is free software: you can redistribute it and/or modify
|
|
|
# it under the terms of the GNU Affero General Public License, version 3
|
|
|
# (only), as published by the Free Software Foundation.
|
|
|
#
|
|
|
# This program is distributed in the hope that it will be useful,
|
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
|
# GNU General Public License for more details.
|
|
|
#
|
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
#
|
|
|
# This program is dual-licensed. If you wish to learn more about the
|
|
|
# RhodeCode Enterprise Edition, including its added features, Support services,
|
|
|
# and proprietary license terms, please see https://rhodecode.com/licenses/
|
|
|
|
|
|
import threading
|
|
|
import time
|
|
|
import sys
|
|
|
import logging
|
|
|
import os.path
|
|
|
import subprocess
|
|
|
import tempfile
|
|
|
import urllib.request
|
|
|
import urllib.error
|
|
|
import urllib.parse
|
|
|
from lxml.html import fromstring, tostring
|
|
|
from lxml.cssselect import CSSSelector
|
|
|
from urllib.parse import unquote_plus
|
|
|
import webob
|
|
|
|
|
|
from webtest.app import TestResponse, TestApp
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
|
from rhodecode.model.db import User, Repository
|
|
|
from rhodecode.model.meta import Session
|
|
|
from rhodecode.model.scm import ScmModel
|
|
|
from rhodecode.lib.vcs.backends.svn.repository import SubversionRepository
|
|
|
from rhodecode.lib.vcs.backends.base import EmptyCommit
|
|
|
from rhodecode.tests import login_user_session, console_printer
|
|
|
from rhodecode.authentication import AuthenticationPluginRegistry
|
|
|
from rhodecode.model.settings import SettingsModel
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
def console_printer_utils(msg):
|
|
|
console_printer(f" :white_check_mark: [green]test-utils[/green] {msg}")
|
|
|
|
|
|
|
|
|
def get_rc_testdata():
|
|
|
try:
|
|
|
import rc_testdata
|
|
|
except ImportError:
|
|
|
raise ImportError('Failed to import rc_testdata, '
|
|
|
'please make sure this package is installed from requirements_test.txt')
|
|
|
return rc_testdata
|
|
|
|
|
|
|
|
|
class CustomTestResponse(TestResponse):
|
|
|
|
|
|
def _save_output(self, out):
|
|
|
f = tempfile.NamedTemporaryFile(mode='w', delete=False, prefix='rc-test-', suffix='.html')
|
|
|
f.write(out)
|
|
|
return f.name
|
|
|
|
|
|
def mustcontain(self, *strings, **kw):
|
|
|
"""
|
|
|
Assert that the response contains all the strings passed
|
|
|
in as arguments.
|
|
|
|
|
|
Equivalent to::
|
|
|
|
|
|
assert string in res
|
|
|
"""
|
|
|
print_body = kw.pop('print_body', False)
|
|
|
|
|
|
if 'no' in kw:
|
|
|
no = kw['no']
|
|
|
del kw['no']
|
|
|
if isinstance(no, str):
|
|
|
no = [no]
|
|
|
else:
|
|
|
no = []
|
|
|
if kw:
|
|
|
raise TypeError(f"The only keyword argument allowed is 'no' got {kw}")
|
|
|
|
|
|
f = self._save_output(str(self))
|
|
|
|
|
|
for s in strings:
|
|
|
if s not in self:
|
|
|
console_printer_utils(f"Actual response (no {s!r}):")
|
|
|
console_printer_utils(f"body output saved as `{f}`")
|
|
|
if print_body:
|
|
|
console_printer_utils(str(self))
|
|
|
raise IndexError(f"Body does not contain string {s!r}, body output saved as {f}")
|
|
|
|
|
|
for no_s in no:
|
|
|
if no_s in self:
|
|
|
console_printer_utils(f"Actual response (has {no_s!r})")
|
|
|
console_printer_utils(f"body output saved as `{f}`")
|
|
|
if print_body:
|
|
|
console_printer_utils(str(self))
|
|
|
raise IndexError(f"Body contains bad string {no_s!r}, body output saved as {f}")
|
|
|
|
|
|
def assert_response(self):
|
|
|
return AssertResponse(self)
|
|
|
|
|
|
def get_session_from_response(self):
|
|
|
"""
|
|
|
This returns the session from a response object.
|
|
|
"""
|
|
|
from rhodecode.lib.rc_beaker import session_factory_from_settings
|
|
|
session = session_factory_from_settings(self.test_app._pyramid_settings)
|
|
|
return session(self.request)
|
|
|
|
|
|
|
|
|
class TestRequest(webob.BaseRequest):
|
|
|
|
|
|
# for py.test, so it doesn't try to run this tas by name starting with test...
|
|
|
disabled = True
|
|
|
ResponseClass = CustomTestResponse
|
|
|
|
|
|
def add_response_callback(self, callback):
|
|
|
pass
|
|
|
|
|
|
@classmethod
|
|
|
def blank(cls, path, environ=None, base_url=None,
|
|
|
headers=None, POST=None, **kw):
|
|
|
|
|
|
if not path.isascii():
|
|
|
# our custom quote path if it contains non-ascii chars
|
|
|
path = urllib.parse.quote(path)
|
|
|
|
|
|
return super(TestRequest, cls).blank(
|
|
|
path, environ=environ, base_url=base_url, headers=headers, POST=POST, **kw)
|
|
|
|
|
|
|
|
|
class CustomTestApp(TestApp):
|
|
|
"""
|
|
|
Custom app to make mustcontain more Useful, and extract special methods
|
|
|
"""
|
|
|
RequestClass = TestRequest
|
|
|
rc_login_data = {}
|
|
|
rc_current_session = None
|
|
|
|
|
|
def login(self, username=None, password=None):
|
|
|
from rhodecode.lib import auth
|
|
|
|
|
|
if username and password:
|
|
|
session = login_user_session(self, username, password)
|
|
|
else:
|
|
|
session = login_user_session(self)
|
|
|
|
|
|
self.rc_login_data['csrf_token'] = auth.get_csrf_token(session)
|
|
|
self.rc_current_session = session
|
|
|
return session['rhodecode_user']
|
|
|
|
|
|
@property
|
|
|
def csrf_token(self):
|
|
|
return self.rc_login_data['csrf_token']
|
|
|
|
|
|
@property
|
|
|
def _pyramid_registry(self):
|
|
|
return self.app.config.registry
|
|
|
|
|
|
@property
|
|
|
def _pyramid_settings(self):
|
|
|
return self._pyramid_registry.settings
|
|
|
|
|
|
def do_request(self, req, status=None, expect_errors=None):
|
|
|
# you can put custom code here
|
|
|
return super().do_request(req, status, expect_errors)
|
|
|
|
|
|
|
|
|
def set_anonymous_access(enabled):
|
|
|
"""(Dis)allows anonymous access depending on parameter `enabled`"""
|
|
|
user = User.get_default_user()
|
|
|
user.active = enabled
|
|
|
Session().add(user)
|
|
|
Session().commit()
|
|
|
time.sleep(1.5) # must sleep for cache (1s to expire)
|
|
|
log.info('anonymous access is now: %s', enabled)
|
|
|
assert enabled == User.get_default_user().active, (
|
|
|
'Cannot set anonymous access')
|
|
|
|
|
|
|
|
|
def check_xfail_backends(node, backend_alias):
|
|
|
# Using "xfail_backends" here intentionally, since this marks work
|
|
|
# which is "to be done" soon.
|
|
|
skip_marker = node.get_closest_marker('xfail_backends')
|
|
|
if skip_marker and backend_alias in skip_marker.args:
|
|
|
msg = "Support for backend %s to be developed." % (backend_alias, )
|
|
|
msg = skip_marker.kwargs.get('reason', msg)
|
|
|
pytest.xfail(msg)
|
|
|
|
|
|
|
|
|
def check_skip_backends(node, backend_alias):
|
|
|
# Using "skip_backends" here intentionally, since this marks work which is
|
|
|
# not supported.
|
|
|
skip_marker = node.get_closest_marker('skip_backends')
|
|
|
if skip_marker and backend_alias in skip_marker.args:
|
|
|
msg = "Feature not supported for backend %s." % (backend_alias, )
|
|
|
msg = skip_marker.kwargs.get('reason', msg)
|
|
|
pytest.skip(msg)
|
|
|
|
|
|
|
|
|
def extract_git_repo_from_dump(dump_name, repo_name):
|
|
|
"""Create git repo `repo_name` from dump `dump_name`."""
|
|
|
repos_path = ScmModel().repos_path
|
|
|
target_path = os.path.join(repos_path, repo_name)
|
|
|
rc_testdata = get_rc_testdata()
|
|
|
rc_testdata.extract_git_dump(dump_name, target_path)
|
|
|
return target_path
|
|
|
|
|
|
|
|
|
def extract_hg_repo_from_dump(dump_name, repo_name):
|
|
|
"""Create hg repo `repo_name` from dump `dump_name`."""
|
|
|
repos_path = ScmModel().repos_path
|
|
|
target_path = os.path.join(repos_path, repo_name)
|
|
|
rc_testdata = get_rc_testdata()
|
|
|
rc_testdata.extract_hg_dump(dump_name, target_path)
|
|
|
return target_path
|
|
|
|
|
|
|
|
|
def extract_svn_repo_from_dump(dump_name, repo_name):
|
|
|
"""Create a svn repo `repo_name` from dump `dump_name`."""
|
|
|
repos_path = ScmModel().repos_path
|
|
|
target_path = os.path.join(repos_path, repo_name)
|
|
|
SubversionRepository(target_path, create=True)
|
|
|
_load_svn_dump_into_repo(dump_name, target_path)
|
|
|
return target_path
|
|
|
|
|
|
|
|
|
def assert_message_in_log(log_records, message, levelno, module):
|
|
|
messages = [
|
|
|
r.message for r in log_records
|
|
|
if r.module == module and r.levelno == levelno
|
|
|
]
|
|
|
assert message in messages
|
|
|
|
|
|
|
|
|
def _load_svn_dump_into_repo(dump_name, repo_path):
|
|
|
"""
|
|
|
Utility to populate a svn repository with a named dump
|
|
|
|
|
|
Currently the dumps are in rc_testdata. They might later on be
|
|
|
integrated with the main repository once they stabilize more.
|
|
|
"""
|
|
|
rc_testdata = get_rc_testdata()
|
|
|
dump = rc_testdata.load_svn_dump(dump_name)
|
|
|
load_dump = subprocess.Popen(
|
|
|
['svnadmin', 'load', repo_path],
|
|
|
stdin=subprocess.PIPE, stdout=subprocess.PIPE,
|
|
|
stderr=subprocess.PIPE)
|
|
|
out, err = load_dump.communicate(dump)
|
|
|
if load_dump.returncode != 0:
|
|
|
log.error("Output of load_dump command: %s", out)
|
|
|
log.error("Error output of load_dump command: %s", err)
|
|
|
raise Exception(f'Failed to load dump "{dump_name}" into repository at path "{repo_path}".')
|
|
|
|
|
|
|
|
|
class AssertResponse(object):
|
|
|
"""
|
|
|
Utility that helps to assert things about a given HTML response.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, response):
|
|
|
self.response = response
|
|
|
|
|
|
def get_imports(self):
|
|
|
return fromstring, tostring, CSSSelector
|
|
|
|
|
|
def one_element_exists(self, css_selector):
|
|
|
self.get_element(css_selector)
|
|
|
|
|
|
def no_element_exists(self, css_selector):
|
|
|
assert not self._get_elements(css_selector)
|
|
|
|
|
|
def element_equals_to(self, css_selector, expected_content):
|
|
|
element = self.get_element(css_selector)
|
|
|
element_text = self._element_to_string(element)
|
|
|
|
|
|
assert expected_content in element_text
|
|
|
|
|
|
def element_contains(self, css_selector, expected_content):
|
|
|
element = self.get_element(css_selector)
|
|
|
assert expected_content in element.text_content()
|
|
|
|
|
|
def element_value_contains(self, css_selector, expected_content):
|
|
|
element = self.get_element(css_selector)
|
|
|
assert expected_content in element.value
|
|
|
|
|
|
def contains_one_link(self, link_text, href):
|
|
|
fromstring, tostring, CSSSelector = self.get_imports()
|
|
|
doc = fromstring(self.response.body)
|
|
|
sel = CSSSelector('a[href]')
|
|
|
elements = [
|
|
|
e for e in sel(doc) if e.text_content().strip() == link_text]
|
|
|
assert len(elements) == 1, "Did not find link or found multiple links"
|
|
|
self._ensure_url_equal(elements[0].attrib.get('href'), href)
|
|
|
|
|
|
def contains_one_anchor(self, anchor_id):
|
|
|
fromstring, tostring, CSSSelector = self.get_imports()
|
|
|
doc = fromstring(self.response.body)
|
|
|
sel = CSSSelector('#' + anchor_id)
|
|
|
elements = sel(doc)
|
|
|
assert len(elements) == 1, 'cannot find 1 element {}'.format(anchor_id)
|
|
|
|
|
|
def _ensure_url_equal(self, found, expected):
|
|
|
assert _Url(found) == _Url(expected)
|
|
|
|
|
|
def get_element(self, css_selector):
|
|
|
elements = self._get_elements(css_selector)
|
|
|
assert len(elements) == 1, 'cannot find 1 element {}'.format(css_selector)
|
|
|
return elements[0]
|
|
|
|
|
|
def get_elements(self, css_selector):
|
|
|
return self._get_elements(css_selector)
|
|
|
|
|
|
def _get_elements(self, css_selector):
|
|
|
fromstring, tostring, CSSSelector = self.get_imports()
|
|
|
doc = fromstring(self.response.body)
|
|
|
sel = CSSSelector(css_selector)
|
|
|
elements = sel(doc)
|
|
|
return elements
|
|
|
|
|
|
def _element_to_string(self, element):
|
|
|
fromstring, tostring, CSSSelector = self.get_imports()
|
|
|
return tostring(element, encoding='unicode')
|
|
|
|
|
|
|
|
|
class _Url(object):
|
|
|
"""
|
|
|
A url object that can be compared with other url orbjects
|
|
|
without regard to the vagaries of encoding, escaping, and ordering
|
|
|
of parameters in query strings.
|
|
|
|
|
|
Inspired by
|
|
|
http://stackoverflow.com/questions/5371992/comparing-two-urls-in-python
|
|
|
"""
|
|
|
|
|
|
def __init__(self, url):
|
|
|
parts = urllib.parse.urlparse(url)
|
|
|
_query = frozenset(urllib.parse.parse_qsl(parts.query))
|
|
|
_path = unquote_plus(parts.path)
|
|
|
parts = parts._replace(query=_query, path=_path)
|
|
|
self.parts = parts
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
return self.parts == other.parts
|
|
|
|
|
|
def __hash__(self):
|
|
|
return hash(self.parts)
|
|
|
|
|
|
|
|
|
def run_test_concurrently(times, raise_catched_exc=True):
|
|
|
"""
|
|
|
Add this decorator to small pieces of code that you want to test
|
|
|
concurrently
|
|
|
|
|
|
ex:
|
|
|
|
|
|
@test_concurrently(25)
|
|
|
def my_test_function():
|
|
|
...
|
|
|
"""
|
|
|
def test_concurrently_decorator(test_func):
|
|
|
def wrapper(*args, **kwargs):
|
|
|
exceptions = []
|
|
|
|
|
|
def call_test_func():
|
|
|
try:
|
|
|
test_func(*args, **kwargs)
|
|
|
except Exception as e:
|
|
|
exceptions.append(e)
|
|
|
if raise_catched_exc:
|
|
|
raise
|
|
|
threads = []
|
|
|
for i in range(times):
|
|
|
threads.append(threading.Thread(target=call_test_func))
|
|
|
for t in threads:
|
|
|
t.start()
|
|
|
for t in threads:
|
|
|
t.join()
|
|
|
if exceptions:
|
|
|
raise Exception(
|
|
|
'test_concurrently intercepted %s exceptions: %s' % (
|
|
|
len(exceptions), exceptions))
|
|
|
return wrapper
|
|
|
return test_concurrently_decorator
|
|
|
|
|
|
|
|
|
def wait_for_url(url, timeout=10):
|
|
|
"""
|
|
|
Wait until URL becomes reachable.
|
|
|
|
|
|
It polls the URL until the timeout is reached or it became reachable.
|
|
|
If will call to `py.test.fail` in case the URL is not reachable.
|
|
|
"""
|
|
|
timeout = time.time() + timeout
|
|
|
last = 0
|
|
|
wait = 0.1
|
|
|
|
|
|
while timeout > last:
|
|
|
last = time.time()
|
|
|
if is_url_reachable(url, log_exc=False):
|
|
|
break
|
|
|
elif (last + wait) > time.time():
|
|
|
# Go to sleep because not enough time has passed since last check.
|
|
|
time.sleep(wait)
|
|
|
else:
|
|
|
pytest.fail(f"Timeout while waiting for URL {url}")
|
|
|
|
|
|
|
|
|
def is_url_reachable(url: str, log_exc: bool = False) -> bool:
|
|
|
try:
|
|
|
urllib.request.urlopen(url)
|
|
|
except urllib.error.URLError:
|
|
|
if log_exc:
|
|
|
log.exception(f'URL `{url}` reach error')
|
|
|
return False
|
|
|
return True
|
|
|
|
|
|
|
|
|
def repo_on_filesystem(repo_name):
|
|
|
from rhodecode.lib import vcs
|
|
|
from rhodecode.tests import TESTS_TMP_PATH
|
|
|
repo = vcs.get_vcs_instance(
|
|
|
os.path.join(TESTS_TMP_PATH, repo_name), create=False)
|
|
|
return repo is not None
|
|
|
|
|
|
|
|
|
def commit_change(
|
|
|
repo, filename: bytes, content: bytes, message, vcs_type, parent=None, branch=None, newfile=False):
|
|
|
from rhodecode.tests import TEST_USER_ADMIN_LOGIN
|
|
|
|
|
|
repo = Repository.get_by_repo_name(repo)
|
|
|
_commit = parent
|
|
|
if not parent:
|
|
|
_commit = EmptyCommit(alias=vcs_type)
|
|
|
|
|
|
if newfile:
|
|
|
nodes = {
|
|
|
filename: {
|
|
|
'content': content
|
|
|
}
|
|
|
}
|
|
|
commit = ScmModel().create_nodes(
|
|
|
user=TEST_USER_ADMIN_LOGIN, repo=repo,
|
|
|
message=message,
|
|
|
nodes=nodes,
|
|
|
parent_commit=_commit,
|
|
|
author=f'{TEST_USER_ADMIN_LOGIN} <admin@rhodecode.com>',
|
|
|
)
|
|
|
else:
|
|
|
commit = ScmModel().commit_change(
|
|
|
repo=repo.scm_instance(), repo_name=repo.repo_name,
|
|
|
commit=parent, user=TEST_USER_ADMIN_LOGIN,
|
|
|
author=f'{TEST_USER_ADMIN_LOGIN} <admin@rhodecode.com>',
|
|
|
message=message,
|
|
|
content=content,
|
|
|
f_path=filename,
|
|
|
branch=branch
|
|
|
)
|
|
|
return commit
|
|
|
|
|
|
|
|
|
def permission_update_data_generator(csrf_token, default=None, grant=None, revoke=None):
|
|
|
if not default:
|
|
|
raise ValueError('Permission for default user must be given')
|
|
|
form_data = [(
|
|
|
'csrf_token', csrf_token
|
|
|
)]
|
|
|
# add default
|
|
|
form_data.extend([
|
|
|
('u_perm_1', default)
|
|
|
])
|
|
|
|
|
|
if grant:
|
|
|
for cnt, (obj_id, perm, obj_name, obj_type) in enumerate(grant, 1):
|
|
|
form_data.extend([
|
|
|
('perm_new_member_perm_new{}'.format(cnt), perm),
|
|
|
('perm_new_member_id_new{}'.format(cnt), obj_id),
|
|
|
('perm_new_member_name_new{}'.format(cnt), obj_name),
|
|
|
('perm_new_member_type_new{}'.format(cnt), obj_type),
|
|
|
|
|
|
])
|
|
|
if revoke:
|
|
|
for obj_id, obj_type in revoke:
|
|
|
form_data.extend([
|
|
|
('perm_del_member_id_{}'.format(obj_id), obj_id),
|
|
|
('perm_del_member_type_{}'.format(obj_id), obj_type),
|
|
|
])
|
|
|
return form_data
|
|
|
|
|
|
|
|
|
|
|
|
class AuthPluginManager:
|
|
|
|
|
|
def cleanup(self):
|
|
|
self._enable_plugins(['egg:rhodecode-enterprise-ce#rhodecode'])
|
|
|
|
|
|
def enable(self, plugins_list, override=None):
|
|
|
return self._enable_plugins(plugins_list, override)
|
|
|
|
|
|
@classmethod
|
|
|
def _enable_plugins(cls, plugins_list, override: object = None):
|
|
|
override = override or {}
|
|
|
params = {
|
|
|
'auth_plugins': ','.join(plugins_list),
|
|
|
}
|
|
|
|
|
|
# helper translate some names to others, to fix settings code
|
|
|
name_map = {
|
|
|
'token': 'authtoken'
|
|
|
}
|
|
|
log.debug('enable_auth_plugins: enabling following auth-plugins: %s', plugins_list)
|
|
|
|
|
|
for module in plugins_list:
|
|
|
plugin_name = module.partition('#')[-1]
|
|
|
if plugin_name in name_map:
|
|
|
plugin_name = name_map[plugin_name]
|
|
|
enabled_plugin = f'auth_{plugin_name}_enabled'
|
|
|
cache_ttl = f'auth_{plugin_name}_cache_ttl'
|
|
|
|
|
|
# default params that are needed for each plugin,
|
|
|
# `enabled` and `cache_ttl`
|
|
|
params.update({
|
|
|
enabled_plugin: True,
|
|
|
cache_ttl: 0
|
|
|
})
|
|
|
if override.get:
|
|
|
params.update(override.get(module, {}))
|
|
|
|
|
|
validated_params = params
|
|
|
|
|
|
for k, v in validated_params.items():
|
|
|
setting = SettingsModel().create_or_update_setting(k, v)
|
|
|
Session().add(setting)
|
|
|
Session().commit()
|
|
|
|
|
|
AuthenticationPluginRegistry.invalidate_auth_plugins_cache(hard=True)
|
|
|
|
|
|
enabled_plugins = SettingsModel().get_auth_plugins()
|
|
|
assert plugins_list == enabled_plugins
|
|
|
|