diff --git a/IPython/kernel/tests/test_connect.py b/IPython/kernel/tests/test_connect.py new file mode 100644 index 0000000..273e611 --- /dev/null +++ b/IPython/kernel/tests/test_connect.py @@ -0,0 +1,122 @@ +"""Tests for kernel connection utilities + +Authors +------- +* MinRK +""" +#----------------------------------------------------------------------------- +# Copyright (c) 2013, the IPython Development Team. +# +# Distributed under the terms of the Modified BSD License. +# +# The full license is in the file COPYING.txt, distributed with this software. +#----------------------------------------------------------------------------- + +#----------------------------------------------------------------------------- +# Imports +#----------------------------------------------------------------------------- + +import json +import os + +import nose.tools as nt + +from IPython.config import Config +from IPython.consoleapp import IPythonConsoleApp +from IPython.core.application import BaseIPythonApplication +from IPython.utils.tempdir import TemporaryDirectory, TemporaryWorkingDirectory +from IPython.utils.py3compat import str_to_bytes +from IPython.kernel import connect + +#----------------------------------------------------------------------------- +# Classes and functions +#----------------------------------------------------------------------------- + +class DummyConsoleApp(BaseIPythonApplication, IPythonConsoleApp): + def initialize(self, argv=[]): + BaseIPythonApplication.initialize(self, argv=argv) + self.init_connection_file() + +sample_info = dict(ip='1.2.3.4', transport='ipc', + shell_port=1, hb_port=2, iopub_port=3, stdin_port=4, control_port=5, + key=b'abc123', signature_scheme='hmac-md5', + ) + +def test_write_connection_file(): + with TemporaryDirectory() as d: + cf = os.path.join(d, 'kernel.json') + connect.write_connection_file(cf, **sample_info) + nt.assert_true(os.path.exists(cf)) + with open(cf, 'r') as f: + info = json.load(f) + info['key'] = str_to_bytes(info['key']) + nt.assert_equal(info, sample_info) + +def test_app_load_connection_file(): + """test `ipython console --existing` loads a connection file""" + with TemporaryDirectory() as d: + cf = os.path.join(d, 'kernel.json') + connect.write_connection_file(cf, **sample_info) + app = DummyConsoleApp(connection_file=cf) + app.initialize(argv=[]) + + for attr, expected in sample_info.items(): + if attr in ('key', 'signature_scheme'): + continue + value = getattr(app, attr) + nt.assert_equal(value, expected, "app.%s = %s != %s" % (attr, value, expected)) + +def test_get_connection_file(): + cfg = Config() + with TemporaryWorkingDirectory() as d: + cfg.ProfileDir.location = d + cf = 'kernel.json' + app = DummyConsoleApp(config=cfg, connection_file=cf) + app.initialize(argv=[]) + + profile_cf = os.path.join(app.profile_dir.location, 'security', cf) + nt.assert_equal(profile_cf, app.connection_file) + with open(profile_cf, 'w') as f: + f.write("{}") + nt.assert_true(os.path.exists(profile_cf)) + nt.assert_equal(connect.get_connection_file(app), profile_cf) + + app.connection_file = cf + nt.assert_equal(connect.get_connection_file(app), profile_cf) + +def test_find_connection_file(): + cfg = Config() + with TemporaryDirectory() as d: + cfg.ProfileDir.location = d + cf = 'kernel.json' + app = DummyConsoleApp(config=cfg, connection_file=cf) + app.initialize(argv=[]) + BaseIPythonApplication._instance = app + + profile_cf = os.path.join(app.profile_dir.location, 'security', cf) + with open(profile_cf, 'w') as f: + f.write("{}") + + for query in ( + 'kernel.json', + 'kern*', + '*ernel*', + 'k*', + ): + nt.assert_equal(connect.find_connection_file(query), profile_cf) + + BaseIPythonApplication._instance = None + +def test_get_connection_info(): + with TemporaryDirectory() as d: + cf = os.path.join(d, 'kernel.json') + connect.write_connection_file(cf, **sample_info) + json_info = connect.get_connection_info(cf) + info = connect.get_connection_info(cf, unpack=True) + + nt.assert_equal(type(json_info), type("")) + nt.assert_equal(info, sample_info) + + info2 = json.loads(json_info) + info2['key'] = str_to_bytes(info2['key']) + nt.assert_equal(info2, sample_info)