Show More
utils.py
465 lines
| 14.6 KiB
| text/x-python
|
PythonLexer
r1 | # -*- coding: utf-8 -*- | |||
r3363 | # Copyright (C) 2010-2019 RhodeCode GmbH | |||
r1 | # | |||
# This program is free software: you can redistribute it and/or modify | ||||
# it under the terms of the GNU Affero General Public License, version 3 | ||||
# (only), as published by the Free Software Foundation. | ||||
# | ||||
# This program is distributed in the hope that it will be useful, | ||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of | ||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||||
# GNU General Public License for more details. | ||||
# | ||||
# You should have received a copy of the GNU Affero General Public License | ||||
# along with this program. If not, see <http://www.gnu.org/licenses/>. | ||||
# | ||||
# This program is dual-licensed. If you wish to learn more about the | ||||
# RhodeCode Enterprise Edition, including its added features, Support services, | ||||
# and proprietary license terms, please see https://rhodecode.com/licenses/ | ||||
import threading | ||||
import time | ||||
import logging | ||||
import os.path | ||||
Martin Bornhold
|
r1007 | import subprocess32 | ||
r1256 | import tempfile | |||
r1 | import urllib2 | |||
r1525 | from lxml.html import fromstring, tostring | |||
from lxml.cssselect import CSSSelector | ||||
r1 | from urlparse import urlparse, parse_qsl | |||
from urllib import unquote_plus | ||||
r1906 | import webob | |||
r1 | ||||
r1906 | from webtest.app import TestResponse, TestApp, string_types | |||
from webtest.compat import print_stderr | ||||
r1256 | ||||
r1 | import pytest | |||
import rc_testdata | ||||
r1256 | from rhodecode.model.db import User, Repository | |||
r1 | from rhodecode.model.meta import Session | |||
from rhodecode.model.scm import ScmModel | ||||
from rhodecode.lib.vcs.backends.svn.repository import SubversionRepository | ||||
r1259 | from rhodecode.lib.vcs.backends.base import EmptyCommit | |||
r2374 | from rhodecode.tests import login_user_session | |||
r1 | ||||
log = logging.getLogger(__name__) | ||||
r1256 | class CustomTestResponse(TestResponse): | |||
def _save_output(self, out): | ||||
f = tempfile.NamedTemporaryFile( | ||||
delete=False, prefix='rc-test-', suffix='.html') | ||||
f.write(out) | ||||
return f.name | ||||
def mustcontain(self, *strings, **kw): | ||||
""" | ||||
Assert that the response contains all of the strings passed | ||||
in as arguments. | ||||
Equivalent to:: | ||||
assert string in res | ||||
""" | ||||
if 'no' in kw: | ||||
no = kw['no'] | ||||
del kw['no'] | ||||
if isinstance(no, string_types): | ||||
no = [no] | ||||
else: | ||||
no = [] | ||||
if kw: | ||||
raise TypeError( | ||||
r2374 | "The only keyword argument allowed is 'no' got %s" % kw) | |||
r1256 | ||||
f = self._save_output(str(self)) | ||||
for s in strings: | ||||
if not s in self: | ||||
print_stderr("Actual response (no %r):" % s) | ||||
print_stderr(str(self)) | ||||
raise IndexError( | ||||
"Body does not contain string %r, output saved as %s" % ( | ||||
s, f)) | ||||
for no_s in no: | ||||
if no_s in self: | ||||
print_stderr("Actual response (has %r)" % no_s) | ||||
print_stderr(str(self)) | ||||
raise IndexError( | ||||
"Body contains bad string %r, output saved as %s" % ( | ||||
no_s, f)) | ||||
def assert_response(self): | ||||
return AssertResponse(self) | ||||
r1774 | def get_session_from_response(self): | |||
""" | ||||
r2351 | This returns the session from a response object. | |||
r1774 | """ | |||
r2351 | ||||
from pyramid_beaker import session_factory_from_settings | ||||
r3432 | session = session_factory_from_settings(self.test_app._pyramid_settings) | |||
r2351 | return session(self.request) | |||
r1774 | ||||
r1256 | ||||
r1906 | class TestRequest(webob.BaseRequest): | |||
r1256 | ||||
# for py.test | ||||
disabled = True | ||||
ResponseClass = CustomTestResponse | ||||
r2351 | def add_response_callback(self, callback): | |||
pass | ||||
r1256 | ||||
class CustomTestApp(TestApp): | ||||
""" | ||||
r3465 | Custom app to make mustcontain more Useful, and extract special methods | |||
r1256 | """ | |||
RequestClass = TestRequest | ||||
r2374 | rc_login_data = {} | |||
rc_current_session = None | ||||
def login(self, username=None, password=None): | ||||
from rhodecode.lib import auth | ||||
if username and password: | ||||
session = login_user_session(self, username, password) | ||||
else: | ||||
session = login_user_session(self) | ||||
self.rc_login_data['csrf_token'] = auth.get_csrf_token(session) | ||||
self.rc_current_session = session | ||||
return session['rhodecode_user'] | ||||
@property | ||||
def csrf_token(self): | ||||
return self.rc_login_data['csrf_token'] | ||||
r1256 | ||||
r3432 | @property | |||
def _pyramid_registry(self): | ||||
return self.app.config.registry | ||||
@property | ||||
def _pyramid_settings(self): | ||||
return self._pyramid_registry.settings | ||||
r1256 | ||||
r1 | def set_anonymous_access(enabled): | |||
"""(Dis)allows anonymous access depending on parameter `enabled`""" | ||||
user = User.get_default_user() | ||||
user.active = enabled | ||||
Session().add(user) | ||||
Session().commit() | ||||
r1785 | time.sleep(1.5) # must sleep for cache (1s to expire) | |||
r1 | log.info('anonymous access is now: %s', enabled) | |||
assert enabled == User.get_default_user().active, ( | ||||
'Cannot set anonymous access') | ||||
def check_xfail_backends(node, backend_alias): | ||||
# Using "xfail_backends" here intentionally, since this marks work | ||||
# which is "to be done" soon. | ||||
r3098 | skip_marker = node.get_closest_marker('xfail_backends') | |||
r1 | if skip_marker and backend_alias in skip_marker.args: | |||
msg = "Support for backend %s to be developed." % (backend_alias, ) | ||||
msg = skip_marker.kwargs.get('reason', msg) | ||||
pytest.xfail(msg) | ||||
def check_skip_backends(node, backend_alias): | ||||
# Using "skip_backends" here intentionally, since this marks work which is | ||||
# not supported. | ||||
r3098 | skip_marker = node.get_closest_marker('skip_backends') | |||
r1 | if skip_marker and backend_alias in skip_marker.args: | |||
msg = "Feature not supported for backend %s." % (backend_alias, ) | ||||
msg = skip_marker.kwargs.get('reason', msg) | ||||
pytest.skip(msg) | ||||
def extract_git_repo_from_dump(dump_name, repo_name): | ||||
"""Create git repo `repo_name` from dump `dump_name`.""" | ||||
repos_path = ScmModel().repos_path | ||||
target_path = os.path.join(repos_path, repo_name) | ||||
rc_testdata.extract_git_dump(dump_name, target_path) | ||||
return target_path | ||||
def extract_hg_repo_from_dump(dump_name, repo_name): | ||||
"""Create hg repo `repo_name` from dump `dump_name`.""" | ||||
repos_path = ScmModel().repos_path | ||||
target_path = os.path.join(repos_path, repo_name) | ||||
rc_testdata.extract_hg_dump(dump_name, target_path) | ||||
return target_path | ||||
def extract_svn_repo_from_dump(dump_name, repo_name): | ||||
"""Create a svn repo `repo_name` from dump `dump_name`.""" | ||||
repos_path = ScmModel().repos_path | ||||
target_path = os.path.join(repos_path, repo_name) | ||||
SubversionRepository(target_path, create=True) | ||||
_load_svn_dump_into_repo(dump_name, target_path) | ||||
return target_path | ||||
def assert_message_in_log(log_records, message, levelno, module): | ||||
messages = [ | ||||
r.message for r in log_records | ||||
if r.module == module and r.levelno == levelno | ||||
] | ||||
assert message in messages | ||||
def _load_svn_dump_into_repo(dump_name, repo_path): | ||||
""" | ||||
Utility to populate a svn repository with a named dump | ||||
Currently the dumps are in rc_testdata. They might later on be | ||||
integrated with the main repository once they stabilize more. | ||||
""" | ||||
dump = rc_testdata.load_svn_dump(dump_name) | ||||
Martin Bornhold
|
r1007 | load_dump = subprocess32.Popen( | ||
r1 | ['svnadmin', 'load', repo_path], | |||
Martin Bornhold
|
r1007 | stdin=subprocess32.PIPE, stdout=subprocess32.PIPE, | ||
stderr=subprocess32.PIPE) | ||||
r1 | out, err = load_dump.communicate(dump) | |||
if load_dump.returncode != 0: | ||||
log.error("Output of load_dump command: %s", out) | ||||
log.error("Error output of load_dump command: %s", err) | ||||
raise Exception( | ||||
'Failed to load dump "%s" into repository at path "%s".' | ||||
% (dump_name, repo_path)) | ||||
class AssertResponse(object): | ||||
""" | ||||
Utility that helps to assert things about a given HTML response. | ||||
""" | ||||
def __init__(self, response): | ||||
self.response = response | ||||
r1239 | def get_imports(self): | |||
return fromstring, tostring, CSSSelector | ||||
r1 | def one_element_exists(self, css_selector): | |||
self.get_element(css_selector) | ||||
def no_element_exists(self, css_selector): | ||||
assert not self._get_elements(css_selector) | ||||
def element_equals_to(self, css_selector, expected_content): | ||||
element = self.get_element(css_selector) | ||||
element_text = self._element_to_string(element) | ||||
assert expected_content in element_text | ||||
def element_contains(self, css_selector, expected_content): | ||||
element = self.get_element(css_selector) | ||||
assert expected_content in element.text_content() | ||||
Martin Bornhold
|
r1046 | def element_value_contains(self, css_selector, expected_content): | ||
element = self.get_element(css_selector) | ||||
assert expected_content in element.value | ||||
r1 | def contains_one_link(self, link_text, href): | |||
r1239 | fromstring, tostring, CSSSelector = self.get_imports() | |||
r1 | doc = fromstring(self.response.body) | |||
sel = CSSSelector('a[href]') | ||||
elements = [ | ||||
e for e in sel(doc) if e.text_content().strip() == link_text] | ||||
assert len(elements) == 1, "Did not find link or found multiple links" | ||||
self._ensure_url_equal(elements[0].attrib.get('href'), href) | ||||
def contains_one_anchor(self, anchor_id): | ||||
r1239 | fromstring, tostring, CSSSelector = self.get_imports() | |||
r1 | doc = fromstring(self.response.body) | |||
sel = CSSSelector('#' + anchor_id) | ||||
elements = sel(doc) | ||||
r1442 | assert len(elements) == 1, 'cannot find 1 element {}'.format(anchor_id) | |||
r1 | ||||
def _ensure_url_equal(self, found, expected): | ||||
assert _Url(found) == _Url(expected) | ||||
def get_element(self, css_selector): | ||||
elements = self._get_elements(css_selector) | ||||
r1442 | assert len(elements) == 1, 'cannot find 1 element {}'.format(css_selector) | |||
r1 | return elements[0] | |||
r423 | def get_elements(self, css_selector): | |||
return self._get_elements(css_selector) | ||||
r1 | def _get_elements(self, css_selector): | |||
r1239 | fromstring, tostring, CSSSelector = self.get_imports() | |||
r1 | doc = fromstring(self.response.body) | |||
sel = CSSSelector(css_selector) | ||||
elements = sel(doc) | ||||
return elements | ||||
def _element_to_string(self, element): | ||||
r1239 | fromstring, tostring, CSSSelector = self.get_imports() | |||
r1 | return tostring(element) | |||
class _Url(object): | ||||
""" | ||||
A url object that can be compared with other url orbjects | ||||
without regard to the vagaries of encoding, escaping, and ordering | ||||
of parameters in query strings. | ||||
Inspired by | ||||
http://stackoverflow.com/questions/5371992/comparing-two-urls-in-python | ||||
""" | ||||
def __init__(self, url): | ||||
parts = urlparse(url) | ||||
_query = frozenset(parse_qsl(parts.query)) | ||||
_path = unquote_plus(parts.path) | ||||
parts = parts._replace(query=_query, path=_path) | ||||
self.parts = parts | ||||
def __eq__(self, other): | ||||
return self.parts == other.parts | ||||
def __hash__(self): | ||||
return hash(self.parts) | ||||
def run_test_concurrently(times, raise_catched_exc=True): | ||||
""" | ||||
Add this decorator to small pieces of code that you want to test | ||||
concurrently | ||||
ex: | ||||
@test_concurrently(25) | ||||
def my_test_function(): | ||||
... | ||||
""" | ||||
def test_concurrently_decorator(test_func): | ||||
def wrapper(*args, **kwargs): | ||||
exceptions = [] | ||||
def call_test_func(): | ||||
try: | ||||
test_func(*args, **kwargs) | ||||
r1239 | except Exception as e: | |||
r1 | exceptions.append(e) | |||
if raise_catched_exc: | ||||
raise | ||||
threads = [] | ||||
for i in range(times): | ||||
threads.append(threading.Thread(target=call_test_func)) | ||||
for t in threads: | ||||
t.start() | ||||
for t in threads: | ||||
t.join() | ||||
if exceptions: | ||||
raise Exception( | ||||
'test_concurrently intercepted %s exceptions: %s' % ( | ||||
len(exceptions), exceptions)) | ||||
return wrapper | ||||
return test_concurrently_decorator | ||||
def wait_for_url(url, timeout=10): | ||||
""" | ||||
Wait until URL becomes reachable. | ||||
It polls the URL until the timeout is reached or it became reachable. | ||||
If will call to `py.test.fail` in case the URL is not reachable. | ||||
""" | ||||
timeout = time.time() + timeout | ||||
last = 0 | ||||
wait = 0.1 | ||||
r1239 | while timeout > last: | |||
r1 | last = time.time() | |||
if is_url_reachable(url): | ||||
break | ||||
r1239 | elif (last + wait) > time.time(): | |||
r1 | # Go to sleep because not enough time has passed since last check. | |||
time.sleep(wait) | ||||
else: | ||||
pytest.fail("Timeout while waiting for URL {}".format(url)) | ||||
def is_url_reachable(url): | ||||
try: | ||||
urllib2.urlopen(url) | ||||
except urllib2.URLError: | ||||
return False | ||||
return True | ||||
r41 | ||||
Martin Bornhold
|
r486 | def repo_on_filesystem(repo_name): | ||
from rhodecode.lib import vcs | ||||
from rhodecode.tests import TESTS_TMP_PATH | ||||
repo = vcs.get_vcs_instance( | ||||
os.path.join(TESTS_TMP_PATH, repo_name), create=False) | ||||
return repo is not None | ||||
r1259 | ||||
def commit_change( | ||||
repo, filename, content, message, vcs_type, parent=None, newfile=False): | ||||
from rhodecode.tests import TEST_USER_ADMIN_LOGIN | ||||
repo = Repository.get_by_repo_name(repo) | ||||
_commit = parent | ||||
if not parent: | ||||
_commit = EmptyCommit(alias=vcs_type) | ||||
if newfile: | ||||
nodes = { | ||||
filename: { | ||||
'content': content | ||||
} | ||||
} | ||||
commit = ScmModel().create_nodes( | ||||
user=TEST_USER_ADMIN_LOGIN, repo=repo, | ||||
message=message, | ||||
nodes=nodes, | ||||
parent_commit=_commit, | ||||
author=TEST_USER_ADMIN_LOGIN, | ||||
) | ||||
else: | ||||
commit = ScmModel().commit_change( | ||||
repo=repo.scm_instance(), repo_name=repo.repo_name, | ||||
commit=parent, user=TEST_USER_ADMIN_LOGIN, | ||||
author=TEST_USER_ADMIN_LOGIN, | ||||
message=message, | ||||
content=content, | ||||
f_path=filename | ||||
) | ||||
return commit | ||||
r2827 | ||||
def permission_update_data_generator(csrf_token, default=None, grant=None, revoke=None): | ||||
if not default: | ||||
raise ValueError('Permission for default user must be given') | ||||
form_data = [( | ||||
'csrf_token', csrf_token | ||||
)] | ||||
# add default | ||||
form_data.extend([ | ||||
('u_perm_1', default) | ||||
]) | ||||
if grant: | ||||
for cnt, (obj_id, perm, obj_name, obj_type) in enumerate(grant, 1): | ||||
form_data.extend([ | ||||
('perm_new_member_perm_new{}'.format(cnt), perm), | ||||
('perm_new_member_id_new{}'.format(cnt), obj_id), | ||||
('perm_new_member_name_new{}'.format(cnt), obj_name), | ||||
('perm_new_member_type_new{}'.format(cnt), obj_type), | ||||
]) | ||||
if revoke: | ||||
for obj_id, obj_type in revoke: | ||||
form_data.extend([ | ||||
('perm_del_member_id_{}'.format(obj_id), obj_id), | ||||
('perm_del_member_type_{}'.format(obj_id), obj_type), | ||||
]) | ||||
return form_data | ||||