# -*- coding: utf-8 -*-

# Copyright (C) 2010-2018 RhodeCode GmbH
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License, version 3
# (only), as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# This program is dual-licensed. If you wish to learn more about the
# RhodeCode Enterprise Edition, including its added features, Support services,
# and proprietary license terms, please see https://rhodecode.com/licenses/

from subprocess32 import Popen, PIPE
import os
import shutil
import sys
import tempfile

import pytest
from sqlalchemy.engine import url

from rhodecode.tests.fixture import TestINI


def _get_dbs_from_metafunc(metafunc):
    if hasattr(metafunc.function, 'dbs'):
        # Supported backends by this test function, created from
        # pytest.mark.dbs
        backends = metafunc.definition.get_closest_marker('dbs').args
    else:
        backends = metafunc.config.getoption('--dbs')
    return backends


def pytest_generate_tests(metafunc):
    # Support test generation based on --dbs parameter
    if 'db_backend' in metafunc.fixturenames:
        requested_backends = set(metafunc.config.getoption('--dbs'))
        backends = _get_dbs_from_metafunc(metafunc)
        backends = requested_backends.intersection(backends)
        # TODO: johbo: Disabling a backend did not work out with
        # parametrization, find better way to achieve this.
        if not backends:
            metafunc.function._skip = True
        metafunc.parametrize('db_backend_name', backends)


def pytest_collection_modifyitems(session, config, items):
    remaining = [
        i for i in items if not getattr(i.obj, '_skip', False)]
    items[:] = remaining


@pytest.fixture
def db_backend(
        request, db_backend_name, ini_config, tmpdir_factory):
    basetemp = tmpdir_factory.getbasetemp().strpath
    klass = _get_backend(db_backend_name)

    option_name = '--{}-connection-string'.format(db_backend_name)
    connection_string = request.config.getoption(option_name) or None

    return klass(
        config_file=ini_config, basetemp=basetemp,
        connection_string=connection_string)


def _get_backend(backend_type):
    return {
        'sqlite': SQLiteDBBackend,
        'postgres': PostgresDBBackend,
        'mysql': MySQLDBBackend,
        '': EmptyDBBackend
    }[backend_type]


class DBBackend(object):
    _store = os.path.dirname(os.path.abspath(__file__))
    _type = None
    _base_ini_config = [{'app:main': {'vcs.start_server': 'false',
                                      'startup.import_repos': 'false',
                                      'is_test': 'False'}}]
    _db_url = [{'app:main': {'sqlalchemy.db1.url': ''}}]
    _base_db_name = 'rhodecode_test_db_backend'

    def __init__(
            self, config_file, db_name=None, basetemp=None,
            connection_string=None):

        from rhodecode.lib.vcs.backends.hg import largefiles_store
        from rhodecode.lib.vcs.backends.git import lfs_store

        self.fixture_store = os.path.join(self._store, self._type)
        self.db_name = db_name or self._base_db_name
        self._base_ini_file = config_file
        self.stderr = ''
        self.stdout = ''
        self._basetemp = basetemp or tempfile.gettempdir()
        self._repos_location = os.path.join(self._basetemp, 'rc_test_repos')
        self._repos_hg_largefiles_store = largefiles_store(self._basetemp)
        self._repos_git_lfs_store = lfs_store(self._basetemp)
        self.connection_string = connection_string

    @property
    def connection_string(self):
        return self._connection_string

    @connection_string.setter
    def connection_string(self, new_connection_string):
        if not new_connection_string:
            new_connection_string = self.get_default_connection_string()
        else:
            new_connection_string = new_connection_string.format(
                db_name=self.db_name)
        url_parts = url.make_url(new_connection_string)
        self._connection_string = new_connection_string
        self.user = url_parts.username
        self.password = url_parts.password
        self.host = url_parts.host

    def get_default_connection_string(self):
        raise NotImplementedError('default connection_string is required.')

    def execute(self, cmd, env=None, *args):
        """
        Runs command on the system with given ``args``.
        """

        command = cmd + ' ' + ' '.join(args)
        sys.stdout.write(command)

        # Tell Python to use UTF-8 encoding out stdout
        _env = os.environ.copy()
        _env['PYTHONIOENCODING'] = 'UTF-8'
        if env:
            _env.update(env)
        self.p = Popen(command, shell=True, stdout=PIPE, stderr=PIPE, env=_env)
        self.stdout, self.stderr = self.p.communicate()
        sys.stdout.write('COMMAND:'+command+'\n')
        sys.stdout.write(self.stdout)
        return self.stdout, self.stderr

    def assert_returncode_success(self):
        if not self.p.returncode == 0:
            print(self.stderr)
            raise AssertionError('non 0 retcode:{}'.format(self.p.returncode))

    def assert_correct_output(self, stdout, version):
        assert 'UPGRADE FOR STEP {} COMPLETED'.format(version) in stdout

    def setup_rhodecode_db(self, ini_params=None, env=None):
        if not ini_params:
            ini_params = self._base_ini_config

        ini_params.extend(self._db_url)
        with TestINI(self._base_ini_file, ini_params,
                     self._type, destroy=True) as _ini_file:

            if not os.path.isdir(self._repos_location):
                os.makedirs(self._repos_location)
            if not os.path.isdir(self._repos_hg_largefiles_store):
                os.makedirs(self._repos_hg_largefiles_store)
            if not os.path.isdir(self._repos_git_lfs_store):
                os.makedirs(self._repos_git_lfs_store)

            return self.execute(
                "rc-setup-app {0} --user=marcink "
                "--email=marcin@rhodeocode.com --password={1} "
                "--repos={2} --force-yes".format(
                    _ini_file, 'qweqwe', self._repos_location), env=env)

    def upgrade_database(self, ini_params=None):
        if not ini_params:
            ini_params = self._base_ini_config
        ini_params.extend(self._db_url)

        test_ini = TestINI(
            self._base_ini_file, ini_params, self._type, destroy=True)
        with test_ini as ini_file:
            if not os.path.isdir(self._repos_location):
                os.makedirs(self._repos_location)

            return self.execute(
                "rc-upgrade-db {0} --force-yes".format(ini_file))

    def setup_db(self):
        raise NotImplementedError

    def teardown_db(self):
        raise NotImplementedError

    def import_dump(self, dumpname):
        raise NotImplementedError


class EmptyDBBackend(DBBackend):
    _type = ''

    def setup_db(self):
        pass

    def teardown_db(self):
        pass

    def import_dump(self, dumpname):
        pass

    def assert_returncode_success(self):
        assert True


class SQLiteDBBackend(DBBackend):
    _type = 'sqlite'

    def get_default_connection_string(self):
        return 'sqlite:///{}/{}.sqlite'.format(self._basetemp, self.db_name)

    def setup_db(self):
        # dump schema for tests
        # cp -v $TEST_DB_NAME
        self._db_url = [{'app:main': {
            'sqlalchemy.db1.url': self.connection_string}}]

    def import_dump(self, dumpname):
        dump = os.path.join(self.fixture_store, dumpname)
        target = os.path.join(self._basetemp, '{0.db_name}.sqlite'.format(self))
        return self.execute('cp -v {} {}'.format(dump, target))

    def teardown_db(self):
        return self.execute("rm -rf {}.sqlite".format(
            os.path.join(self._basetemp, self.db_name)))


class MySQLDBBackend(DBBackend):
    _type = 'mysql'

    def get_default_connection_string(self):
        return 'mysql://root:qweqwe@127.0.0.1/{}'.format(self.db_name)

    def setup_db(self):
        # dump schema for tests
        # mysqldump -uroot -pqweqwe $TEST_DB_NAME
        self._db_url = [{'app:main': {
            'sqlalchemy.db1.url': self.connection_string}}]
        return self.execute("mysql -v -u{} -p{} -e 'create database '{}';'".format(
            self.user, self.password, self.db_name))

    def import_dump(self, dumpname):
        dump = os.path.join(self.fixture_store, dumpname)
        return self.execute("mysql -u{} -p{} {} < {}".format(
            self.user, self.password, self.db_name, dump))

    def teardown_db(self):
        return self.execute("mysql -v -u{} -p{} -e 'drop database '{}';'".format(
            self.user, self.password, self.db_name))


class PostgresDBBackend(DBBackend):
    _type = 'postgres'

    def get_default_connection_string(self):
        return 'postgresql://postgres:qweqwe@localhost/{}'.format(self.db_name)

    def setup_db(self):
        # dump schema for tests
        # pg_dump -U postgres -h localhost $TEST_DB_NAME
        self._db_url = [{'app:main': {
            'sqlalchemy.db1.url':
                self.connection_string}}]
        return self.execute("PGPASSWORD={} psql -U {} -h localhost "
                     "-c 'create database '{}';'".format(
                         self.password, self.user, self.db_name))

    def teardown_db(self):
        return self.execute("PGPASSWORD={} psql -U {} -h localhost "
                     "-c 'drop database if exists '{}';'".format(
                         self.password, self.user, self.db_name))

    def import_dump(self, dumpname):
        dump = os.path.join(self.fixture_store, dumpname)
        return self.execute(
            "PGPASSWORD={} psql -U {} -h localhost -d {} -1 "
            "-f {}".format(
                self.password, self.user, self.db_name, dump))