# Copyright (C) 2010-2023 RhodeCode GmbH # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License, version 3 # (only), as published by the Free Software Foundation. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # # This program is dual-licensed. If you wish to learn more about the # RhodeCode Enterprise Edition, including its added features, Support services, # and proprietary license terms, please see https://rhodecode.com/licenses/ import logging import io import mock import msgpack import pytest import tempfile from rhodecode.lib.hook_daemon import http_hooks_deamon from rhodecode.lib.hook_daemon import celery_hooks_deamon from rhodecode.lib.hook_daemon import hook_module from rhodecode.lib.hook_daemon import base as hook_base from rhodecode.lib.str_utils import safe_bytes from rhodecode.tests.utils import assert_message_in_log from rhodecode.lib.ext_json import json test_proto = http_hooks_deamon.HooksHttpHandler.MSGPACK_HOOKS_PROTO class TestHooks(object): def test_hooks_can_be_used_as_a_context_processor(self): hooks = hook_module.Hooks() with hooks as return_value: pass assert hooks == return_value class TestHooksHttpHandler(object): def test_read_request_parses_method_name_and_arguments(self): data = { 'method': 'test', 'extras': { 'param1': 1, 'param2': 'a' } } request = self._generate_post_request(data) hooks_patcher = mock.patch.object( hook_module.Hooks, data['method'], create=True, return_value=1) with hooks_patcher as hooks_mock: handler = http_hooks_deamon.HooksHttpHandler handler.DEFAULT_HOOKS_PROTO = test_proto handler.wbufsize = 10240 MockServer(handler, request) hooks_mock.assert_called_once_with(data['extras']) def test_hooks_serialized_result_is_returned(self): request = self._generate_post_request({}) rpc_method = 'test' hook_result = { 'first': 'one', 'second': 2 } extras = {} # patching our _read to return test method and proto used read_patcher = mock.patch.object( http_hooks_deamon.HooksHttpHandler, '_read_request', return_value=(test_proto, rpc_method, extras)) # patch Hooks instance to return hook_result data on 'test' call hooks_patcher = mock.patch.object( hook_module.Hooks, rpc_method, create=True, return_value=hook_result) with read_patcher, hooks_patcher: handler = http_hooks_deamon.HooksHttpHandler handler.DEFAULT_HOOKS_PROTO = test_proto handler.wbufsize = 10240 server = MockServer(handler, request) expected_result = http_hooks_deamon.HooksHttpHandler.serialize_data(hook_result) server.request.output_stream.seek(0) assert server.request.output_stream.readlines()[-1] == expected_result def test_exception_is_returned_in_response(self): request = self._generate_post_request({}) rpc_method = 'test' read_patcher = mock.patch.object( http_hooks_deamon.HooksHttpHandler, '_read_request', return_value=(test_proto, rpc_method, {})) hooks_patcher = mock.patch.object( hook_module.Hooks, rpc_method, create=True, side_effect=Exception('Test exception')) with read_patcher, hooks_patcher: handler = http_hooks_deamon.HooksHttpHandler handler.DEFAULT_HOOKS_PROTO = test_proto handler.wbufsize = 10240 server = MockServer(handler, request) server.request.output_stream.seek(0) data = server.request.output_stream.readlines() msgpack_data = b''.join(data[5:]) org_exc = http_hooks_deamon.HooksHttpHandler.deserialize_data(msgpack_data) expected_result = { 'exception': 'Exception', 'exception_traceback': org_exc['exception_traceback'], 'exception_args': ['Test exception'] } assert org_exc == expected_result def test_log_message_writes_to_debug_log(self, caplog): ip_port = ('0.0.0.0', 8888) handler = http_hooks_deamon.HooksHttpHandler(MockRequest('POST /'), ip_port, mock.Mock()) fake_date = '1/Nov/2015 00:00:00' date_patcher = mock.patch.object( handler, 'log_date_time_string', return_value=fake_date) with date_patcher, caplog.at_level(logging.DEBUG): handler.log_message('Some message %d, %s', 123, 'string') expected_message = f"HOOKS: client={ip_port} - - [{fake_date}] Some message 123, string" assert_message_in_log( caplog.records, expected_message, levelno=logging.DEBUG, module='http_hooks_deamon') def _generate_post_request(self, data, proto=test_proto): if proto == http_hooks_deamon.HooksHttpHandler.MSGPACK_HOOKS_PROTO: payload = msgpack.packb(data) else: payload = json.dumps(data) return b'POST / HTTP/1.0\nContent-Length: %d\n\n%b' % ( len(payload), payload) class ThreadedHookCallbackDaemon(object): def test_constructor_calls_prepare(self): prepare_daemon_patcher = mock.patch.object( http_hooks_deamon.ThreadedHookCallbackDaemon, '_prepare') with prepare_daemon_patcher as prepare_daemon_mock: http_hooks_deamon.ThreadedHookCallbackDaemon() prepare_daemon_mock.assert_called_once_with() def test_run_is_called_on_context_start(self): patchers = mock.patch.multiple( http_hooks_deamon.ThreadedHookCallbackDaemon, _run=mock.DEFAULT, _prepare=mock.DEFAULT, __exit__=mock.DEFAULT) with patchers as mocks: daemon = http_hooks_deamon.ThreadedHookCallbackDaemon() with daemon as daemon_context: pass mocks['_run'].assert_called_once_with() assert daemon_context == daemon def test_stop_is_called_on_context_exit(self): patchers = mock.patch.multiple( http_hooks_deamon.ThreadedHookCallbackDaemon, _run=mock.DEFAULT, _prepare=mock.DEFAULT, _stop=mock.DEFAULT) with patchers as mocks: daemon = http_hooks_deamon.ThreadedHookCallbackDaemon() with daemon as daemon_context: assert mocks['_stop'].call_count == 0 mocks['_stop'].assert_called_once_with() assert daemon_context == daemon class TestHttpHooksCallbackDaemon(object): def test_hooks_callback_generates_new_port(self, caplog): with caplog.at_level(logging.DEBUG): daemon = http_hooks_deamon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881) assert daemon._daemon.server_address == ('127.0.0.1', 8881) with caplog.at_level(logging.DEBUG): daemon = http_hooks_deamon.HttpHooksCallbackDaemon(host=None, port=None) assert daemon._daemon.server_address[1] in range(0, 66000) assert daemon._daemon.server_address[0] != '127.0.0.1' def test_prepare_inits_daemon_variable(self, tcp_server, caplog): with self._tcp_patcher(tcp_server), caplog.at_level(logging.DEBUG): daemon = http_hooks_deamon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881) assert daemon._daemon == tcp_server _, port = tcp_server.server_address msg = f"HOOKS: 127.0.0.1:{port} Preparing HTTP callback daemon registering " \ f"hook object: " assert_message_in_log( caplog.records, msg, levelno=logging.DEBUG, module='http_hooks_deamon') def test_prepare_inits_hooks_uri_and_logs_it( self, tcp_server, caplog): with self._tcp_patcher(tcp_server), caplog.at_level(logging.DEBUG): daemon = http_hooks_deamon.HttpHooksCallbackDaemon(host='127.0.0.1', port=8881) _, port = tcp_server.server_address expected_uri = '{}:{}'.format('127.0.0.1', port) assert daemon.hooks_uri == expected_uri msg = f"HOOKS: 127.0.0.1:{port} Preparing HTTP callback daemon registering " \ f"hook object: " assert_message_in_log( caplog.records, msg, levelno=logging.DEBUG, module='http_hooks_deamon') def test_run_creates_a_thread(self, tcp_server): thread = mock.Mock() with self._tcp_patcher(tcp_server): daemon = http_hooks_deamon.HttpHooksCallbackDaemon() with self._thread_patcher(thread) as thread_mock: daemon._run() thread_mock.assert_called_once_with( target=tcp_server.serve_forever, kwargs={'poll_interval': daemon.POLL_INTERVAL}) assert thread.daemon is True thread.start.assert_called_once_with() def test_run_logs(self, tcp_server, caplog): with self._tcp_patcher(tcp_server): daemon = http_hooks_deamon.HttpHooksCallbackDaemon() with self._thread_patcher(mock.Mock()), caplog.at_level(logging.DEBUG): daemon._run() assert_message_in_log( caplog.records, 'Running thread-based loop of callback daemon in background', levelno=logging.DEBUG, module='http_hooks_deamon') def test_stop_cleans_up_the_connection(self, tcp_server, caplog): thread = mock.Mock() with self._tcp_patcher(tcp_server): daemon = http_hooks_deamon.HttpHooksCallbackDaemon() with self._thread_patcher(thread), caplog.at_level(logging.DEBUG): with daemon: assert daemon._daemon == tcp_server assert daemon._callback_thread == thread assert daemon._daemon is None assert daemon._callback_thread is None tcp_server.shutdown.assert_called_with() thread.join.assert_called_once_with() assert_message_in_log( caplog.records, 'Waiting for background thread to finish.', levelno=logging.DEBUG, module='http_hooks_deamon') def _tcp_patcher(self, tcp_server): return mock.patch.object( http_hooks_deamon, 'TCPServer', return_value=tcp_server) def _thread_patcher(self, thread): return mock.patch.object( http_hooks_deamon.threading, 'Thread', return_value=thread) class TestPrepareHooksDaemon(object): @pytest.mark.parametrize('protocol', ('celery',)) def test_returns_celery_hooks_callback_daemon_when_celery_protocol_specified( self, protocol): with tempfile.NamedTemporaryFile(mode='w') as temp_file: temp_file.write("[app:main]\ncelery.broker_url = redis://redis/0\n" "celery.result_backend = redis://redis/0") temp_file.flush() expected_extras = {'config': temp_file.name} callback, extras = hook_base.prepare_callback_daemon( expected_extras, protocol=protocol, host='') assert isinstance(callback, celery_hooks_deamon.CeleryHooksCallbackDaemon) @pytest.mark.parametrize('protocol, expected_class', ( ('http', http_hooks_deamon.HttpHooksCallbackDaemon), )) def test_returns_real_hooks_callback_daemon_when_protocol_is_specified( self, protocol, expected_class): expected_extras = { 'extra1': 'value1', 'txn_id': 'txnid2', 'hooks_protocol': protocol.lower(), 'task_backend': '', 'task_queue': '' } callback, extras = hook_base.prepare_callback_daemon( expected_extras.copy(), protocol=protocol, host='127.0.0.1', txn_id='txnid2') assert isinstance(callback, expected_class) extras.pop('hooks_uri') expected_extras['time'] = extras['time'] assert extras == expected_extras @pytest.mark.parametrize('protocol', ( 'invalid', 'Http', 'HTTP', )) def test_raises_on_invalid_protocol(self, protocol): expected_extras = { 'extra1': 'value1', 'hooks_protocol': protocol.lower() } with pytest.raises(Exception): callback, extras = hook_base.prepare_callback_daemon( expected_extras.copy(), protocol=protocol, host='127.0.0.1') class MockRequest(object): def __init__(self, request): self.request = request self.input_stream = io.BytesIO(safe_bytes(self.request)) self.output_stream = io.BytesIO() # make it un-closable for testing invesitagion self.output_stream.close = lambda: None def makefile(self, mode, *args, **kwargs): return self.output_stream if mode == 'wb' else self.input_stream class MockServer(object): def __init__(self, handler_cls, request): ip_port = ('0.0.0.0', 8888) self.request = MockRequest(request) self.server_address = ip_port self.handler = handler_cls(self.request, ip_port, self) @pytest.fixture() def tcp_server(): server = mock.Mock() server.server_address = ('127.0.0.1', 8881) server.wbufsize = 1024 return server