# Copyright (C) 2010-2024 RhodeCode GmbH # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License, version 3 # (only), as published by the Free Software Foundation. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # # This program is dual-licensed. If you wish to learn more about the # RhodeCode Enterprise Edition, including its added features, Support services, # and proprietary license terms, please see https://rhodecode.com/licenses/ import subprocess import os import sys import tempfile import pytest from sqlalchemy.engine import url from rhodecode.lib.str_utils import safe_str, safe_bytes from rhodecode.tests.fixtures.rc_fixture import TestINI def _get_dbs_from_metafunc(metafunc): dbs_mark = metafunc.definition.get_closest_marker("dbs") if dbs_mark: # Supported backends by this test function, created from pytest.mark.dbs backends = dbs_mark.args else: backends = metafunc.config.getoption("--dbs") return backends def pytest_generate_tests(metafunc): # Support test generation based on --dbs parameter if "db_backend" in metafunc.fixturenames: requested_backends = set(metafunc.config.getoption("--dbs")) backends = _get_dbs_from_metafunc(metafunc) backends = requested_backends.intersection(backends) # TODO: johbo: Disabling a backend did not work out with # parametrization, find better way to achieve this. if not backends: metafunc.function._skip = True metafunc.parametrize("db_backend_name", backends) def pytest_collection_modifyitems(session, config, items): remaining = [i for i in items if not getattr(i.obj, "_skip", False)] items[:] = remaining @pytest.fixture() def db_backend(request, db_backend_name, ini_config, tmpdir_factory): basetemp = tmpdir_factory.getbasetemp().strpath klass = _get_backend(db_backend_name) option_name = "--{}-connection-string".format(db_backend_name) connection_string = request.config.getoption(option_name) or None return klass(config_file=ini_config, basetemp=basetemp, connection_string=connection_string) def _get_backend(backend_type): return {"sqlite": SQLiteDBBackend, "postgres": PostgresDBBackend, "mysql": MySQLDBBackend, "": EmptyDBBackend}[ backend_type ] class DBBackend(object): _store = os.path.dirname(os.path.abspath(__file__)) _type = None _base_ini_config = [{"app:main": {"vcs.start_server": "false", "startup.import_repos": "false"}}] _db_url = [{"app:main": {"sqlalchemy.db1.url": ""}}] _base_db_name = "rhodecode_test_db_backend" std_env = {"RC_TEST": "0"} def __init__(self, config_file, db_name=None, basetemp=None, connection_string=None): self.fixture_store = os.path.join(self._store, self._type) self.db_name = db_name or self._base_db_name self._base_ini_file = config_file self.stderr = "" self.stdout = "" self._basetemp = basetemp or tempfile.gettempdir() self._repos_location = os.path.join(self._basetemp, "rc_test_repos") self.connection_string = connection_string @property def connection_string(self): return self._connection_string @connection_string.setter def connection_string(self, new_connection_string): if not new_connection_string: new_connection_string = self.get_default_connection_string() else: new_connection_string = new_connection_string.format(db_name=self.db_name) url_parts = url.make_url(new_connection_string) self._connection_string = new_connection_string self.user = url_parts.username self.password = url_parts.password self.host = url_parts.host def get_default_connection_string(self): raise NotImplementedError("default connection_string is required.") def execute(self, cmd, env=None, *args): """ Runs command on the system with given ``args``. """ command = cmd + " " + " ".join(args) sys.stdout.write(f"CMD: {command}") # Tell Python to use UTF-8 encoding out stdout _env = os.environ.copy() _env["PYTHONIOENCODING"] = "UTF-8" _env.update(self.std_env) if env: _env.update(env) self.p = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=_env) self.stdout, self.stderr = self.p.communicate() stdout_str = safe_str(self.stdout) sys.stdout.write(f"COMMAND:{command}\n") sys.stdout.write(stdout_str) return self.stdout, self.stderr def assert_returncode_success(self): from rich import print as pprint if not self.p.returncode == 0: pprint(safe_str(self.stderr)) raise AssertionError(f"non 0 retcode:{self.p.returncode}") def assert_correct_output(self, stdout, version): assert b"UPGRADE FOR STEP %b COMPLETED" % safe_bytes(version) in stdout def setup_rhodecode_db(self, ini_params=None, env=None): if not ini_params: ini_params = self._base_ini_config ini_params.extend(self._db_url) with TestINI(self._base_ini_file, ini_params, self._type, destroy=True) as _ini_file: if not os.path.isdir(self._repos_location): os.makedirs(self._repos_location) return self.execute( "rc-setup-app {0} --user=marcink " "--email=marcin@rhodeocode.com --password={1} " "--repos={2} --force-yes".format(_ini_file, "qweqwe", self._repos_location), env=env, ) def upgrade_database(self, ini_params=None): if not ini_params: ini_params = self._base_ini_config ini_params.extend(self._db_url) test_ini = TestINI(self._base_ini_file, ini_params, self._type, destroy=True) with test_ini as ini_file: if not os.path.isdir(self._repos_location): os.makedirs(self._repos_location) return self.execute("rc-upgrade-db {0} --force-yes".format(ini_file)) def setup_db(self): raise NotImplementedError def teardown_db(self): raise NotImplementedError def import_dump(self, dumpname): raise NotImplementedError class EmptyDBBackend(DBBackend): _type = "" def setup_db(self): pass def teardown_db(self): pass def import_dump(self, dumpname): pass def assert_returncode_success(self): assert True class SQLiteDBBackend(DBBackend): _type = "sqlite" def get_default_connection_string(self): return "sqlite:///{}/{}.sqlite".format(self._basetemp, self.db_name) def setup_db(self): # dump schema for tests # cp -v $TEST_DB_NAME self._db_url = [{"app:main": {"sqlalchemy.db1.url": self.connection_string}}] def import_dump(self, dumpname): dump = os.path.join(self.fixture_store, dumpname) target = os.path.join(self._basetemp, "{0.db_name}.sqlite".format(self)) return self.execute(f"cp -v {dump} {target}") def teardown_db(self): target_db = os.path.join(self._basetemp, self.db_name) return self.execute(f"rm -rf {target_db}.sqlite") class MySQLDBBackend(DBBackend): _type = "mysql" def get_default_connection_string(self): return "mysql://root:qweqwe@127.0.0.1/{}".format(self.db_name) def setup_db(self): # dump schema for tests # mysqldump -uroot -pqweqwe $TEST_DB_NAME self._db_url = [{"app:main": {"sqlalchemy.db1.url": self.connection_string}}] return self.execute( "mysql -v -u{} -p{} -e 'create database '{}';'".format(self.user, self.password, self.db_name) ) def import_dump(self, dumpname): dump = os.path.join(self.fixture_store, dumpname) return self.execute("mysql -u{} -p{} {} < {}".format(self.user, self.password, self.db_name, dump)) def teardown_db(self): return self.execute( "mysql -v -u{} -p{} -e 'drop database '{}';'".format(self.user, self.password, self.db_name) ) class PostgresDBBackend(DBBackend): _type = "postgres" def get_default_connection_string(self): return "postgresql://postgres:qweqwe@localhost/{}".format(self.db_name) def setup_db(self): # dump schema for tests # pg_dump -U postgres -h localhost $TEST_DB_NAME self._db_url = [{"app:main": {"sqlalchemy.db1.url": self.connection_string}}] cmd = f"PGPASSWORD={self.password} psql -U {self.user} -h localhost -c 'create database '{self.db_name}';'" return self.execute(cmd) def teardown_db(self): cmd = f"PGPASSWORD={self.password} psql -U {self.user} -h localhost -c 'drop database if exists '{self.db_name}';'" return self.execute(cmd) def import_dump(self, dumpname): dump = os.path.join(self.fixture_store, dumpname) cmd = f"PGPASSWORD={self.password} psql -U {self.user} -h localhost -d {self.db_name} -1 -f {dump}" return self.execute(cmd)