# Copyright (C) 2010-2023 RhodeCode GmbH # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License, version 3 # (only), as published by the Free Software Foundation. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # # This program is dual-licensed. If you wish to learn more about the # RhodeCode Enterprise Edition, including its added features, Support services, # and proprietary license terms, please see https://rhodecode.com/licenses/ from subprocess import Popen, PIPE import os import sys import tempfile import pytest from sqlalchemy.engine import url from rhodecode.lib.str_utils import safe_str, safe_bytes from rhodecode.tests.fixture import TestINI def _get_dbs_from_metafunc(metafunc): dbs_mark = metafunc.definition.get_closest_marker('dbs') if dbs_mark: # Supported backends by this test function, created from pytest.mark.dbs backends = dbs_mark.args else: backends = metafunc.config.getoption('--dbs') return backends def pytest_generate_tests(metafunc): # Support test generation based on --dbs parameter if 'db_backend' in metafunc.fixturenames: requested_backends = set(metafunc.config.getoption('--dbs')) backends = _get_dbs_from_metafunc(metafunc) backends = requested_backends.intersection(backends) # TODO: johbo: Disabling a backend did not work out with # parametrization, find better way to achieve this. if not backends: metafunc.function._skip = True metafunc.parametrize('db_backend_name', backends) def pytest_collection_modifyitems(session, config, items): remaining = [ i for i in items if not getattr(i.obj, '_skip', False)] items[:] = remaining @pytest.fixture() def db_backend( request, db_backend_name, ini_config, tmpdir_factory): basetemp = tmpdir_factory.getbasetemp().strpath klass = _get_backend(db_backend_name) option_name = '--{}-connection-string'.format(db_backend_name) connection_string = request.config.getoption(option_name) or None return klass( config_file=ini_config, basetemp=basetemp, connection_string=connection_string) def _get_backend(backend_type): return { 'sqlite': SQLiteDBBackend, 'postgres': PostgresDBBackend, 'mysql': MySQLDBBackend, '': EmptyDBBackend }[backend_type] class DBBackend(object): _store = os.path.dirname(os.path.abspath(__file__)) _type = None _base_ini_config = [{'app:main': {'vcs.start_server': 'false', 'startup.import_repos': 'false', 'is_test': 'False'}}] _db_url = [{'app:main': {'sqlalchemy.db1.url': ''}}] _base_db_name = 'rhodecode_test_db_backend' def __init__( self, config_file, db_name=None, basetemp=None, connection_string=None): from rhodecode.lib.vcs.backends.hg import largefiles_store from rhodecode.lib.vcs.backends.git import lfs_store self.fixture_store = os.path.join(self._store, self._type) self.db_name = db_name or self._base_db_name self._base_ini_file = config_file self.stderr = '' self.stdout = '' self._basetemp = basetemp or tempfile.gettempdir() self._repos_location = os.path.join(self._basetemp, 'rc_test_repos') self._repos_hg_largefiles_store = largefiles_store(self._basetemp) self._repos_git_lfs_store = lfs_store(self._basetemp) self.connection_string = connection_string @property def connection_string(self): return self._connection_string @connection_string.setter def connection_string(self, new_connection_string): if not new_connection_string: new_connection_string = self.get_default_connection_string() else: new_connection_string = new_connection_string.format( db_name=self.db_name) url_parts = url.make_url(new_connection_string) self._connection_string = new_connection_string self.user = url_parts.username self.password = url_parts.password self.host = url_parts.host def get_default_connection_string(self): raise NotImplementedError('default connection_string is required.') def execute(self, cmd, env=None, *args): """ Runs command on the system with given ``args``. """ command = cmd + ' ' + ' '.join(args) sys.stdout.write(command) # Tell Python to use UTF-8 encoding out stdout _env = os.environ.copy() _env['PYTHONIOENCODING'] = 'UTF-8' if env: _env.update(env) self.p = Popen(command, shell=True, stdout=PIPE, stderr=PIPE, env=_env) self.stdout, self.stderr = self.p.communicate() stdout_str = safe_str(self.stdout) sys.stdout.write(f'COMMAND:{command}\n') sys.stdout.write(stdout_str) return self.stdout, self.stderr def assert_returncode_success(self): from rich import print as pprint if not self.p.returncode == 0: pprint(safe_str(self.stderr)) raise AssertionError(f'non 0 retcode:{self.p.returncode}') def assert_correct_output(self, stdout, version): assert b'UPGRADE FOR STEP %b COMPLETED' % safe_bytes(version) in stdout def setup_rhodecode_db(self, ini_params=None, env=None): if not ini_params: ini_params = self._base_ini_config ini_params.extend(self._db_url) with TestINI(self._base_ini_file, ini_params, self._type, destroy=True) as _ini_file: if not os.path.isdir(self._repos_location): os.makedirs(self._repos_location) if not os.path.isdir(self._repos_hg_largefiles_store): os.makedirs(self._repos_hg_largefiles_store) if not os.path.isdir(self._repos_git_lfs_store): os.makedirs(self._repos_git_lfs_store) return self.execute( "rc-setup-app {0} --user=marcink " "--email=marcin@rhodeocode.com --password={1} " "--repos={2} --force-yes".format( _ini_file, 'qweqwe', self._repos_location), env=env) def upgrade_database(self, ini_params=None): if not ini_params: ini_params = self._base_ini_config ini_params.extend(self._db_url) test_ini = TestINI( self._base_ini_file, ini_params, self._type, destroy=True) with test_ini as ini_file: if not os.path.isdir(self._repos_location): os.makedirs(self._repos_location) return self.execute( "rc-upgrade-db {0} --force-yes".format(ini_file)) def setup_db(self): raise NotImplementedError def teardown_db(self): raise NotImplementedError def import_dump(self, dumpname): raise NotImplementedError class EmptyDBBackend(DBBackend): _type = '' def setup_db(self): pass def teardown_db(self): pass def import_dump(self, dumpname): pass def assert_returncode_success(self): assert True class SQLiteDBBackend(DBBackend): _type = 'sqlite' def get_default_connection_string(self): return 'sqlite:///{}/{}.sqlite'.format(self._basetemp, self.db_name) def setup_db(self): # dump schema for tests # cp -v $TEST_DB_NAME self._db_url = [{'app:main': { 'sqlalchemy.db1.url': self.connection_string}}] def import_dump(self, dumpname): dump = os.path.join(self.fixture_store, dumpname) target = os.path.join(self._basetemp, '{0.db_name}.sqlite'.format(self)) return self.execute(f'cp -v {dump} {target}') def teardown_db(self): target_db = os.path.join(self._basetemp, self.db_name) return self.execute(f"rm -rf {target_db}.sqlite") class MySQLDBBackend(DBBackend): _type = 'mysql' def get_default_connection_string(self): return 'mysql://root:qweqwe@127.0.0.1/{}'.format(self.db_name) def setup_db(self): # dump schema for tests # mysqldump -uroot -pqweqwe $TEST_DB_NAME self._db_url = [{'app:main': { 'sqlalchemy.db1.url': self.connection_string}}] return self.execute("mysql -v -u{} -p{} -e 'create database '{}';'".format( self.user, self.password, self.db_name)) def import_dump(self, dumpname): dump = os.path.join(self.fixture_store, dumpname) return self.execute("mysql -u{} -p{} {} < {}".format( self.user, self.password, self.db_name, dump)) def teardown_db(self): return self.execute("mysql -v -u{} -p{} -e 'drop database '{}';'".format( self.user, self.password, self.db_name)) class PostgresDBBackend(DBBackend): _type = 'postgres' def get_default_connection_string(self): return 'postgresql://postgres:qweqwe@localhost/{}'.format(self.db_name) def setup_db(self): # dump schema for tests # pg_dump -U postgres -h localhost $TEST_DB_NAME self._db_url = [{'app:main': {'sqlalchemy.db1.url': self.connection_string}}] cmd = f"PGPASSWORD={self.password} psql -U {self.user} -h localhost -c 'create database '{self.db_name}';'" return self.execute(cmd) def teardown_db(self): cmd = f"PGPASSWORD={self.password} psql -U {self.user} -h localhost -c 'drop database if exists '{self.db_name}';'" return self.execute(cmd) def import_dump(self, dumpname): dump = os.path.join(self.fixture_store, dumpname) cmd = f"PGPASSWORD={self.password} psql -U {self.user} -h localhost -d {self.db_name} -1 -f {dump}" return self.execute(cmd)