# -*- coding: utf-8 -*-

# Copyright (C) 2010-2018 RhodeCode GmbH
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License, version 3
# (only), as published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# This program is dual-licensed. If you wish to learn more about the
# RhodeCode Enterprise Edition, including its added features, Support services,
# and proprietary license terms, please see https://rhodecode.com/licenses/

import json
import logging
from StringIO import StringIO

import mock
import pytest

from rhodecode.lib import hooks_daemon
from rhodecode.tests.utils import assert_message_in_log


class TestDummyHooksCallbackDaemon(object):
    def test_hooks_module_path_set_properly(self):
        daemon = hooks_daemon.DummyHooksCallbackDaemon()
        assert daemon.hooks_module == 'rhodecode.lib.hooks_daemon'

    def test_logs_entering_the_hook(self):
        daemon = hooks_daemon.DummyHooksCallbackDaemon()
        with mock.patch.object(hooks_daemon.log, 'debug') as log_mock:
            with daemon as return_value:
                log_mock.assert_called_once_with(
                    'Running dummy hooks callback daemon')
        assert return_value == daemon

    def test_logs_exiting_the_hook(self):
        daemon = hooks_daemon.DummyHooksCallbackDaemon()
        with mock.patch.object(hooks_daemon.log, 'debug') as log_mock:
            with daemon:
                pass
        log_mock.assert_called_with('Exiting dummy hooks callback daemon')


class TestHooks(object):
    def test_hooks_can_be_used_as_a_context_processor(self):
        hooks = hooks_daemon.Hooks()
        with hooks as return_value:
            pass
        assert hooks == return_value


class TestHooksHttpHandler(object):
    def test_read_request_parses_method_name_and_arguments(self):
        data = {
            'method': 'test',
            'extras': {
                'param1': 1,
                'param2': 'a'
            }
        }
        request = self._generate_post_request(data)
        hooks_patcher = mock.patch.object(
            hooks_daemon.Hooks, data['method'], create=True, return_value=1)

        with hooks_patcher as hooks_mock:
            MockServer(hooks_daemon.HooksHttpHandler, request)

        hooks_mock.assert_called_once_with(data['extras'])

    def test_hooks_serialized_result_is_returned(self):
        request = self._generate_post_request({})
        rpc_method = 'test'
        hook_result = {
            'first': 'one',
            'second': 2
        }
        read_patcher = mock.patch.object(
            hooks_daemon.HooksHttpHandler, '_read_request',
            return_value=(rpc_method, {}))
        hooks_patcher = mock.patch.object(
            hooks_daemon.Hooks, rpc_method, create=True,
            return_value=hook_result)

        with read_patcher, hooks_patcher:
            server = MockServer(hooks_daemon.HooksHttpHandler, request)

        expected_result = json.dumps(hook_result)
        assert server.request.output_stream.buflist[-1] == expected_result

    def test_exception_is_returned_in_response(self):
        request = self._generate_post_request({})
        rpc_method = 'test'
        read_patcher = mock.patch.object(
            hooks_daemon.HooksHttpHandler, '_read_request',
            return_value=(rpc_method, {}))
        hooks_patcher = mock.patch.object(
            hooks_daemon.Hooks, rpc_method, create=True,
            side_effect=Exception('Test exception'))

        with read_patcher, hooks_patcher:
            server = MockServer(hooks_daemon.HooksHttpHandler, request)

        org_exc = json.loads(server.request.output_stream.buflist[-1])
        expected_result = {
            'exception': 'Exception',
            'exception_traceback': org_exc['exception_traceback'],
            'exception_args': ['Test exception']
        }
        assert org_exc == expected_result

    def test_log_message_writes_to_debug_log(self, caplog):
        ip_port = ('0.0.0.0', 8888)
        handler = hooks_daemon.HooksHttpHandler(
            MockRequest('POST /'), ip_port, mock.Mock())
        fake_date = '1/Nov/2015 00:00:00'
        date_patcher = mock.patch.object(
            handler, 'log_date_time_string', return_value=fake_date)
        with date_patcher, caplog.at_level(logging.DEBUG):
            handler.log_message('Some message %d, %s', 123, 'string')

        expected_message = '{} - - [{}] Some message 123, string'.format(
            ip_port[0], fake_date)
        assert_message_in_log(
            caplog.records, expected_message,
            levelno=logging.DEBUG, module='hooks_daemon')

    def _generate_post_request(self, data):
        payload = json.dumps(data)
        return 'POST / HTTP/1.0\nContent-Length: {}\n\n{}'.format(
            len(payload), payload)


class ThreadedHookCallbackDaemon(object):
    def test_constructor_calls_prepare(self):
        prepare_daemon_patcher = mock.patch.object(
            hooks_daemon.ThreadedHookCallbackDaemon, '_prepare')
        with prepare_daemon_patcher as prepare_daemon_mock:
            hooks_daemon.ThreadedHookCallbackDaemon()
        prepare_daemon_mock.assert_called_once_with()

    def test_run_is_called_on_context_start(self):
        patchers = mock.patch.multiple(
            hooks_daemon.ThreadedHookCallbackDaemon,
            _run=mock.DEFAULT, _prepare=mock.DEFAULT, __exit__=mock.DEFAULT)

        with patchers as mocks:
            daemon = hooks_daemon.ThreadedHookCallbackDaemon()
            with daemon as daemon_context:
                pass
            mocks['_run'].assert_called_once_with()
        assert daemon_context == daemon

    def test_stop_is_called_on_context_exit(self):
        patchers = mock.patch.multiple(
            hooks_daemon.ThreadedHookCallbackDaemon,
            _run=mock.DEFAULT, _prepare=mock.DEFAULT, _stop=mock.DEFAULT)

        with patchers as mocks:
            daemon = hooks_daemon.ThreadedHookCallbackDaemon()
            with daemon as daemon_context:
                assert mocks['_stop'].call_count == 0

            mocks['_stop'].assert_called_once_with()
        assert daemon_context == daemon


class TestHttpHooksCallbackDaemon(object):
    def test_prepare_inits_daemon_variable(self, tcp_server, caplog):
        with self._tcp_patcher(tcp_server), caplog.at_level(logging.DEBUG):
            daemon = hooks_daemon.HttpHooksCallbackDaemon()
        assert daemon._daemon == tcp_server

        _, port = tcp_server.server_address
        expected_uri = '{}:{}'.format('127.0.0.1', port)
        msg = 'Preparing HTTP callback daemon at `{}` and ' \
              'registering hook object'.format(expected_uri)
        assert_message_in_log(
            caplog.records, msg, levelno=logging.DEBUG, module='hooks_daemon')

    def test_prepare_inits_hooks_uri_and_logs_it(
            self, tcp_server, caplog):
        with self._tcp_patcher(tcp_server), caplog.at_level(logging.DEBUG):
            daemon = hooks_daemon.HttpHooksCallbackDaemon()

        _, port = tcp_server.server_address
        expected_uri = '{}:{}'.format('127.0.0.1', port)
        assert daemon.hooks_uri == expected_uri

        msg = 'Preparing HTTP callback daemon at `{}` and ' \
              'registering hook object'.format(expected_uri)
        assert_message_in_log(
            caplog.records, msg,
            levelno=logging.DEBUG, module='hooks_daemon')

    def test_run_creates_a_thread(self, tcp_server):
        thread = mock.Mock()

        with self._tcp_patcher(tcp_server):
            daemon = hooks_daemon.HttpHooksCallbackDaemon()

        with self._thread_patcher(thread) as thread_mock:
            daemon._run()

        thread_mock.assert_called_once_with(
            target=tcp_server.serve_forever,
            kwargs={'poll_interval': daemon.POLL_INTERVAL})
        assert thread.daemon is True
        thread.start.assert_called_once_with()

    def test_run_logs(self, tcp_server, caplog):

        with self._tcp_patcher(tcp_server):
            daemon = hooks_daemon.HttpHooksCallbackDaemon()

        with self._thread_patcher(mock.Mock()), caplog.at_level(logging.DEBUG):
            daemon._run()

        assert_message_in_log(
            caplog.records,
            'Running event loop of callback daemon in background thread',
            levelno=logging.DEBUG, module='hooks_daemon')

    def test_stop_cleans_up_the_connection(self, tcp_server, caplog):
        thread = mock.Mock()

        with self._tcp_patcher(tcp_server):
            daemon = hooks_daemon.HttpHooksCallbackDaemon()

        with self._thread_patcher(thread), caplog.at_level(logging.DEBUG):
            with daemon:
                assert daemon._daemon == tcp_server
                assert daemon._callback_thread == thread

        assert daemon._daemon is None
        assert daemon._callback_thread is None
        tcp_server.shutdown.assert_called_with()
        thread.join.assert_called_once_with()

        assert_message_in_log(
            caplog.records, 'Waiting for background thread to finish.',
            levelno=logging.DEBUG, module='hooks_daemon')

    def _tcp_patcher(self, tcp_server):
        return mock.patch.object(
            hooks_daemon, 'TCPServer', return_value=tcp_server)

    def _thread_patcher(self, thread):
        return mock.patch.object(
            hooks_daemon.threading, 'Thread', return_value=thread)


class TestPrepareHooksDaemon(object):
    @pytest.mark.parametrize('protocol', ('http',))
    def test_returns_dummy_hooks_callback_daemon_when_using_direct_calls(
            self, protocol):
        expected_extras = {'extra1': 'value1'}
        callback, extras = hooks_daemon.prepare_callback_daemon(
            expected_extras.copy(), protocol=protocol,
            host='127.0.0.1', use_direct_calls=True)
        assert isinstance(callback, hooks_daemon.DummyHooksCallbackDaemon)
        expected_extras['hooks_module'] = 'rhodecode.lib.hooks_daemon'
        expected_extras['time'] = extras['time']
        assert 'extra1' in extras

    @pytest.mark.parametrize('protocol, expected_class', (
        ('http', hooks_daemon.HttpHooksCallbackDaemon),
    ))
    def test_returns_real_hooks_callback_daemon_when_protocol_is_specified(
            self, protocol, expected_class):
        expected_extras = {
            'extra1': 'value1',
            'txn_id': 'txnid2',
            'hooks_protocol': protocol.lower()
        }
        callback, extras = hooks_daemon.prepare_callback_daemon(
            expected_extras.copy(), protocol=protocol, host='127.0.0.1',
            use_direct_calls=False,
            txn_id='txnid2')
        assert isinstance(callback, expected_class)
        extras.pop('hooks_uri')
        expected_extras['time'] = extras['time']
        assert extras == expected_extras

    @pytest.mark.parametrize('protocol', (
        'invalid',
        'Http',
        'HTTP',
    ))
    def test_raises_on_invalid_protocol(self, protocol):
        expected_extras = {
            'extra1': 'value1',
            'hooks_protocol': protocol.lower()
        }
        with pytest.raises(Exception):
            callback, extras = hooks_daemon.prepare_callback_daemon(
                expected_extras.copy(),
                protocol=protocol, host='127.0.0.1',
                use_direct_calls=False)


class MockRequest(object):
    def __init__(self, request):
        self.request = request
        self.input_stream = StringIO(b'{}'.format(self.request))
        self.output_stream = StringIO()

    def makefile(self, mode, *args, **kwargs):
        return self.output_stream if mode == 'wb' else self.input_stream


class MockServer(object):
    def __init__(self, Handler, request):
        ip_port = ('0.0.0.0', 8888)
        self.request = MockRequest(request)
        self.handler = Handler(self.request, ip_port, self)


@pytest.fixture
def tcp_server():
    server = mock.Mock()
    server.server_address = ('127.0.0.1', 8881)
    return server