##// END OF EJS Templates
s/assert_/assertTrue/
Bradley M. Froehle -
Show More
@@ -1,183 +1,183 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Tests for IPython.config.configurable
3 Tests for IPython.config.configurable
4
4
5 Authors:
5 Authors:
6
6
7 * Brian Granger
7 * Brian Granger
8 * Fernando Perez (design help)
8 * Fernando Perez (design help)
9 """
9 """
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Copyright (C) 2008-2011 The IPython Development Team
12 # Copyright (C) 2008-2011 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19 # Imports
19 # Imports
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21
21
22 from unittest import TestCase
22 from unittest import TestCase
23
23
24 from IPython.config.configurable import (
24 from IPython.config.configurable import (
25 Configurable,
25 Configurable,
26 SingletonConfigurable
26 SingletonConfigurable
27 )
27 )
28
28
29 from IPython.utils.traitlets import (
29 from IPython.utils.traitlets import (
30 Integer, Float, Unicode
30 Integer, Float, Unicode
31 )
31 )
32
32
33 from IPython.config.loader import Config
33 from IPython.config.loader import Config
34 from IPython.utils.py3compat import PY3
34 from IPython.utils.py3compat import PY3
35
35
36 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
37 # Test cases
37 # Test cases
38 #-----------------------------------------------------------------------------
38 #-----------------------------------------------------------------------------
39
39
40
40
41 class MyConfigurable(Configurable):
41 class MyConfigurable(Configurable):
42 a = Integer(1, config=True, help="The integer a.")
42 a = Integer(1, config=True, help="The integer a.")
43 b = Float(1.0, config=True, help="The integer b.")
43 b = Float(1.0, config=True, help="The integer b.")
44 c = Unicode('no config')
44 c = Unicode('no config')
45
45
46
46
47 mc_help=u"""MyConfigurable options
47 mc_help=u"""MyConfigurable options
48 ----------------------
48 ----------------------
49 --MyConfigurable.a=<Integer>
49 --MyConfigurable.a=<Integer>
50 Default: 1
50 Default: 1
51 The integer a.
51 The integer a.
52 --MyConfigurable.b=<Float>
52 --MyConfigurable.b=<Float>
53 Default: 1.0
53 Default: 1.0
54 The integer b."""
54 The integer b."""
55
55
56 mc_help_inst=u"""MyConfigurable options
56 mc_help_inst=u"""MyConfigurable options
57 ----------------------
57 ----------------------
58 --MyConfigurable.a=<Integer>
58 --MyConfigurable.a=<Integer>
59 Current: 5
59 Current: 5
60 The integer a.
60 The integer a.
61 --MyConfigurable.b=<Float>
61 --MyConfigurable.b=<Float>
62 Current: 4.0
62 Current: 4.0
63 The integer b."""
63 The integer b."""
64
64
65 # On Python 3, the Integer trait is a synonym for Int
65 # On Python 3, the Integer trait is a synonym for Int
66 if PY3:
66 if PY3:
67 mc_help = mc_help.replace(u"<Integer>", u"<Int>")
67 mc_help = mc_help.replace(u"<Integer>", u"<Int>")
68 mc_help_inst = mc_help_inst.replace(u"<Integer>", u"<Int>")
68 mc_help_inst = mc_help_inst.replace(u"<Integer>", u"<Int>")
69
69
70 class Foo(Configurable):
70 class Foo(Configurable):
71 a = Integer(0, config=True, help="The integer a.")
71 a = Integer(0, config=True, help="The integer a.")
72 b = Unicode('nope', config=True)
72 b = Unicode('nope', config=True)
73
73
74
74
75 class Bar(Foo):
75 class Bar(Foo):
76 b = Unicode('gotit', config=False, help="The string b.")
76 b = Unicode('gotit', config=False, help="The string b.")
77 c = Float(config=True, help="The string c.")
77 c = Float(config=True, help="The string c.")
78
78
79
79
80 class TestConfigurable(TestCase):
80 class TestConfigurable(TestCase):
81
81
82 def test_default(self):
82 def test_default(self):
83 c1 = Configurable()
83 c1 = Configurable()
84 c2 = Configurable(config=c1.config)
84 c2 = Configurable(config=c1.config)
85 c3 = Configurable(config=c2.config)
85 c3 = Configurable(config=c2.config)
86 self.assertEqual(c1.config, c2.config)
86 self.assertEqual(c1.config, c2.config)
87 self.assertEqual(c2.config, c3.config)
87 self.assertEqual(c2.config, c3.config)
88
88
89 def test_custom(self):
89 def test_custom(self):
90 config = Config()
90 config = Config()
91 config.foo = 'foo'
91 config.foo = 'foo'
92 config.bar = 'bar'
92 config.bar = 'bar'
93 c1 = Configurable(config=config)
93 c1 = Configurable(config=config)
94 c2 = Configurable(config=c1.config)
94 c2 = Configurable(config=c1.config)
95 c3 = Configurable(config=c2.config)
95 c3 = Configurable(config=c2.config)
96 self.assertEqual(c1.config, config)
96 self.assertEqual(c1.config, config)
97 self.assertEqual(c2.config, config)
97 self.assertEqual(c2.config, config)
98 self.assertEqual(c3.config, config)
98 self.assertEqual(c3.config, config)
99 # Test that copies are not made
99 # Test that copies are not made
100 self.assert_(c1.config is config)
100 self.assertTrue(c1.config is config)
101 self.assert_(c2.config is config)
101 self.assertTrue(c2.config is config)
102 self.assert_(c3.config is config)
102 self.assertTrue(c3.config is config)
103 self.assert_(c1.config is c2.config)
103 self.assertTrue(c1.config is c2.config)
104 self.assert_(c2.config is c3.config)
104 self.assertTrue(c2.config is c3.config)
105
105
106 def test_inheritance(self):
106 def test_inheritance(self):
107 config = Config()
107 config = Config()
108 config.MyConfigurable.a = 2
108 config.MyConfigurable.a = 2
109 config.MyConfigurable.b = 2.0
109 config.MyConfigurable.b = 2.0
110 c1 = MyConfigurable(config=config)
110 c1 = MyConfigurable(config=config)
111 c2 = MyConfigurable(config=c1.config)
111 c2 = MyConfigurable(config=c1.config)
112 self.assertEqual(c1.a, config.MyConfigurable.a)
112 self.assertEqual(c1.a, config.MyConfigurable.a)
113 self.assertEqual(c1.b, config.MyConfigurable.b)
113 self.assertEqual(c1.b, config.MyConfigurable.b)
114 self.assertEqual(c2.a, config.MyConfigurable.a)
114 self.assertEqual(c2.a, config.MyConfigurable.a)
115 self.assertEqual(c2.b, config.MyConfigurable.b)
115 self.assertEqual(c2.b, config.MyConfigurable.b)
116
116
117 def test_parent(self):
117 def test_parent(self):
118 config = Config()
118 config = Config()
119 config.Foo.a = 10
119 config.Foo.a = 10
120 config.Foo.b = "wow"
120 config.Foo.b = "wow"
121 config.Bar.b = 'later'
121 config.Bar.b = 'later'
122 config.Bar.c = 100.0
122 config.Bar.c = 100.0
123 f = Foo(config=config)
123 f = Foo(config=config)
124 b = Bar(config=f.config)
124 b = Bar(config=f.config)
125 self.assertEqual(f.a, 10)
125 self.assertEqual(f.a, 10)
126 self.assertEqual(f.b, 'wow')
126 self.assertEqual(f.b, 'wow')
127 self.assertEqual(b.b, 'gotit')
127 self.assertEqual(b.b, 'gotit')
128 self.assertEqual(b.c, 100.0)
128 self.assertEqual(b.c, 100.0)
129
129
130 def test_override1(self):
130 def test_override1(self):
131 config = Config()
131 config = Config()
132 config.MyConfigurable.a = 2
132 config.MyConfigurable.a = 2
133 config.MyConfigurable.b = 2.0
133 config.MyConfigurable.b = 2.0
134 c = MyConfigurable(a=3, config=config)
134 c = MyConfigurable(a=3, config=config)
135 self.assertEqual(c.a, 3)
135 self.assertEqual(c.a, 3)
136 self.assertEqual(c.b, config.MyConfigurable.b)
136 self.assertEqual(c.b, config.MyConfigurable.b)
137 self.assertEqual(c.c, 'no config')
137 self.assertEqual(c.c, 'no config')
138
138
139 def test_override2(self):
139 def test_override2(self):
140 config = Config()
140 config = Config()
141 config.Foo.a = 1
141 config.Foo.a = 1
142 config.Bar.b = 'or' # Up above b is config=False, so this won't do it.
142 config.Bar.b = 'or' # Up above b is config=False, so this won't do it.
143 config.Bar.c = 10.0
143 config.Bar.c = 10.0
144 c = Bar(config=config)
144 c = Bar(config=config)
145 self.assertEqual(c.a, config.Foo.a)
145 self.assertEqual(c.a, config.Foo.a)
146 self.assertEqual(c.b, 'gotit')
146 self.assertEqual(c.b, 'gotit')
147 self.assertEqual(c.c, config.Bar.c)
147 self.assertEqual(c.c, config.Bar.c)
148 c = Bar(a=2, b='and', c=20.0, config=config)
148 c = Bar(a=2, b='and', c=20.0, config=config)
149 self.assertEqual(c.a, 2)
149 self.assertEqual(c.a, 2)
150 self.assertEqual(c.b, 'and')
150 self.assertEqual(c.b, 'and')
151 self.assertEqual(c.c, 20.0)
151 self.assertEqual(c.c, 20.0)
152
152
153 def test_help(self):
153 def test_help(self):
154 self.assertEqual(MyConfigurable.class_get_help(), mc_help)
154 self.assertEqual(MyConfigurable.class_get_help(), mc_help)
155
155
156 def test_help_inst(self):
156 def test_help_inst(self):
157 inst = MyConfigurable(a=5, b=4)
157 inst = MyConfigurable(a=5, b=4)
158 self.assertEqual(MyConfigurable.class_get_help(inst), mc_help_inst)
158 self.assertEqual(MyConfigurable.class_get_help(inst), mc_help_inst)
159
159
160
160
161 class TestSingletonConfigurable(TestCase):
161 class TestSingletonConfigurable(TestCase):
162
162
163 def test_instance(self):
163 def test_instance(self):
164 from IPython.config.configurable import SingletonConfigurable
164 from IPython.config.configurable import SingletonConfigurable
165 class Foo(SingletonConfigurable): pass
165 class Foo(SingletonConfigurable): pass
166 self.assertEqual(Foo.initialized(), False)
166 self.assertEqual(Foo.initialized(), False)
167 foo = Foo.instance()
167 foo = Foo.instance()
168 self.assertEqual(Foo.initialized(), True)
168 self.assertEqual(Foo.initialized(), True)
169 self.assertEqual(foo, Foo.instance())
169 self.assertEqual(foo, Foo.instance())
170 self.assertEqual(SingletonConfigurable._instance, None)
170 self.assertEqual(SingletonConfigurable._instance, None)
171
171
172 def test_inheritance(self):
172 def test_inheritance(self):
173 class Bar(SingletonConfigurable): pass
173 class Bar(SingletonConfigurable): pass
174 class Bam(Bar): pass
174 class Bam(Bar): pass
175 self.assertEqual(Bar.initialized(), False)
175 self.assertEqual(Bar.initialized(), False)
176 self.assertEqual(Bam.initialized(), False)
176 self.assertEqual(Bam.initialized(), False)
177 bam = Bam.instance()
177 bam = Bam.instance()
178 bam == Bar.instance()
178 bam == Bar.instance()
179 self.assertEqual(Bar.initialized(), True)
179 self.assertEqual(Bar.initialized(), True)
180 self.assertEqual(Bam.initialized(), True)
180 self.assertEqual(Bam.initialized(), True)
181 self.assertEqual(bam, Bam._instance)
181 self.assertEqual(bam, Bam._instance)
182 self.assertEqual(bam, Bar._instance)
182 self.assertEqual(bam, Bar._instance)
183 self.assertEqual(SingletonConfigurable._instance, None)
183 self.assertEqual(SingletonConfigurable._instance, None)
@@ -1,263 +1,263 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Tests for IPython.config.loader
3 Tests for IPython.config.loader
4
4
5 Authors:
5 Authors:
6
6
7 * Brian Granger
7 * Brian Granger
8 * Fernando Perez (design help)
8 * Fernando Perez (design help)
9 """
9 """
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Copyright (C) 2008-2011 The IPython Development Team
12 # Copyright (C) 2008-2011 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19 # Imports
19 # Imports
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21
21
22 import os
22 import os
23 import sys
23 import sys
24 from tempfile import mkstemp
24 from tempfile import mkstemp
25 from unittest import TestCase
25 from unittest import TestCase
26
26
27 from nose import SkipTest
27 from nose import SkipTest
28
28
29 from IPython.testing.tools import mute_warn
29 from IPython.testing.tools import mute_warn
30
30
31 from IPython.utils.traitlets import Unicode
31 from IPython.utils.traitlets import Unicode
32 from IPython.config.configurable import Configurable
32 from IPython.config.configurable import Configurable
33 from IPython.config.loader import (
33 from IPython.config.loader import (
34 Config,
34 Config,
35 PyFileConfigLoader,
35 PyFileConfigLoader,
36 KeyValueConfigLoader,
36 KeyValueConfigLoader,
37 ArgParseConfigLoader,
37 ArgParseConfigLoader,
38 KVArgParseConfigLoader,
38 KVArgParseConfigLoader,
39 ConfigError
39 ConfigError
40 )
40 )
41
41
42 #-----------------------------------------------------------------------------
42 #-----------------------------------------------------------------------------
43 # Actual tests
43 # Actual tests
44 #-----------------------------------------------------------------------------
44 #-----------------------------------------------------------------------------
45
45
46
46
47 pyfile = """
47 pyfile = """
48 c = get_config()
48 c = get_config()
49 c.a=10
49 c.a=10
50 c.b=20
50 c.b=20
51 c.Foo.Bar.value=10
51 c.Foo.Bar.value=10
52 c.Foo.Bam.value=list(range(10)) # list() is just so it's the same on Python 3
52 c.Foo.Bam.value=list(range(10)) # list() is just so it's the same on Python 3
53 c.D.C.value='hi there'
53 c.D.C.value='hi there'
54 """
54 """
55
55
56 class TestPyFileCL(TestCase):
56 class TestPyFileCL(TestCase):
57
57
58 def test_basic(self):
58 def test_basic(self):
59 fd, fname = mkstemp('.py')
59 fd, fname = mkstemp('.py')
60 f = os.fdopen(fd, 'w')
60 f = os.fdopen(fd, 'w')
61 f.write(pyfile)
61 f.write(pyfile)
62 f.close()
62 f.close()
63 # Unlink the file
63 # Unlink the file
64 cl = PyFileConfigLoader(fname)
64 cl = PyFileConfigLoader(fname)
65 config = cl.load_config()
65 config = cl.load_config()
66 self.assertEqual(config.a, 10)
66 self.assertEqual(config.a, 10)
67 self.assertEqual(config.b, 20)
67 self.assertEqual(config.b, 20)
68 self.assertEqual(config.Foo.Bar.value, 10)
68 self.assertEqual(config.Foo.Bar.value, 10)
69 self.assertEqual(config.Foo.Bam.value, range(10))
69 self.assertEqual(config.Foo.Bam.value, range(10))
70 self.assertEqual(config.D.C.value, 'hi there')
70 self.assertEqual(config.D.C.value, 'hi there')
71
71
72 class MyLoader1(ArgParseConfigLoader):
72 class MyLoader1(ArgParseConfigLoader):
73 def _add_arguments(self, aliases=None, flags=None):
73 def _add_arguments(self, aliases=None, flags=None):
74 p = self.parser
74 p = self.parser
75 p.add_argument('-f', '--foo', dest='Global.foo', type=str)
75 p.add_argument('-f', '--foo', dest='Global.foo', type=str)
76 p.add_argument('-b', dest='MyClass.bar', type=int)
76 p.add_argument('-b', dest='MyClass.bar', type=int)
77 p.add_argument('-n', dest='n', action='store_true')
77 p.add_argument('-n', dest='n', action='store_true')
78 p.add_argument('Global.bam', type=str)
78 p.add_argument('Global.bam', type=str)
79
79
80 class MyLoader2(ArgParseConfigLoader):
80 class MyLoader2(ArgParseConfigLoader):
81 def _add_arguments(self, aliases=None, flags=None):
81 def _add_arguments(self, aliases=None, flags=None):
82 subparsers = self.parser.add_subparsers(dest='subparser_name')
82 subparsers = self.parser.add_subparsers(dest='subparser_name')
83 subparser1 = subparsers.add_parser('1')
83 subparser1 = subparsers.add_parser('1')
84 subparser1.add_argument('-x',dest='Global.x')
84 subparser1.add_argument('-x',dest='Global.x')
85 subparser2 = subparsers.add_parser('2')
85 subparser2 = subparsers.add_parser('2')
86 subparser2.add_argument('y')
86 subparser2.add_argument('y')
87
87
88 class TestArgParseCL(TestCase):
88 class TestArgParseCL(TestCase):
89
89
90 def test_basic(self):
90 def test_basic(self):
91 cl = MyLoader1()
91 cl = MyLoader1()
92 config = cl.load_config('-f hi -b 10 -n wow'.split())
92 config = cl.load_config('-f hi -b 10 -n wow'.split())
93 self.assertEqual(config.Global.foo, 'hi')
93 self.assertEqual(config.Global.foo, 'hi')
94 self.assertEqual(config.MyClass.bar, 10)
94 self.assertEqual(config.MyClass.bar, 10)
95 self.assertEqual(config.n, True)
95 self.assertEqual(config.n, True)
96 self.assertEqual(config.Global.bam, 'wow')
96 self.assertEqual(config.Global.bam, 'wow')
97 config = cl.load_config(['wow'])
97 config = cl.load_config(['wow'])
98 self.assertEqual(config.keys(), ['Global'])
98 self.assertEqual(config.keys(), ['Global'])
99 self.assertEqual(config.Global.keys(), ['bam'])
99 self.assertEqual(config.Global.keys(), ['bam'])
100 self.assertEqual(config.Global.bam, 'wow')
100 self.assertEqual(config.Global.bam, 'wow')
101
101
102 def test_add_arguments(self):
102 def test_add_arguments(self):
103 cl = MyLoader2()
103 cl = MyLoader2()
104 config = cl.load_config('2 frobble'.split())
104 config = cl.load_config('2 frobble'.split())
105 self.assertEqual(config.subparser_name, '2')
105 self.assertEqual(config.subparser_name, '2')
106 self.assertEqual(config.y, 'frobble')
106 self.assertEqual(config.y, 'frobble')
107 config = cl.load_config('1 -x frobble'.split())
107 config = cl.load_config('1 -x frobble'.split())
108 self.assertEqual(config.subparser_name, '1')
108 self.assertEqual(config.subparser_name, '1')
109 self.assertEqual(config.Global.x, 'frobble')
109 self.assertEqual(config.Global.x, 'frobble')
110
110
111 def test_argv(self):
111 def test_argv(self):
112 cl = MyLoader1(argv='-f hi -b 10 -n wow'.split())
112 cl = MyLoader1(argv='-f hi -b 10 -n wow'.split())
113 config = cl.load_config()
113 config = cl.load_config()
114 self.assertEqual(config.Global.foo, 'hi')
114 self.assertEqual(config.Global.foo, 'hi')
115 self.assertEqual(config.MyClass.bar, 10)
115 self.assertEqual(config.MyClass.bar, 10)
116 self.assertEqual(config.n, True)
116 self.assertEqual(config.n, True)
117 self.assertEqual(config.Global.bam, 'wow')
117 self.assertEqual(config.Global.bam, 'wow')
118
118
119
119
120 class TestKeyValueCL(TestCase):
120 class TestKeyValueCL(TestCase):
121 klass = KeyValueConfigLoader
121 klass = KeyValueConfigLoader
122
122
123 def test_basic(self):
123 def test_basic(self):
124 cl = self.klass()
124 cl = self.klass()
125 argv = ['--'+s.strip('c.') for s in pyfile.split('\n')[2:-1]]
125 argv = ['--'+s.strip('c.') for s in pyfile.split('\n')[2:-1]]
126 with mute_warn():
126 with mute_warn():
127 config = cl.load_config(argv)
127 config = cl.load_config(argv)
128 self.assertEqual(config.a, 10)
128 self.assertEqual(config.a, 10)
129 self.assertEqual(config.b, 20)
129 self.assertEqual(config.b, 20)
130 self.assertEqual(config.Foo.Bar.value, 10)
130 self.assertEqual(config.Foo.Bar.value, 10)
131 self.assertEqual(config.Foo.Bam.value, range(10))
131 self.assertEqual(config.Foo.Bam.value, range(10))
132 self.assertEqual(config.D.C.value, 'hi there')
132 self.assertEqual(config.D.C.value, 'hi there')
133
133
134 def test_expanduser(self):
134 def test_expanduser(self):
135 cl = self.klass()
135 cl = self.klass()
136 argv = ['--a=~/1/2/3', '--b=~', '--c=~/', '--d="~/"']
136 argv = ['--a=~/1/2/3', '--b=~', '--c=~/', '--d="~/"']
137 with mute_warn():
137 with mute_warn():
138 config = cl.load_config(argv)
138 config = cl.load_config(argv)
139 self.assertEqual(config.a, os.path.expanduser('~/1/2/3'))
139 self.assertEqual(config.a, os.path.expanduser('~/1/2/3'))
140 self.assertEqual(config.b, os.path.expanduser('~'))
140 self.assertEqual(config.b, os.path.expanduser('~'))
141 self.assertEqual(config.c, os.path.expanduser('~/'))
141 self.assertEqual(config.c, os.path.expanduser('~/'))
142 self.assertEqual(config.d, '~/')
142 self.assertEqual(config.d, '~/')
143
143
144 def test_extra_args(self):
144 def test_extra_args(self):
145 cl = self.klass()
145 cl = self.klass()
146 with mute_warn():
146 with mute_warn():
147 config = cl.load_config(['--a=5', 'b', '--c=10', 'd'])
147 config = cl.load_config(['--a=5', 'b', '--c=10', 'd'])
148 self.assertEqual(cl.extra_args, ['b', 'd'])
148 self.assertEqual(cl.extra_args, ['b', 'd'])
149 self.assertEqual(config.a, 5)
149 self.assertEqual(config.a, 5)
150 self.assertEqual(config.c, 10)
150 self.assertEqual(config.c, 10)
151 with mute_warn():
151 with mute_warn():
152 config = cl.load_config(['--', '--a=5', '--c=10'])
152 config = cl.load_config(['--', '--a=5', '--c=10'])
153 self.assertEqual(cl.extra_args, ['--a=5', '--c=10'])
153 self.assertEqual(cl.extra_args, ['--a=5', '--c=10'])
154
154
155 def test_unicode_args(self):
155 def test_unicode_args(self):
156 cl = self.klass()
156 cl = self.klass()
157 argv = [u'--a=épsîlön']
157 argv = [u'--a=épsîlön']
158 with mute_warn():
158 with mute_warn():
159 config = cl.load_config(argv)
159 config = cl.load_config(argv)
160 self.assertEqual(config.a, u'épsîlön')
160 self.assertEqual(config.a, u'épsîlön')
161
161
162 def test_unicode_bytes_args(self):
162 def test_unicode_bytes_args(self):
163 uarg = u'--a=é'
163 uarg = u'--a=é'
164 try:
164 try:
165 barg = uarg.encode(sys.stdin.encoding)
165 barg = uarg.encode(sys.stdin.encoding)
166 except (TypeError, UnicodeEncodeError):
166 except (TypeError, UnicodeEncodeError):
167 raise SkipTest("sys.stdin.encoding can't handle 'é'")
167 raise SkipTest("sys.stdin.encoding can't handle 'é'")
168
168
169 cl = self.klass()
169 cl = self.klass()
170 with mute_warn():
170 with mute_warn():
171 config = cl.load_config([barg])
171 config = cl.load_config([barg])
172 self.assertEqual(config.a, u'é')
172 self.assertEqual(config.a, u'é')
173
173
174 def test_unicode_alias(self):
174 def test_unicode_alias(self):
175 cl = self.klass()
175 cl = self.klass()
176 argv = [u'--a=épsîlön']
176 argv = [u'--a=épsîlön']
177 with mute_warn():
177 with mute_warn():
178 config = cl.load_config(argv, aliases=dict(a='A.a'))
178 config = cl.load_config(argv, aliases=dict(a='A.a'))
179 self.assertEqual(config.A.a, u'épsîlön')
179 self.assertEqual(config.A.a, u'épsîlön')
180
180
181
181
182 class TestArgParseKVCL(TestKeyValueCL):
182 class TestArgParseKVCL(TestKeyValueCL):
183 klass = KVArgParseConfigLoader
183 klass = KVArgParseConfigLoader
184
184
185 def test_expanduser2(self):
185 def test_expanduser2(self):
186 cl = self.klass()
186 cl = self.klass()
187 argv = ['-a', '~/1/2/3', '--b', "'~/1/2/3'"]
187 argv = ['-a', '~/1/2/3', '--b', "'~/1/2/3'"]
188 with mute_warn():
188 with mute_warn():
189 config = cl.load_config(argv, aliases=dict(a='A.a', b='A.b'))
189 config = cl.load_config(argv, aliases=dict(a='A.a', b='A.b'))
190 self.assertEqual(config.A.a, os.path.expanduser('~/1/2/3'))
190 self.assertEqual(config.A.a, os.path.expanduser('~/1/2/3'))
191 self.assertEqual(config.A.b, '~/1/2/3')
191 self.assertEqual(config.A.b, '~/1/2/3')
192
192
193 def test_eval(self):
193 def test_eval(self):
194 cl = self.klass()
194 cl = self.klass()
195 argv = ['-c', 'a=5']
195 argv = ['-c', 'a=5']
196 with mute_warn():
196 with mute_warn():
197 config = cl.load_config(argv, aliases=dict(c='A.c'))
197 config = cl.load_config(argv, aliases=dict(c='A.c'))
198 self.assertEqual(config.A.c, u"a=5")
198 self.assertEqual(config.A.c, u"a=5")
199
199
200
200
201 class TestConfig(TestCase):
201 class TestConfig(TestCase):
202
202
203 def test_setget(self):
203 def test_setget(self):
204 c = Config()
204 c = Config()
205 c.a = 10
205 c.a = 10
206 self.assertEqual(c.a, 10)
206 self.assertEqual(c.a, 10)
207 self.assertEqual('b' in c, False)
207 self.assertEqual('b' in c, False)
208
208
209 def test_auto_section(self):
209 def test_auto_section(self):
210 c = Config()
210 c = Config()
211 self.assertEqual('A' in c, True)
211 self.assertEqual('A' in c, True)
212 self.assertEqual(c._has_section('A'), False)
212 self.assertEqual(c._has_section('A'), False)
213 A = c.A
213 A = c.A
214 A.foo = 'hi there'
214 A.foo = 'hi there'
215 self.assertEqual(c._has_section('A'), True)
215 self.assertEqual(c._has_section('A'), True)
216 self.assertEqual(c.A.foo, 'hi there')
216 self.assertEqual(c.A.foo, 'hi there')
217 del c.A
217 del c.A
218 self.assertEqual(len(c.A.keys()),0)
218 self.assertEqual(len(c.A.keys()),0)
219
219
220 def test_merge_doesnt_exist(self):
220 def test_merge_doesnt_exist(self):
221 c1 = Config()
221 c1 = Config()
222 c2 = Config()
222 c2 = Config()
223 c2.bar = 10
223 c2.bar = 10
224 c2.Foo.bar = 10
224 c2.Foo.bar = 10
225 c1._merge(c2)
225 c1._merge(c2)
226 self.assertEqual(c1.Foo.bar, 10)
226 self.assertEqual(c1.Foo.bar, 10)
227 self.assertEqual(c1.bar, 10)
227 self.assertEqual(c1.bar, 10)
228 c2.Bar.bar = 10
228 c2.Bar.bar = 10
229 c1._merge(c2)
229 c1._merge(c2)
230 self.assertEqual(c1.Bar.bar, 10)
230 self.assertEqual(c1.Bar.bar, 10)
231
231
232 def test_merge_exists(self):
232 def test_merge_exists(self):
233 c1 = Config()
233 c1 = Config()
234 c2 = Config()
234 c2 = Config()
235 c1.Foo.bar = 10
235 c1.Foo.bar = 10
236 c1.Foo.bam = 30
236 c1.Foo.bam = 30
237 c2.Foo.bar = 20
237 c2.Foo.bar = 20
238 c2.Foo.wow = 40
238 c2.Foo.wow = 40
239 c1._merge(c2)
239 c1._merge(c2)
240 self.assertEqual(c1.Foo.bam, 30)
240 self.assertEqual(c1.Foo.bam, 30)
241 self.assertEqual(c1.Foo.bar, 20)
241 self.assertEqual(c1.Foo.bar, 20)
242 self.assertEqual(c1.Foo.wow, 40)
242 self.assertEqual(c1.Foo.wow, 40)
243 c2.Foo.Bam.bam = 10
243 c2.Foo.Bam.bam = 10
244 c1._merge(c2)
244 c1._merge(c2)
245 self.assertEqual(c1.Foo.Bam.bam, 10)
245 self.assertEqual(c1.Foo.Bam.bam, 10)
246
246
247 def test_deepcopy(self):
247 def test_deepcopy(self):
248 c1 = Config()
248 c1 = Config()
249 c1.Foo.bar = 10
249 c1.Foo.bar = 10
250 c1.Foo.bam = 30
250 c1.Foo.bam = 30
251 c1.a = 'asdf'
251 c1.a = 'asdf'
252 c1.b = range(10)
252 c1.b = range(10)
253 import copy
253 import copy
254 c2 = copy.deepcopy(c1)
254 c2 = copy.deepcopy(c1)
255 self.assertEqual(c1, c2)
255 self.assertEqual(c1, c2)
256 self.assert_(c1 is not c2)
256 self.assertTrue(c1 is not c2)
257 self.assert_(c1.Foo is not c2.Foo)
257 self.assertTrue(c1.Foo is not c2.Foo)
258
258
259 def test_builtin(self):
259 def test_builtin(self):
260 c1 = Config()
260 c1 = Config()
261 exec 'foo = True' in c1
261 exec 'foo = True' in c1
262 self.assertEqual(c1.foo, True)
262 self.assertEqual(c1.foo, True)
263 self.assertRaises(ConfigError, setattr, c1, 'ValueError', 10)
263 self.assertRaises(ConfigError, setattr, c1, 'ValueError', 10)
@@ -1,27 +1,27 b''
1 """Tests for the notebook kernel and session manager."""
1 """Tests for the notebook kernel and session manager."""
2
2
3 from unittest import TestCase
3 from unittest import TestCase
4
4
5 from IPython.frontend.html.notebook.kernelmanager import MultiKernelManager
5 from IPython.frontend.html.notebook.kernelmanager import MultiKernelManager
6
6
7 class TestKernelManager(TestCase):
7 class TestKernelManager(TestCase):
8
8
9 def test_km_lifecycle(self):
9 def test_km_lifecycle(self):
10 km = MultiKernelManager()
10 km = MultiKernelManager()
11 kid = km.start_kernel()
11 kid = km.start_kernel()
12 self.assert_(kid in km)
12 self.assertTrue(kid in km)
13 self.assertEqual(len(km),1)
13 self.assertEqual(len(km),1)
14 km.kill_kernel(kid)
14 km.kill_kernel(kid)
15 self.assert_(not kid in km)
15 self.assertTrue(not kid in km)
16
16
17 kid = km.start_kernel()
17 kid = km.start_kernel()
18 self.assertEqual('127.0.0.1',km.get_kernel_ip(kid))
18 self.assertEqual('127.0.0.1',km.get_kernel_ip(kid))
19 port_dict = km.get_kernel_ports(kid)
19 port_dict = km.get_kernel_ports(kid)
20 self.assert_('stdin_port' in port_dict)
20 self.assertTrue('stdin_port' in port_dict)
21 self.assert_('iopub_port' in port_dict)
21 self.assertTrue('iopub_port' in port_dict)
22 self.assert_('shell_port' in port_dict)
22 self.assertTrue('shell_port' in port_dict)
23 self.assert_('hb_port' in port_dict)
23 self.assertTrue('hb_port' in port_dict)
24 km.get_kernel(kid)
24 km.get_kernel(kid)
25 km.kill_kernel(kid)
25 km.kill_kernel(kid)
26
26
27
27
@@ -1,85 +1,85 b''
1 # Standard library imports
1 # Standard library imports
2 import unittest
2 import unittest
3
3
4 # System library imports
4 # System library imports
5 from IPython.external.qt import QtCore, QtGui
5 from IPython.external.qt import QtCore, QtGui
6
6
7 # Local imports
7 # Local imports
8 from IPython.frontend.qt.console.kill_ring import KillRing, QtKillRing
8 from IPython.frontend.qt.console.kill_ring import KillRing, QtKillRing
9
9
10
10
11 class TestKillRing(unittest.TestCase):
11 class TestKillRing(unittest.TestCase):
12
12
13 @classmethod
13 @classmethod
14 def setUpClass(cls):
14 def setUpClass(cls):
15 """ Create the application for the test case.
15 """ Create the application for the test case.
16 """
16 """
17 cls._app = QtGui.QApplication.instance()
17 cls._app = QtGui.QApplication.instance()
18 if cls._app is None:
18 if cls._app is None:
19 cls._app = QtGui.QApplication([])
19 cls._app = QtGui.QApplication([])
20 cls._app.setQuitOnLastWindowClosed(False)
20 cls._app.setQuitOnLastWindowClosed(False)
21
21
22 @classmethod
22 @classmethod
23 def tearDownClass(cls):
23 def tearDownClass(cls):
24 """ Exit the application.
24 """ Exit the application.
25 """
25 """
26 QtGui.QApplication.quit()
26 QtGui.QApplication.quit()
27
27
28 def test_generic(self):
28 def test_generic(self):
29 """ Does the generic kill ring work?
29 """ Does the generic kill ring work?
30 """
30 """
31 ring = KillRing()
31 ring = KillRing()
32 self.assert_(ring.yank() is None)
32 self.assertTrue(ring.yank() is None)
33 self.assert_(ring.rotate() is None)
33 self.assertTrue(ring.rotate() is None)
34
34
35 ring.kill('foo')
35 ring.kill('foo')
36 self.assertEqual(ring.yank(), 'foo')
36 self.assertEqual(ring.yank(), 'foo')
37 self.assert_(ring.rotate() is None)
37 self.assertTrue(ring.rotate() is None)
38 self.assertEqual(ring.yank(), 'foo')
38 self.assertEqual(ring.yank(), 'foo')
39
39
40 ring.kill('bar')
40 ring.kill('bar')
41 self.assertEqual(ring.yank(), 'bar')
41 self.assertEqual(ring.yank(), 'bar')
42 self.assertEqual(ring.rotate(), 'foo')
42 self.assertEqual(ring.rotate(), 'foo')
43
43
44 ring.clear()
44 ring.clear()
45 self.assert_(ring.yank() is None)
45 self.assertTrue(ring.yank() is None)
46 self.assert_(ring.rotate() is None)
46 self.assertTrue(ring.rotate() is None)
47
47
48 def test_qt_basic(self):
48 def test_qt_basic(self):
49 """ Does the Qt kill ring work?
49 """ Does the Qt kill ring work?
50 """
50 """
51 text_edit = QtGui.QPlainTextEdit()
51 text_edit = QtGui.QPlainTextEdit()
52 ring = QtKillRing(text_edit)
52 ring = QtKillRing(text_edit)
53
53
54 ring.kill('foo')
54 ring.kill('foo')
55 ring.kill('bar')
55 ring.kill('bar')
56 ring.yank()
56 ring.yank()
57 ring.rotate()
57 ring.rotate()
58 ring.yank()
58 ring.yank()
59 self.assertEqual(text_edit.toPlainText(), 'foobar')
59 self.assertEqual(text_edit.toPlainText(), 'foobar')
60
60
61 text_edit.clear()
61 text_edit.clear()
62 ring.kill('baz')
62 ring.kill('baz')
63 ring.yank()
63 ring.yank()
64 ring.rotate()
64 ring.rotate()
65 ring.rotate()
65 ring.rotate()
66 ring.rotate()
66 ring.rotate()
67 self.assertEqual(text_edit.toPlainText(), 'foo')
67 self.assertEqual(text_edit.toPlainText(), 'foo')
68
68
69 def test_qt_cursor(self):
69 def test_qt_cursor(self):
70 """ Does the Qt kill ring maintain state with cursor movement?
70 """ Does the Qt kill ring maintain state with cursor movement?
71 """
71 """
72 text_edit = QtGui.QPlainTextEdit()
72 text_edit = QtGui.QPlainTextEdit()
73 ring = QtKillRing(text_edit)
73 ring = QtKillRing(text_edit)
74
74
75 ring.kill('foo')
75 ring.kill('foo')
76 ring.kill('bar')
76 ring.kill('bar')
77 ring.yank()
77 ring.yank()
78 text_edit.moveCursor(QtGui.QTextCursor.Left)
78 text_edit.moveCursor(QtGui.QTextCursor.Left)
79 ring.rotate()
79 ring.rotate()
80 self.assertEqual(text_edit.toPlainText(), 'bar')
80 self.assertEqual(text_edit.toPlainText(), 'bar')
81
81
82
82
83 if __name__ == '__main__':
83 if __name__ == '__main__':
84 import nose
84 import nose
85 nose.main()
85 nose.main()
@@ -1,180 +1,180 b''
1 """Test suite for the irunner module.
1 """Test suite for the irunner module.
2
2
3 Not the most elegant or fine-grained, but it does cover at least the bulk
3 Not the most elegant or fine-grained, but it does cover at least the bulk
4 functionality."""
4 functionality."""
5 from __future__ import print_function
5 from __future__ import print_function
6
6
7 # Global to make tests extra verbose and help debugging
7 # Global to make tests extra verbose and help debugging
8 VERBOSE = True
8 VERBOSE = True
9
9
10 # stdlib imports
10 # stdlib imports
11 import StringIO
11 import StringIO
12 import sys
12 import sys
13 import unittest
13 import unittest
14
14
15 # IPython imports
15 # IPython imports
16 from IPython.lib import irunner
16 from IPython.lib import irunner
17 from IPython.utils.py3compat import doctest_refactor_print
17 from IPython.utils.py3compat import doctest_refactor_print
18
18
19 # Testing code begins
19 # Testing code begins
20 class RunnerTestCase(unittest.TestCase):
20 class RunnerTestCase(unittest.TestCase):
21
21
22 def setUp(self):
22 def setUp(self):
23 self.out = StringIO.StringIO()
23 self.out = StringIO.StringIO()
24 #self.out = sys.stdout
24 #self.out = sys.stdout
25
25
26 def _test_runner(self,runner,source,output):
26 def _test_runner(self,runner,source,output):
27 """Test that a given runner's input/output match."""
27 """Test that a given runner's input/output match."""
28
28
29 runner.run_source(source)
29 runner.run_source(source)
30 out = self.out.getvalue()
30 out = self.out.getvalue()
31 #out = ''
31 #out = ''
32 # this output contains nasty \r\n lineends, and the initial ipython
32 # this output contains nasty \r\n lineends, and the initial ipython
33 # banner. clean it up for comparison, removing lines of whitespace
33 # banner. clean it up for comparison, removing lines of whitespace
34 output_l = [l for l in output.splitlines() if l and not l.isspace()]
34 output_l = [l for l in output.splitlines() if l and not l.isspace()]
35 out_l = [l for l in out.splitlines() if l and not l.isspace()]
35 out_l = [l for l in out.splitlines() if l and not l.isspace()]
36 mismatch = 0
36 mismatch = 0
37 if len(output_l) != len(out_l):
37 if len(output_l) != len(out_l):
38 message = ("Mismatch in number of lines\n\n"
38 message = ("Mismatch in number of lines\n\n"
39 "Expected:\n"
39 "Expected:\n"
40 "~~~~~~~~~\n"
40 "~~~~~~~~~\n"
41 "%s\n\n"
41 "%s\n\n"
42 "Got:\n"
42 "Got:\n"
43 "~~~~~~~~~\n"
43 "~~~~~~~~~\n"
44 "%s"
44 "%s"
45 ) % ("\n".join(output_l), "\n".join(out_l))
45 ) % ("\n".join(output_l), "\n".join(out_l))
46 self.fail(message)
46 self.fail(message)
47 for n in range(len(output_l)):
47 for n in range(len(output_l)):
48 # Do a line-by-line comparison
48 # Do a line-by-line comparison
49 ol1 = output_l[n].strip()
49 ol1 = output_l[n].strip()
50 ol2 = out_l[n].strip()
50 ol2 = out_l[n].strip()
51 if ol1 != ol2:
51 if ol1 != ol2:
52 mismatch += 1
52 mismatch += 1
53 if VERBOSE:
53 if VERBOSE:
54 print('<<< line %s does not match:' % n)
54 print('<<< line %s does not match:' % n)
55 print(repr(ol1))
55 print(repr(ol1))
56 print(repr(ol2))
56 print(repr(ol2))
57 print('>>>')
57 print('>>>')
58 self.assert_(mismatch==0,'Number of mismatched lines: %s' %
58 self.assertTrue(mismatch==0,'Number of mismatched lines: %s' %
59 mismatch)
59 mismatch)
60
60
61 def testIPython(self):
61 def testIPython(self):
62 """Test the IPython runner."""
62 """Test the IPython runner."""
63 source = doctest_refactor_print("""
63 source = doctest_refactor_print("""
64 print 'hello, this is python'
64 print 'hello, this is python'
65 # some more code
65 # some more code
66 x=1;y=2
66 x=1;y=2
67 x+y**2
67 x+y**2
68
68
69 # An example of autocall functionality
69 # An example of autocall functionality
70 from math import *
70 from math import *
71 autocall 1
71 autocall 1
72 cos pi
72 cos pi
73 autocall 0
73 autocall 0
74 cos pi
74 cos pi
75 cos(pi)
75 cos(pi)
76
76
77 for i in range(5):
77 for i in range(5):
78 print i
78 print i
79
79
80 print "that's all folks!"
80 print "that's all folks!"
81
81
82 exit
82 exit
83 """)
83 """)
84 output = doctest_refactor_print("""\
84 output = doctest_refactor_print("""\
85 In [1]: print 'hello, this is python'
85 In [1]: print 'hello, this is python'
86 hello, this is python
86 hello, this is python
87
87
88
88
89 # some more code
89 # some more code
90 In [2]: x=1;y=2
90 In [2]: x=1;y=2
91
91
92 In [3]: x+y**2
92 In [3]: x+y**2
93 Out[3]: 5
93 Out[3]: 5
94
94
95
95
96 # An example of autocall functionality
96 # An example of autocall functionality
97 In [4]: from math import *
97 In [4]: from math import *
98
98
99 In [5]: autocall 1
99 In [5]: autocall 1
100 Automatic calling is: Smart
100 Automatic calling is: Smart
101
101
102 In [6]: cos pi
102 In [6]: cos pi
103 ------> cos(pi)
103 ------> cos(pi)
104 Out[6]: -1.0
104 Out[6]: -1.0
105
105
106 In [7]: autocall 0
106 In [7]: autocall 0
107 Automatic calling is: OFF
107 Automatic calling is: OFF
108
108
109 In [8]: cos pi
109 In [8]: cos pi
110 File "<ipython-input-8-6bd7313dd9a9>", line 1
110 File "<ipython-input-8-6bd7313dd9a9>", line 1
111 cos pi
111 cos pi
112 ^
112 ^
113 SyntaxError: invalid syntax
113 SyntaxError: invalid syntax
114
114
115
115
116 In [9]: cos(pi)
116 In [9]: cos(pi)
117 Out[9]: -1.0
117 Out[9]: -1.0
118
118
119
119
120 In [10]: for i in range(5):
120 In [10]: for i in range(5):
121 ....: print i
121 ....: print i
122 ....:
122 ....:
123 0
123 0
124 1
124 1
125 2
125 2
126 3
126 3
127 4
127 4
128
128
129 In [11]: print "that's all folks!"
129 In [11]: print "that's all folks!"
130 that's all folks!
130 that's all folks!
131
131
132
132
133 In [12]: exit
133 In [12]: exit
134 """)
134 """)
135 runner = irunner.IPythonRunner(out=self.out)
135 runner = irunner.IPythonRunner(out=self.out)
136 self._test_runner(runner,source,output)
136 self._test_runner(runner,source,output)
137
137
138 def testPython(self):
138 def testPython(self):
139 """Test the Python runner."""
139 """Test the Python runner."""
140 runner = irunner.PythonRunner(out=self.out)
140 runner = irunner.PythonRunner(out=self.out)
141 source = doctest_refactor_print("""
141 source = doctest_refactor_print("""
142 print 'hello, this is python'
142 print 'hello, this is python'
143
143
144 # some more code
144 # some more code
145 x=1;y=2
145 x=1;y=2
146 x+y**2
146 x+y**2
147
147
148 from math import *
148 from math import *
149 cos(pi)
149 cos(pi)
150
150
151 for i in range(5):
151 for i in range(5):
152 print i
152 print i
153
153
154 print "that's all folks!"
154 print "that's all folks!"
155 """)
155 """)
156 output = doctest_refactor_print("""\
156 output = doctest_refactor_print("""\
157 >>> print 'hello, this is python'
157 >>> print 'hello, this is python'
158 hello, this is python
158 hello, this is python
159
159
160 # some more code
160 # some more code
161 >>> x=1;y=2
161 >>> x=1;y=2
162 >>> x+y**2
162 >>> x+y**2
163 5
163 5
164
164
165 >>> from math import *
165 >>> from math import *
166 >>> cos(pi)
166 >>> cos(pi)
167 -1.0
167 -1.0
168
168
169 >>> for i in range(5):
169 >>> for i in range(5):
170 ... print i
170 ... print i
171 ...
171 ...
172 0
172 0
173 1
173 1
174 2
174 2
175 3
175 3
176 4
176 4
177 >>> print "that's all folks!"
177 >>> print "that's all folks!"
178 that's all folks!
178 that's all folks!
179 """)
179 """)
180 self._test_runner(runner,source,output)
180 self._test_runner(runner,source,output)
@@ -1,119 +1,119 b''
1 """Test suite for pylab_import_all magic
1 """Test suite for pylab_import_all magic
2 Modified from the irunner module but using regex.
2 Modified from the irunner module but using regex.
3 """
3 """
4 from __future__ import print_function
4 from __future__ import print_function
5
5
6 # Global to make tests extra verbose and help debugging
6 # Global to make tests extra verbose and help debugging
7 VERBOSE = True
7 VERBOSE = True
8
8
9 # stdlib imports
9 # stdlib imports
10 import StringIO
10 import StringIO
11 import sys
11 import sys
12 import unittest
12 import unittest
13 import re
13 import re
14
14
15 # IPython imports
15 # IPython imports
16 from IPython.lib import irunner
16 from IPython.lib import irunner
17 from IPython.testing import decorators
17 from IPython.testing import decorators
18
18
19 def pylab_not_importable():
19 def pylab_not_importable():
20 """Test if importing pylab fails with RuntimeError (true when having no display)"""
20 """Test if importing pylab fails with RuntimeError (true when having no display)"""
21 try:
21 try:
22 import pylab
22 import pylab
23 return False
23 return False
24 except RuntimeError:
24 except RuntimeError:
25 return True
25 return True
26
26
27 # Testing code begins
27 # Testing code begins
28 class RunnerTestCase(unittest.TestCase):
28 class RunnerTestCase(unittest.TestCase):
29
29
30 def setUp(self):
30 def setUp(self):
31 self.out = StringIO.StringIO()
31 self.out = StringIO.StringIO()
32 #self.out = sys.stdout
32 #self.out = sys.stdout
33
33
34 def _test_runner(self,runner,source,output):
34 def _test_runner(self,runner,source,output):
35 """Test that a given runner's input/output match."""
35 """Test that a given runner's input/output match."""
36
36
37 runner.run_source(source)
37 runner.run_source(source)
38 out = self.out.getvalue()
38 out = self.out.getvalue()
39 #out = ''
39 #out = ''
40 # this output contains nasty \r\n lineends, and the initial ipython
40 # this output contains nasty \r\n lineends, and the initial ipython
41 # banner. clean it up for comparison, removing lines of whitespace
41 # banner. clean it up for comparison, removing lines of whitespace
42 output_l = [l for l in output.splitlines() if l and not l.isspace()]
42 output_l = [l for l in output.splitlines() if l and not l.isspace()]
43 out_l = [l for l in out.splitlines() if l and not l.isspace()]
43 out_l = [l for l in out.splitlines() if l and not l.isspace()]
44 mismatch = 0
44 mismatch = 0
45 if len(output_l) != len(out_l):
45 if len(output_l) != len(out_l):
46 message = ("Mismatch in number of lines\n\n"
46 message = ("Mismatch in number of lines\n\n"
47 "Expected:\n"
47 "Expected:\n"
48 "~~~~~~~~~\n"
48 "~~~~~~~~~\n"
49 "%s\n\n"
49 "%s\n\n"
50 "Got:\n"
50 "Got:\n"
51 "~~~~~~~~~\n"
51 "~~~~~~~~~\n"
52 "%s"
52 "%s"
53 ) % ("\n".join(output_l), "\n".join(out_l))
53 ) % ("\n".join(output_l), "\n".join(out_l))
54 self.fail(message)
54 self.fail(message)
55 for n in range(len(output_l)):
55 for n in range(len(output_l)):
56 # Do a line-by-line comparison
56 # Do a line-by-line comparison
57 ol1 = output_l[n].strip()
57 ol1 = output_l[n].strip()
58 ol2 = out_l[n].strip()
58 ol2 = out_l[n].strip()
59 if not re.match(ol1,ol2):
59 if not re.match(ol1,ol2):
60 mismatch += 1
60 mismatch += 1
61 if VERBOSE:
61 if VERBOSE:
62 print('<<< line %s does not match:' % n)
62 print('<<< line %s does not match:' % n)
63 print(repr(ol1))
63 print(repr(ol1))
64 print(repr(ol2))
64 print(repr(ol2))
65 print('>>>')
65 print('>>>')
66 self.assert_(mismatch==0,'Number of mismatched lines: %s' %
66 self.assertTrue(mismatch==0,'Number of mismatched lines: %s' %
67 mismatch)
67 mismatch)
68
68
69 @decorators.skipif_not_matplotlib
69 @decorators.skipif_not_matplotlib
70 @decorators.skipif(pylab_not_importable, "Likely a run without X.")
70 @decorators.skipif(pylab_not_importable, "Likely a run without X.")
71 def test_pylab_import_all_enabled(self):
71 def test_pylab_import_all_enabled(self):
72 "Verify that plot is available when pylab_import_all = True"
72 "Verify that plot is available when pylab_import_all = True"
73 source = """
73 source = """
74 from IPython.config.application import Application
74 from IPython.config.application import Application
75 app = Application.instance()
75 app = Application.instance()
76 app.pylab_import_all = True
76 app.pylab_import_all = True
77 pylab
77 pylab
78 ip=get_ipython()
78 ip=get_ipython()
79 'plot' in ip.user_ns
79 'plot' in ip.user_ns
80 """
80 """
81 output = """
81 output = """
82 In \[1\]: from IPython\.config\.application import Application
82 In \[1\]: from IPython\.config\.application import Application
83 In \[2\]: app = Application\.instance\(\)
83 In \[2\]: app = Application\.instance\(\)
84 In \[3\]: app\.pylab_import_all = True
84 In \[3\]: app\.pylab_import_all = True
85 In \[4\]: pylab
85 In \[4\]: pylab
86 ^Welcome to pylab, a matplotlib-based Python environment
86 ^Welcome to pylab, a matplotlib-based Python environment
87 For more information, type 'help\(pylab\)'\.
87 For more information, type 'help\(pylab\)'\.
88 In \[5\]: ip=get_ipython\(\)
88 In \[5\]: ip=get_ipython\(\)
89 In \[6\]: \'plot\' in ip\.user_ns
89 In \[6\]: \'plot\' in ip\.user_ns
90 Out\[6\]: True
90 Out\[6\]: True
91 """
91 """
92 runner = irunner.IPythonRunner(out=self.out)
92 runner = irunner.IPythonRunner(out=self.out)
93 self._test_runner(runner,source,output)
93 self._test_runner(runner,source,output)
94
94
95 @decorators.skipif_not_matplotlib
95 @decorators.skipif_not_matplotlib
96 @decorators.skipif(pylab_not_importable, "Likely a run without X.")
96 @decorators.skipif(pylab_not_importable, "Likely a run without X.")
97 def test_pylab_import_all_disabled(self):
97 def test_pylab_import_all_disabled(self):
98 "Verify that plot is not available when pylab_import_all = False"
98 "Verify that plot is not available when pylab_import_all = False"
99 source = """
99 source = """
100 from IPython.config.application import Application
100 from IPython.config.application import Application
101 app = Application.instance()
101 app = Application.instance()
102 app.pylab_import_all = False
102 app.pylab_import_all = False
103 pylab
103 pylab
104 ip=get_ipython()
104 ip=get_ipython()
105 'plot' in ip.user_ns
105 'plot' in ip.user_ns
106 """
106 """
107 output = """
107 output = """
108 In \[1\]: from IPython\.config\.application import Application
108 In \[1\]: from IPython\.config\.application import Application
109 In \[2\]: app = Application\.instance\(\)
109 In \[2\]: app = Application\.instance\(\)
110 In \[3\]: app\.pylab_import_all = False
110 In \[3\]: app\.pylab_import_all = False
111 In \[4\]: pylab
111 In \[4\]: pylab
112 ^Welcome to pylab, a matplotlib-based Python environment
112 ^Welcome to pylab, a matplotlib-based Python environment
113 For more information, type 'help\(pylab\)'\.
113 For more information, type 'help\(pylab\)'\.
114 In \[5\]: ip=get_ipython\(\)
114 In \[5\]: ip=get_ipython\(\)
115 In \[6\]: \'plot\' in ip\.user_ns
115 In \[6\]: \'plot\' in ip\.user_ns
116 Out\[6\]: False
116 Out\[6\]: False
117 """
117 """
118 runner = irunner.IPythonRunner(out=self.out)
118 runner = irunner.IPythonRunner(out=self.out)
119 self._test_runner(runner,source,output)
119 self._test_runner(runner,source,output)
@@ -1,455 +1,455 b''
1 """Tests for parallel client.py
1 """Tests for parallel client.py
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import time
21 import time
22 from datetime import datetime
22 from datetime import datetime
23 from tempfile import mktemp
23 from tempfile import mktemp
24
24
25 import zmq
25 import zmq
26
26
27 from IPython import parallel
27 from IPython import parallel
28 from IPython.parallel.client import client as clientmod
28 from IPython.parallel.client import client as clientmod
29 from IPython.parallel import error
29 from IPython.parallel import error
30 from IPython.parallel import AsyncResult, AsyncHubResult
30 from IPython.parallel import AsyncResult, AsyncHubResult
31 from IPython.parallel import LoadBalancedView, DirectView
31 from IPython.parallel import LoadBalancedView, DirectView
32
32
33 from clienttest import ClusterTestCase, segfault, wait, add_engines
33 from clienttest import ClusterTestCase, segfault, wait, add_engines
34
34
35 def setup():
35 def setup():
36 add_engines(4, total=True)
36 add_engines(4, total=True)
37
37
38 class TestClient(ClusterTestCase):
38 class TestClient(ClusterTestCase):
39
39
40 def test_ids(self):
40 def test_ids(self):
41 n = len(self.client.ids)
41 n = len(self.client.ids)
42 self.add_engines(2)
42 self.add_engines(2)
43 self.assertEqual(len(self.client.ids), n+2)
43 self.assertEqual(len(self.client.ids), n+2)
44
44
45 def test_view_indexing(self):
45 def test_view_indexing(self):
46 """test index access for views"""
46 """test index access for views"""
47 self.minimum_engines(4)
47 self.minimum_engines(4)
48 targets = self.client._build_targets('all')[-1]
48 targets = self.client._build_targets('all')[-1]
49 v = self.client[:]
49 v = self.client[:]
50 self.assertEqual(v.targets, targets)
50 self.assertEqual(v.targets, targets)
51 t = self.client.ids[2]
51 t = self.client.ids[2]
52 v = self.client[t]
52 v = self.client[t]
53 self.assert_(isinstance(v, DirectView))
53 self.assertTrue(isinstance(v, DirectView))
54 self.assertEqual(v.targets, t)
54 self.assertEqual(v.targets, t)
55 t = self.client.ids[2:4]
55 t = self.client.ids[2:4]
56 v = self.client[t]
56 v = self.client[t]
57 self.assert_(isinstance(v, DirectView))
57 self.assertTrue(isinstance(v, DirectView))
58 self.assertEqual(v.targets, t)
58 self.assertEqual(v.targets, t)
59 v = self.client[::2]
59 v = self.client[::2]
60 self.assert_(isinstance(v, DirectView))
60 self.assertTrue(isinstance(v, DirectView))
61 self.assertEqual(v.targets, targets[::2])
61 self.assertEqual(v.targets, targets[::2])
62 v = self.client[1::3]
62 v = self.client[1::3]
63 self.assert_(isinstance(v, DirectView))
63 self.assertTrue(isinstance(v, DirectView))
64 self.assertEqual(v.targets, targets[1::3])
64 self.assertEqual(v.targets, targets[1::3])
65 v = self.client[:-3]
65 v = self.client[:-3]
66 self.assert_(isinstance(v, DirectView))
66 self.assertTrue(isinstance(v, DirectView))
67 self.assertEqual(v.targets, targets[:-3])
67 self.assertEqual(v.targets, targets[:-3])
68 v = self.client[-1]
68 v = self.client[-1]
69 self.assert_(isinstance(v, DirectView))
69 self.assertTrue(isinstance(v, DirectView))
70 self.assertEqual(v.targets, targets[-1])
70 self.assertEqual(v.targets, targets[-1])
71 self.assertRaises(TypeError, lambda : self.client[None])
71 self.assertRaises(TypeError, lambda : self.client[None])
72
72
73 def test_lbview_targets(self):
73 def test_lbview_targets(self):
74 """test load_balanced_view targets"""
74 """test load_balanced_view targets"""
75 v = self.client.load_balanced_view()
75 v = self.client.load_balanced_view()
76 self.assertEqual(v.targets, None)
76 self.assertEqual(v.targets, None)
77 v = self.client.load_balanced_view(-1)
77 v = self.client.load_balanced_view(-1)
78 self.assertEqual(v.targets, [self.client.ids[-1]])
78 self.assertEqual(v.targets, [self.client.ids[-1]])
79 v = self.client.load_balanced_view('all')
79 v = self.client.load_balanced_view('all')
80 self.assertEqual(v.targets, None)
80 self.assertEqual(v.targets, None)
81
81
82 def test_dview_targets(self):
82 def test_dview_targets(self):
83 """test direct_view targets"""
83 """test direct_view targets"""
84 v = self.client.direct_view()
84 v = self.client.direct_view()
85 self.assertEqual(v.targets, 'all')
85 self.assertEqual(v.targets, 'all')
86 v = self.client.direct_view('all')
86 v = self.client.direct_view('all')
87 self.assertEqual(v.targets, 'all')
87 self.assertEqual(v.targets, 'all')
88 v = self.client.direct_view(-1)
88 v = self.client.direct_view(-1)
89 self.assertEqual(v.targets, self.client.ids[-1])
89 self.assertEqual(v.targets, self.client.ids[-1])
90
90
91 def test_lazy_all_targets(self):
91 def test_lazy_all_targets(self):
92 """test lazy evaluation of rc.direct_view('all')"""
92 """test lazy evaluation of rc.direct_view('all')"""
93 v = self.client.direct_view()
93 v = self.client.direct_view()
94 self.assertEqual(v.targets, 'all')
94 self.assertEqual(v.targets, 'all')
95
95
96 def double(x):
96 def double(x):
97 return x*2
97 return x*2
98 seq = range(100)
98 seq = range(100)
99 ref = [ double(x) for x in seq ]
99 ref = [ double(x) for x in seq ]
100
100
101 # add some engines, which should be used
101 # add some engines, which should be used
102 self.add_engines(1)
102 self.add_engines(1)
103 n1 = len(self.client.ids)
103 n1 = len(self.client.ids)
104
104
105 # simple apply
105 # simple apply
106 r = v.apply_sync(lambda : 1)
106 r = v.apply_sync(lambda : 1)
107 self.assertEqual(r, [1] * n1)
107 self.assertEqual(r, [1] * n1)
108
108
109 # map goes through remotefunction
109 # map goes through remotefunction
110 r = v.map_sync(double, seq)
110 r = v.map_sync(double, seq)
111 self.assertEqual(r, ref)
111 self.assertEqual(r, ref)
112
112
113 # add a couple more engines, and try again
113 # add a couple more engines, and try again
114 self.add_engines(2)
114 self.add_engines(2)
115 n2 = len(self.client.ids)
115 n2 = len(self.client.ids)
116 self.assertNotEquals(n2, n1)
116 self.assertNotEquals(n2, n1)
117
117
118 # apply
118 # apply
119 r = v.apply_sync(lambda : 1)
119 r = v.apply_sync(lambda : 1)
120 self.assertEqual(r, [1] * n2)
120 self.assertEqual(r, [1] * n2)
121
121
122 # map
122 # map
123 r = v.map_sync(double, seq)
123 r = v.map_sync(double, seq)
124 self.assertEqual(r, ref)
124 self.assertEqual(r, ref)
125
125
126 def test_targets(self):
126 def test_targets(self):
127 """test various valid targets arguments"""
127 """test various valid targets arguments"""
128 build = self.client._build_targets
128 build = self.client._build_targets
129 ids = self.client.ids
129 ids = self.client.ids
130 idents,targets = build(None)
130 idents,targets = build(None)
131 self.assertEqual(ids, targets)
131 self.assertEqual(ids, targets)
132
132
133 def test_clear(self):
133 def test_clear(self):
134 """test clear behavior"""
134 """test clear behavior"""
135 self.minimum_engines(2)
135 self.minimum_engines(2)
136 v = self.client[:]
136 v = self.client[:]
137 v.block=True
137 v.block=True
138 v.push(dict(a=5))
138 v.push(dict(a=5))
139 v.pull('a')
139 v.pull('a')
140 id0 = self.client.ids[-1]
140 id0 = self.client.ids[-1]
141 self.client.clear(targets=id0, block=True)
141 self.client.clear(targets=id0, block=True)
142 a = self.client[:-1].get('a')
142 a = self.client[:-1].get('a')
143 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
143 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
144 self.client.clear(block=True)
144 self.client.clear(block=True)
145 for i in self.client.ids:
145 for i in self.client.ids:
146 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
146 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
147
147
148 def test_get_result(self):
148 def test_get_result(self):
149 """test getting results from the Hub."""
149 """test getting results from the Hub."""
150 c = clientmod.Client(profile='iptest')
150 c = clientmod.Client(profile='iptest')
151 t = c.ids[-1]
151 t = c.ids[-1]
152 ar = c[t].apply_async(wait, 1)
152 ar = c[t].apply_async(wait, 1)
153 # give the monitor time to notice the message
153 # give the monitor time to notice the message
154 time.sleep(.25)
154 time.sleep(.25)
155 ahr = self.client.get_result(ar.msg_ids)
155 ahr = self.client.get_result(ar.msg_ids)
156 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 self.assertTrue(isinstance(ahr, AsyncHubResult))
157 self.assertEqual(ahr.get(), ar.get())
157 self.assertEqual(ahr.get(), ar.get())
158 ar2 = self.client.get_result(ar.msg_ids)
158 ar2 = self.client.get_result(ar.msg_ids)
159 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 self.assertFalse(isinstance(ar2, AsyncHubResult))
160 c.close()
160 c.close()
161
161
162 def test_get_execute_result(self):
162 def test_get_execute_result(self):
163 """test getting execute results from the Hub."""
163 """test getting execute results from the Hub."""
164 c = clientmod.Client(profile='iptest')
164 c = clientmod.Client(profile='iptest')
165 t = c.ids[-1]
165 t = c.ids[-1]
166 cell = '\n'.join([
166 cell = '\n'.join([
167 'import time',
167 'import time',
168 'time.sleep(0.25)',
168 'time.sleep(0.25)',
169 '5'
169 '5'
170 ])
170 ])
171 ar = c[t].execute("import time; time.sleep(1)", silent=False)
171 ar = c[t].execute("import time; time.sleep(1)", silent=False)
172 # give the monitor time to notice the message
172 # give the monitor time to notice the message
173 time.sleep(.25)
173 time.sleep(.25)
174 ahr = self.client.get_result(ar.msg_ids)
174 ahr = self.client.get_result(ar.msg_ids)
175 self.assertTrue(isinstance(ahr, AsyncHubResult))
175 self.assertTrue(isinstance(ahr, AsyncHubResult))
176 self.assertEqual(ahr.get().pyout, ar.get().pyout)
176 self.assertEqual(ahr.get().pyout, ar.get().pyout)
177 ar2 = self.client.get_result(ar.msg_ids)
177 ar2 = self.client.get_result(ar.msg_ids)
178 self.assertFalse(isinstance(ar2, AsyncHubResult))
178 self.assertFalse(isinstance(ar2, AsyncHubResult))
179 c.close()
179 c.close()
180
180
181 def test_ids_list(self):
181 def test_ids_list(self):
182 """test client.ids"""
182 """test client.ids"""
183 ids = self.client.ids
183 ids = self.client.ids
184 self.assertEqual(ids, self.client._ids)
184 self.assertEqual(ids, self.client._ids)
185 self.assertFalse(ids is self.client._ids)
185 self.assertFalse(ids is self.client._ids)
186 ids.remove(ids[-1])
186 ids.remove(ids[-1])
187 self.assertNotEquals(ids, self.client._ids)
187 self.assertNotEquals(ids, self.client._ids)
188
188
189 def test_queue_status(self):
189 def test_queue_status(self):
190 ids = self.client.ids
190 ids = self.client.ids
191 id0 = ids[0]
191 id0 = ids[0]
192 qs = self.client.queue_status(targets=id0)
192 qs = self.client.queue_status(targets=id0)
193 self.assertTrue(isinstance(qs, dict))
193 self.assertTrue(isinstance(qs, dict))
194 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
194 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
195 allqs = self.client.queue_status()
195 allqs = self.client.queue_status()
196 self.assertTrue(isinstance(allqs, dict))
196 self.assertTrue(isinstance(allqs, dict))
197 intkeys = list(allqs.keys())
197 intkeys = list(allqs.keys())
198 intkeys.remove('unassigned')
198 intkeys.remove('unassigned')
199 self.assertEqual(sorted(intkeys), sorted(self.client.ids))
199 self.assertEqual(sorted(intkeys), sorted(self.client.ids))
200 unassigned = allqs.pop('unassigned')
200 unassigned = allqs.pop('unassigned')
201 for eid,qs in allqs.items():
201 for eid,qs in allqs.items():
202 self.assertTrue(isinstance(qs, dict))
202 self.assertTrue(isinstance(qs, dict))
203 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
203 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
204
204
205 def test_shutdown(self):
205 def test_shutdown(self):
206 ids = self.client.ids
206 ids = self.client.ids
207 id0 = ids[0]
207 id0 = ids[0]
208 self.client.shutdown(id0, block=True)
208 self.client.shutdown(id0, block=True)
209 while id0 in self.client.ids:
209 while id0 in self.client.ids:
210 time.sleep(0.1)
210 time.sleep(0.1)
211 self.client.spin()
211 self.client.spin()
212
212
213 self.assertRaises(IndexError, lambda : self.client[id0])
213 self.assertRaises(IndexError, lambda : self.client[id0])
214
214
215 def test_result_status(self):
215 def test_result_status(self):
216 pass
216 pass
217 # to be written
217 # to be written
218
218
219 def test_db_query_dt(self):
219 def test_db_query_dt(self):
220 """test db query by date"""
220 """test db query by date"""
221 hist = self.client.hub_history()
221 hist = self.client.hub_history()
222 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
222 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
223 tic = middle['submitted']
223 tic = middle['submitted']
224 before = self.client.db_query({'submitted' : {'$lt' : tic}})
224 before = self.client.db_query({'submitted' : {'$lt' : tic}})
225 after = self.client.db_query({'submitted' : {'$gte' : tic}})
225 after = self.client.db_query({'submitted' : {'$gte' : tic}})
226 self.assertEqual(len(before)+len(after),len(hist))
226 self.assertEqual(len(before)+len(after),len(hist))
227 for b in before:
227 for b in before:
228 self.assertTrue(b['submitted'] < tic)
228 self.assertTrue(b['submitted'] < tic)
229 for a in after:
229 for a in after:
230 self.assertTrue(a['submitted'] >= tic)
230 self.assertTrue(a['submitted'] >= tic)
231 same = self.client.db_query({'submitted' : tic})
231 same = self.client.db_query({'submitted' : tic})
232 for s in same:
232 for s in same:
233 self.assertTrue(s['submitted'] == tic)
233 self.assertTrue(s['submitted'] == tic)
234
234
235 def test_db_query_keys(self):
235 def test_db_query_keys(self):
236 """test extracting subset of record keys"""
236 """test extracting subset of record keys"""
237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
238 for rec in found:
238 for rec in found:
239 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
239 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
240
240
241 def test_db_query_default_keys(self):
241 def test_db_query_default_keys(self):
242 """default db_query excludes buffers"""
242 """default db_query excludes buffers"""
243 found = self.client.db_query({'msg_id': {'$ne' : ''}})
243 found = self.client.db_query({'msg_id': {'$ne' : ''}})
244 for rec in found:
244 for rec in found:
245 keys = set(rec.keys())
245 keys = set(rec.keys())
246 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
246 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
247 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
247 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
248
248
249 def test_db_query_msg_id(self):
249 def test_db_query_msg_id(self):
250 """ensure msg_id is always in db queries"""
250 """ensure msg_id is always in db queries"""
251 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
251 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
252 for rec in found:
252 for rec in found:
253 self.assertTrue('msg_id' in rec.keys())
253 self.assertTrue('msg_id' in rec.keys())
254 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
254 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
255 for rec in found:
255 for rec in found:
256 self.assertTrue('msg_id' in rec.keys())
256 self.assertTrue('msg_id' in rec.keys())
257 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
257 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
258 for rec in found:
258 for rec in found:
259 self.assertTrue('msg_id' in rec.keys())
259 self.assertTrue('msg_id' in rec.keys())
260
260
261 def test_db_query_get_result(self):
261 def test_db_query_get_result(self):
262 """pop in db_query shouldn't pop from result itself"""
262 """pop in db_query shouldn't pop from result itself"""
263 self.client[:].apply_sync(lambda : 1)
263 self.client[:].apply_sync(lambda : 1)
264 found = self.client.db_query({'msg_id': {'$ne' : ''}})
264 found = self.client.db_query({'msg_id': {'$ne' : ''}})
265 rc2 = clientmod.Client(profile='iptest')
265 rc2 = clientmod.Client(profile='iptest')
266 # If this bug is not fixed, this call will hang:
266 # If this bug is not fixed, this call will hang:
267 ar = rc2.get_result(self.client.history[-1])
267 ar = rc2.get_result(self.client.history[-1])
268 ar.wait(2)
268 ar.wait(2)
269 self.assertTrue(ar.ready())
269 self.assertTrue(ar.ready())
270 ar.get()
270 ar.get()
271 rc2.close()
271 rc2.close()
272
272
273 def test_db_query_in(self):
273 def test_db_query_in(self):
274 """test db query with '$in','$nin' operators"""
274 """test db query with '$in','$nin' operators"""
275 hist = self.client.hub_history()
275 hist = self.client.hub_history()
276 even = hist[::2]
276 even = hist[::2]
277 odd = hist[1::2]
277 odd = hist[1::2]
278 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
278 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
279 found = [ r['msg_id'] for r in recs ]
279 found = [ r['msg_id'] for r in recs ]
280 self.assertEqual(set(even), set(found))
280 self.assertEqual(set(even), set(found))
281 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
281 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
282 found = [ r['msg_id'] for r in recs ]
282 found = [ r['msg_id'] for r in recs ]
283 self.assertEqual(set(odd), set(found))
283 self.assertEqual(set(odd), set(found))
284
284
285 def test_hub_history(self):
285 def test_hub_history(self):
286 hist = self.client.hub_history()
286 hist = self.client.hub_history()
287 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
287 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
288 recdict = {}
288 recdict = {}
289 for rec in recs:
289 for rec in recs:
290 recdict[rec['msg_id']] = rec
290 recdict[rec['msg_id']] = rec
291
291
292 latest = datetime(1984,1,1)
292 latest = datetime(1984,1,1)
293 for msg_id in hist:
293 for msg_id in hist:
294 rec = recdict[msg_id]
294 rec = recdict[msg_id]
295 newt = rec['submitted']
295 newt = rec['submitted']
296 self.assertTrue(newt >= latest)
296 self.assertTrue(newt >= latest)
297 latest = newt
297 latest = newt
298 ar = self.client[-1].apply_async(lambda : 1)
298 ar = self.client[-1].apply_async(lambda : 1)
299 ar.get()
299 ar.get()
300 time.sleep(0.25)
300 time.sleep(0.25)
301 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
301 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
302
302
303 def _wait_for_idle(self):
303 def _wait_for_idle(self):
304 """wait for an engine to become idle, according to the Hub"""
304 """wait for an engine to become idle, according to the Hub"""
305 rc = self.client
305 rc = self.client
306
306
307 # timeout 5s, polling every 100ms
307 # timeout 5s, polling every 100ms
308 qs = rc.queue_status()
308 qs = rc.queue_status()
309 for i in range(50):
309 for i in range(50):
310 if qs['unassigned'] or any(qs[eid]['tasks'] for eid in rc.ids):
310 if qs['unassigned'] or any(qs[eid]['tasks'] for eid in rc.ids):
311 time.sleep(0.1)
311 time.sleep(0.1)
312 qs = rc.queue_status()
312 qs = rc.queue_status()
313 else:
313 else:
314 break
314 break
315
315
316 # ensure Hub up to date:
316 # ensure Hub up to date:
317 self.assertEqual(qs['unassigned'], 0)
317 self.assertEqual(qs['unassigned'], 0)
318 for eid in rc.ids:
318 for eid in rc.ids:
319 self.assertEqual(qs[eid]['tasks'], 0)
319 self.assertEqual(qs[eid]['tasks'], 0)
320
320
321
321
322 def test_resubmit(self):
322 def test_resubmit(self):
323 def f():
323 def f():
324 import random
324 import random
325 return random.random()
325 return random.random()
326 v = self.client.load_balanced_view()
326 v = self.client.load_balanced_view()
327 ar = v.apply_async(f)
327 ar = v.apply_async(f)
328 r1 = ar.get(1)
328 r1 = ar.get(1)
329 # give the Hub a chance to notice:
329 # give the Hub a chance to notice:
330 self._wait_for_idle()
330 self._wait_for_idle()
331 ahr = self.client.resubmit(ar.msg_ids)
331 ahr = self.client.resubmit(ar.msg_ids)
332 r2 = ahr.get(1)
332 r2 = ahr.get(1)
333 self.assertFalse(r1 == r2)
333 self.assertFalse(r1 == r2)
334
334
335 def test_resubmit_chain(self):
335 def test_resubmit_chain(self):
336 """resubmit resubmitted tasks"""
336 """resubmit resubmitted tasks"""
337 v = self.client.load_balanced_view()
337 v = self.client.load_balanced_view()
338 ar = v.apply_async(lambda x: x, 'x'*1024)
338 ar = v.apply_async(lambda x: x, 'x'*1024)
339 ar.get()
339 ar.get()
340 self._wait_for_idle()
340 self._wait_for_idle()
341 ars = [ar]
341 ars = [ar]
342
342
343 for i in range(10):
343 for i in range(10):
344 ar = ars[-1]
344 ar = ars[-1]
345 ar2 = self.client.resubmit(ar.msg_ids)
345 ar2 = self.client.resubmit(ar.msg_ids)
346
346
347 [ ar.get() for ar in ars ]
347 [ ar.get() for ar in ars ]
348
348
349 def test_resubmit_header(self):
349 def test_resubmit_header(self):
350 """resubmit shouldn't clobber the whole header"""
350 """resubmit shouldn't clobber the whole header"""
351 def f():
351 def f():
352 import random
352 import random
353 return random.random()
353 return random.random()
354 v = self.client.load_balanced_view()
354 v = self.client.load_balanced_view()
355 v.retries = 1
355 v.retries = 1
356 ar = v.apply_async(f)
356 ar = v.apply_async(f)
357 r1 = ar.get(1)
357 r1 = ar.get(1)
358 # give the Hub a chance to notice:
358 # give the Hub a chance to notice:
359 self._wait_for_idle()
359 self._wait_for_idle()
360 ahr = self.client.resubmit(ar.msg_ids)
360 ahr = self.client.resubmit(ar.msg_ids)
361 ahr.get(1)
361 ahr.get(1)
362 time.sleep(0.5)
362 time.sleep(0.5)
363 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
363 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
364 h1,h2 = [ r['header'] for r in records ]
364 h1,h2 = [ r['header'] for r in records ]
365 for key in set(h1.keys()).union(set(h2.keys())):
365 for key in set(h1.keys()).union(set(h2.keys())):
366 if key in ('msg_id', 'date'):
366 if key in ('msg_id', 'date'):
367 self.assertNotEquals(h1[key], h2[key])
367 self.assertNotEquals(h1[key], h2[key])
368 else:
368 else:
369 self.assertEqual(h1[key], h2[key])
369 self.assertEqual(h1[key], h2[key])
370
370
371 def test_resubmit_aborted(self):
371 def test_resubmit_aborted(self):
372 def f():
372 def f():
373 import random
373 import random
374 return random.random()
374 return random.random()
375 v = self.client.load_balanced_view()
375 v = self.client.load_balanced_view()
376 # restrict to one engine, so we can put a sleep
376 # restrict to one engine, so we can put a sleep
377 # ahead of the task, so it will get aborted
377 # ahead of the task, so it will get aborted
378 eid = self.client.ids[-1]
378 eid = self.client.ids[-1]
379 v.targets = [eid]
379 v.targets = [eid]
380 sleep = v.apply_async(time.sleep, 0.5)
380 sleep = v.apply_async(time.sleep, 0.5)
381 ar = v.apply_async(f)
381 ar = v.apply_async(f)
382 ar.abort()
382 ar.abort()
383 self.assertRaises(error.TaskAborted, ar.get)
383 self.assertRaises(error.TaskAborted, ar.get)
384 # Give the Hub a chance to get up to date:
384 # Give the Hub a chance to get up to date:
385 self._wait_for_idle()
385 self._wait_for_idle()
386 ahr = self.client.resubmit(ar.msg_ids)
386 ahr = self.client.resubmit(ar.msg_ids)
387 r2 = ahr.get(1)
387 r2 = ahr.get(1)
388
388
389 def test_resubmit_inflight(self):
389 def test_resubmit_inflight(self):
390 """resubmit of inflight task"""
390 """resubmit of inflight task"""
391 v = self.client.load_balanced_view()
391 v = self.client.load_balanced_view()
392 ar = v.apply_async(time.sleep,1)
392 ar = v.apply_async(time.sleep,1)
393 # give the message a chance to arrive
393 # give the message a chance to arrive
394 time.sleep(0.2)
394 time.sleep(0.2)
395 ahr = self.client.resubmit(ar.msg_ids)
395 ahr = self.client.resubmit(ar.msg_ids)
396 ar.get(2)
396 ar.get(2)
397 ahr.get(2)
397 ahr.get(2)
398
398
399 def test_resubmit_badkey(self):
399 def test_resubmit_badkey(self):
400 """ensure KeyError on resubmit of nonexistant task"""
400 """ensure KeyError on resubmit of nonexistant task"""
401 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
401 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
402
402
403 def test_purge_results(self):
403 def test_purge_results(self):
404 # ensure there are some tasks
404 # ensure there are some tasks
405 for i in range(5):
405 for i in range(5):
406 self.client[:].apply_sync(lambda : 1)
406 self.client[:].apply_sync(lambda : 1)
407 # Wait for the Hub to realise the result is done:
407 # Wait for the Hub to realise the result is done:
408 # This prevents a race condition, where we
408 # This prevents a race condition, where we
409 # might purge a result the Hub still thinks is pending.
409 # might purge a result the Hub still thinks is pending.
410 time.sleep(0.1)
410 time.sleep(0.1)
411 rc2 = clientmod.Client(profile='iptest')
411 rc2 = clientmod.Client(profile='iptest')
412 hist = self.client.hub_history()
412 hist = self.client.hub_history()
413 ahr = rc2.get_result([hist[-1]])
413 ahr = rc2.get_result([hist[-1]])
414 ahr.wait(10)
414 ahr.wait(10)
415 self.client.purge_results(hist[-1])
415 self.client.purge_results(hist[-1])
416 newhist = self.client.hub_history()
416 newhist = self.client.hub_history()
417 self.assertEqual(len(newhist)+1,len(hist))
417 self.assertEqual(len(newhist)+1,len(hist))
418 rc2.spin()
418 rc2.spin()
419 rc2.close()
419 rc2.close()
420
420
421 def test_purge_all_results(self):
421 def test_purge_all_results(self):
422 self.client.purge_results('all')
422 self.client.purge_results('all')
423 hist = self.client.hub_history()
423 hist = self.client.hub_history()
424 self.assertEqual(len(hist), 0)
424 self.assertEqual(len(hist), 0)
425
425
426 def test_spin_thread(self):
426 def test_spin_thread(self):
427 self.client.spin_thread(0.01)
427 self.client.spin_thread(0.01)
428 ar = self.client[-1].apply_async(lambda : 1)
428 ar = self.client[-1].apply_async(lambda : 1)
429 time.sleep(0.1)
429 time.sleep(0.1)
430 self.assertTrue(ar.wall_time < 0.1,
430 self.assertTrue(ar.wall_time < 0.1,
431 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
431 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
432 )
432 )
433
433
434 def test_stop_spin_thread(self):
434 def test_stop_spin_thread(self):
435 self.client.spin_thread(0.01)
435 self.client.spin_thread(0.01)
436 self.client.stop_spin_thread()
436 self.client.stop_spin_thread()
437 ar = self.client[-1].apply_async(lambda : 1)
437 ar = self.client[-1].apply_async(lambda : 1)
438 time.sleep(0.15)
438 time.sleep(0.15)
439 self.assertTrue(ar.wall_time > 0.1,
439 self.assertTrue(ar.wall_time > 0.1,
440 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
440 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
441 )
441 )
442
442
443 def test_activate(self):
443 def test_activate(self):
444 ip = get_ipython()
444 ip = get_ipython()
445 magics = ip.magics_manager.magics
445 magics = ip.magics_manager.magics
446 self.assertTrue('px' in magics['line'])
446 self.assertTrue('px' in magics['line'])
447 self.assertTrue('px' in magics['cell'])
447 self.assertTrue('px' in magics['cell'])
448 v0 = self.client.activate(-1, '0')
448 v0 = self.client.activate(-1, '0')
449 self.assertTrue('px0' in magics['line'])
449 self.assertTrue('px0' in magics['line'])
450 self.assertTrue('px0' in magics['cell'])
450 self.assertTrue('px0' in magics['cell'])
451 self.assertEqual(v0.targets, self.client.ids[-1])
451 self.assertEqual(v0.targets, self.client.ids[-1])
452 v0 = self.client.activate('all', 'all')
452 v0 = self.client.activate('all', 'all')
453 self.assertTrue('pxall' in magics['line'])
453 self.assertTrue('pxall' in magics['line'])
454 self.assertTrue('pxall' in magics['cell'])
454 self.assertTrue('pxall' in magics['cell'])
455 self.assertEqual(v0.targets, 'all')
455 self.assertEqual(v0.targets, 'all')
@@ -1,150 +1,150 b''
1 # encoding: utf-8
1 # encoding: utf-8
2
2
3 """This file contains unittests for the notification.py module."""
3 """This file contains unittests for the notification.py module."""
4
4
5 #-----------------------------------------------------------------------------
5 #-----------------------------------------------------------------------------
6 # Copyright (C) 2008-2011 The IPython Development Team
6 # Copyright (C) 2008-2011 The IPython Development Team
7 #
7 #
8 # Distributed under the terms of the BSD License. The full license is
8 # Distributed under the terms of the BSD License. The full license is
9 # in the file COPYING, distributed as part of this software.
9 # in the file COPYING, distributed as part of this software.
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11
11
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # Imports
13 # Imports
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15
15
16 import unittest
16 import unittest
17
17
18 from IPython.utils.notification import shared_center
18 from IPython.utils.notification import shared_center
19
19
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21 # Support Classes
21 # Support Classes
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23
23
24
24
25 class Observer(object):
25 class Observer(object):
26
26
27 def __init__(self, expected_ntype, expected_sender,
27 def __init__(self, expected_ntype, expected_sender,
28 center=shared_center, *args, **kwargs):
28 center=shared_center, *args, **kwargs):
29 super(Observer, self).__init__()
29 super(Observer, self).__init__()
30 self.expected_ntype = expected_ntype
30 self.expected_ntype = expected_ntype
31 self.expected_sender = expected_sender
31 self.expected_sender = expected_sender
32 self.expected_args = args
32 self.expected_args = args
33 self.expected_kwargs = kwargs
33 self.expected_kwargs = kwargs
34 self.recieved = False
34 self.recieved = False
35 center.add_observer(self.callback,
35 center.add_observer(self.callback,
36 self.expected_ntype,
36 self.expected_ntype,
37 self.expected_sender)
37 self.expected_sender)
38
38
39 def callback(self, ntype, sender, *args, **kwargs):
39 def callback(self, ntype, sender, *args, **kwargs):
40 assert(ntype == self.expected_ntype or
40 assert(ntype == self.expected_ntype or
41 self.expected_ntype == None)
41 self.expected_ntype == None)
42 assert(sender == self.expected_sender or
42 assert(sender == self.expected_sender or
43 self.expected_sender == None)
43 self.expected_sender == None)
44 assert(args == self.expected_args)
44 assert(args == self.expected_args)
45 assert(kwargs == self.expected_kwargs)
45 assert(kwargs == self.expected_kwargs)
46 self.recieved = True
46 self.recieved = True
47
47
48 def verify(self):
48 def verify(self):
49 assert(self.recieved)
49 assert(self.recieved)
50
50
51 def reset(self):
51 def reset(self):
52 self.recieved = False
52 self.recieved = False
53
53
54
54
55 class Notifier(object):
55 class Notifier(object):
56
56
57 def __init__(self, ntype, **kwargs):
57 def __init__(self, ntype, **kwargs):
58 super(Notifier, self).__init__()
58 super(Notifier, self).__init__()
59 self.ntype = ntype
59 self.ntype = ntype
60 self.kwargs = kwargs
60 self.kwargs = kwargs
61
61
62 def post(self, center=shared_center):
62 def post(self, center=shared_center):
63
63
64 center.post_notification(self.ntype, self,
64 center.post_notification(self.ntype, self,
65 **self.kwargs)
65 **self.kwargs)
66
66
67
67
68 #-----------------------------------------------------------------------------
68 #-----------------------------------------------------------------------------
69 # Tests
69 # Tests
70 #-----------------------------------------------------------------------------
70 #-----------------------------------------------------------------------------
71
71
72
72
73 class NotificationTests(unittest.TestCase):
73 class NotificationTests(unittest.TestCase):
74
74
75 def tearDown(self):
75 def tearDown(self):
76 shared_center.remove_all_observers()
76 shared_center.remove_all_observers()
77
77
78 def test_notification_delivered(self):
78 def test_notification_delivered(self):
79 """Test that notifications are delivered"""
79 """Test that notifications are delivered"""
80
80
81 expected_ntype = 'EXPECTED_TYPE'
81 expected_ntype = 'EXPECTED_TYPE'
82 sender = Notifier(expected_ntype)
82 sender = Notifier(expected_ntype)
83 observer = Observer(expected_ntype, sender)
83 observer = Observer(expected_ntype, sender)
84
84
85 sender.post()
85 sender.post()
86 observer.verify()
86 observer.verify()
87
87
88 def test_type_specificity(self):
88 def test_type_specificity(self):
89 """Test that observers are registered by type"""
89 """Test that observers are registered by type"""
90
90
91 expected_ntype = 1
91 expected_ntype = 1
92 unexpected_ntype = "UNEXPECTED_TYPE"
92 unexpected_ntype = "UNEXPECTED_TYPE"
93 sender = Notifier(expected_ntype)
93 sender = Notifier(expected_ntype)
94 unexpected_sender = Notifier(unexpected_ntype)
94 unexpected_sender = Notifier(unexpected_ntype)
95 observer = Observer(expected_ntype, sender)
95 observer = Observer(expected_ntype, sender)
96
96
97 sender.post()
97 sender.post()
98 unexpected_sender.post()
98 unexpected_sender.post()
99 observer.verify()
99 observer.verify()
100
100
101 def test_sender_specificity(self):
101 def test_sender_specificity(self):
102 """Test that observers are registered by sender"""
102 """Test that observers are registered by sender"""
103
103
104 expected_ntype = "EXPECTED_TYPE"
104 expected_ntype = "EXPECTED_TYPE"
105 sender1 = Notifier(expected_ntype)
105 sender1 = Notifier(expected_ntype)
106 sender2 = Notifier(expected_ntype)
106 sender2 = Notifier(expected_ntype)
107 observer = Observer(expected_ntype, sender1)
107 observer = Observer(expected_ntype, sender1)
108
108
109 sender1.post()
109 sender1.post()
110 sender2.post()
110 sender2.post()
111
111
112 observer.verify()
112 observer.verify()
113
113
114 def test_remove_all_observers(self):
114 def test_remove_all_observers(self):
115 """White-box test for remove_all_observers"""
115 """White-box test for remove_all_observers"""
116
116
117 for i in xrange(10):
117 for i in xrange(10):
118 Observer('TYPE', None, center=shared_center)
118 Observer('TYPE', None, center=shared_center)
119
119
120 self.assert_(len(shared_center.observers[('TYPE',None)]) >= 10,
120 self.assertTrue(len(shared_center.observers[('TYPE',None)]) >= 10,
121 "observers registered")
121 "observers registered")
122
122
123 shared_center.remove_all_observers()
123 shared_center.remove_all_observers()
124 self.assert_(len(shared_center.observers) == 0, "observers removed")
124 self.assertTrue(len(shared_center.observers) == 0, "observers removed")
125
125
126 def test_any_sender(self):
126 def test_any_sender(self):
127 expected_ntype = "EXPECTED_TYPE"
127 expected_ntype = "EXPECTED_TYPE"
128 sender1 = Notifier(expected_ntype)
128 sender1 = Notifier(expected_ntype)
129 sender2 = Notifier(expected_ntype)
129 sender2 = Notifier(expected_ntype)
130 observer = Observer(expected_ntype, None)
130 observer = Observer(expected_ntype, None)
131
131
132 sender1.post()
132 sender1.post()
133 observer.verify()
133 observer.verify()
134
134
135 observer.reset()
135 observer.reset()
136 sender2.post()
136 sender2.post()
137 observer.verify()
137 observer.verify()
138
138
139 def test_post_performance(self):
139 def test_post_performance(self):
140 """Test that post_notification, even with many registered irrelevant
140 """Test that post_notification, even with many registered irrelevant
141 observers is fast"""
141 observers is fast"""
142
142
143 for i in xrange(10):
143 for i in xrange(10):
144 Observer("UNRELATED_TYPE", None)
144 Observer("UNRELATED_TYPE", None)
145
145
146 o = Observer('EXPECTED_TYPE', None)
146 o = Observer('EXPECTED_TYPE', None)
147 shared_center.post_notification('EXPECTED_TYPE', self)
147 shared_center.post_notification('EXPECTED_TYPE', self)
148 o.verify()
148 o.verify()
149
149
150
150
@@ -1,908 +1,908 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Tests for IPython.utils.traitlets.
3 Tests for IPython.utils.traitlets.
4
4
5 Authors:
5 Authors:
6
6
7 * Brian Granger
7 * Brian Granger
8 * Enthought, Inc. Some of the code in this file comes from enthought.traits
8 * Enthought, Inc. Some of the code in this file comes from enthought.traits
9 and is licensed under the BSD license. Also, many of the ideas also come
9 and is licensed under the BSD license. Also, many of the ideas also come
10 from enthought.traits even though our implementation is very different.
10 from enthought.traits even though our implementation is very different.
11 """
11 """
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Copyright (C) 2008-2011 The IPython Development Team
14 # Copyright (C) 2008-2011 The IPython Development Team
15 #
15 #
16 # Distributed under the terms of the BSD License. The full license is in
16 # Distributed under the terms of the BSD License. The full license is in
17 # the file COPYING, distributed as part of this software.
17 # the file COPYING, distributed as part of this software.
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21 # Imports
21 # Imports
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23
23
24 import re
24 import re
25 import sys
25 import sys
26 from unittest import TestCase
26 from unittest import TestCase
27
27
28 from nose import SkipTest
28 from nose import SkipTest
29
29
30 from IPython.utils.traitlets import (
30 from IPython.utils.traitlets import (
31 HasTraits, MetaHasTraits, TraitType, Any, CBytes,
31 HasTraits, MetaHasTraits, TraitType, Any, CBytes,
32 Int, Long, Integer, Float, Complex, Bytes, Unicode, TraitError,
32 Int, Long, Integer, Float, Complex, Bytes, Unicode, TraitError,
33 Undefined, Type, This, Instance, TCPAddress, List, Tuple,
33 Undefined, Type, This, Instance, TCPAddress, List, Tuple,
34 ObjectName, DottedObjectName, CRegExp
34 ObjectName, DottedObjectName, CRegExp
35 )
35 )
36 from IPython.utils import py3compat
36 from IPython.utils import py3compat
37 from IPython.testing.decorators import skipif
37 from IPython.testing.decorators import skipif
38
38
39 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
40 # Helper classes for testing
40 # Helper classes for testing
41 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
42
42
43
43
44 class HasTraitsStub(HasTraits):
44 class HasTraitsStub(HasTraits):
45
45
46 def _notify_trait(self, name, old, new):
46 def _notify_trait(self, name, old, new):
47 self._notify_name = name
47 self._notify_name = name
48 self._notify_old = old
48 self._notify_old = old
49 self._notify_new = new
49 self._notify_new = new
50
50
51
51
52 #-----------------------------------------------------------------------------
52 #-----------------------------------------------------------------------------
53 # Test classes
53 # Test classes
54 #-----------------------------------------------------------------------------
54 #-----------------------------------------------------------------------------
55
55
56
56
57 class TestTraitType(TestCase):
57 class TestTraitType(TestCase):
58
58
59 def test_get_undefined(self):
59 def test_get_undefined(self):
60 class A(HasTraits):
60 class A(HasTraits):
61 a = TraitType
61 a = TraitType
62 a = A()
62 a = A()
63 self.assertEqual(a.a, Undefined)
63 self.assertEqual(a.a, Undefined)
64
64
65 def test_set(self):
65 def test_set(self):
66 class A(HasTraitsStub):
66 class A(HasTraitsStub):
67 a = TraitType
67 a = TraitType
68
68
69 a = A()
69 a = A()
70 a.a = 10
70 a.a = 10
71 self.assertEqual(a.a, 10)
71 self.assertEqual(a.a, 10)
72 self.assertEqual(a._notify_name, 'a')
72 self.assertEqual(a._notify_name, 'a')
73 self.assertEqual(a._notify_old, Undefined)
73 self.assertEqual(a._notify_old, Undefined)
74 self.assertEqual(a._notify_new, 10)
74 self.assertEqual(a._notify_new, 10)
75
75
76 def test_validate(self):
76 def test_validate(self):
77 class MyTT(TraitType):
77 class MyTT(TraitType):
78 def validate(self, inst, value):
78 def validate(self, inst, value):
79 return -1
79 return -1
80 class A(HasTraitsStub):
80 class A(HasTraitsStub):
81 tt = MyTT
81 tt = MyTT
82
82
83 a = A()
83 a = A()
84 a.tt = 10
84 a.tt = 10
85 self.assertEqual(a.tt, -1)
85 self.assertEqual(a.tt, -1)
86
86
87 def test_default_validate(self):
87 def test_default_validate(self):
88 class MyIntTT(TraitType):
88 class MyIntTT(TraitType):
89 def validate(self, obj, value):
89 def validate(self, obj, value):
90 if isinstance(value, int):
90 if isinstance(value, int):
91 return value
91 return value
92 self.error(obj, value)
92 self.error(obj, value)
93 class A(HasTraits):
93 class A(HasTraits):
94 tt = MyIntTT(10)
94 tt = MyIntTT(10)
95 a = A()
95 a = A()
96 self.assertEqual(a.tt, 10)
96 self.assertEqual(a.tt, 10)
97
97
98 # Defaults are validated when the HasTraits is instantiated
98 # Defaults are validated when the HasTraits is instantiated
99 class B(HasTraits):
99 class B(HasTraits):
100 tt = MyIntTT('bad default')
100 tt = MyIntTT('bad default')
101 self.assertRaises(TraitError, B)
101 self.assertRaises(TraitError, B)
102
102
103 def test_is_valid_for(self):
103 def test_is_valid_for(self):
104 class MyTT(TraitType):
104 class MyTT(TraitType):
105 def is_valid_for(self, value):
105 def is_valid_for(self, value):
106 return True
106 return True
107 class A(HasTraits):
107 class A(HasTraits):
108 tt = MyTT
108 tt = MyTT
109
109
110 a = A()
110 a = A()
111 a.tt = 10
111 a.tt = 10
112 self.assertEqual(a.tt, 10)
112 self.assertEqual(a.tt, 10)
113
113
114 def test_value_for(self):
114 def test_value_for(self):
115 class MyTT(TraitType):
115 class MyTT(TraitType):
116 def value_for(self, value):
116 def value_for(self, value):
117 return 20
117 return 20
118 class A(HasTraits):
118 class A(HasTraits):
119 tt = MyTT
119 tt = MyTT
120
120
121 a = A()
121 a = A()
122 a.tt = 10
122 a.tt = 10
123 self.assertEqual(a.tt, 20)
123 self.assertEqual(a.tt, 20)
124
124
125 def test_info(self):
125 def test_info(self):
126 class A(HasTraits):
126 class A(HasTraits):
127 tt = TraitType
127 tt = TraitType
128 a = A()
128 a = A()
129 self.assertEqual(A.tt.info(), 'any value')
129 self.assertEqual(A.tt.info(), 'any value')
130
130
131 def test_error(self):
131 def test_error(self):
132 class A(HasTraits):
132 class A(HasTraits):
133 tt = TraitType
133 tt = TraitType
134 a = A()
134 a = A()
135 self.assertRaises(TraitError, A.tt.error, a, 10)
135 self.assertRaises(TraitError, A.tt.error, a, 10)
136
136
137 def test_dynamic_initializer(self):
137 def test_dynamic_initializer(self):
138 class A(HasTraits):
138 class A(HasTraits):
139 x = Int(10)
139 x = Int(10)
140 def _x_default(self):
140 def _x_default(self):
141 return 11
141 return 11
142 class B(A):
142 class B(A):
143 x = Int(20)
143 x = Int(20)
144 class C(A):
144 class C(A):
145 def _x_default(self):
145 def _x_default(self):
146 return 21
146 return 21
147
147
148 a = A()
148 a = A()
149 self.assertEqual(a._trait_values, {})
149 self.assertEqual(a._trait_values, {})
150 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
150 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
151 self.assertEqual(a.x, 11)
151 self.assertEqual(a.x, 11)
152 self.assertEqual(a._trait_values, {'x': 11})
152 self.assertEqual(a._trait_values, {'x': 11})
153 b = B()
153 b = B()
154 self.assertEqual(b._trait_values, {'x': 20})
154 self.assertEqual(b._trait_values, {'x': 20})
155 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
155 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
156 self.assertEqual(b.x, 20)
156 self.assertEqual(b.x, 20)
157 c = C()
157 c = C()
158 self.assertEqual(c._trait_values, {})
158 self.assertEqual(c._trait_values, {})
159 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
159 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
160 self.assertEqual(c.x, 21)
160 self.assertEqual(c.x, 21)
161 self.assertEqual(c._trait_values, {'x': 21})
161 self.assertEqual(c._trait_values, {'x': 21})
162 # Ensure that the base class remains unmolested when the _default
162 # Ensure that the base class remains unmolested when the _default
163 # initializer gets overridden in a subclass.
163 # initializer gets overridden in a subclass.
164 a = A()
164 a = A()
165 c = C()
165 c = C()
166 self.assertEqual(a._trait_values, {})
166 self.assertEqual(a._trait_values, {})
167 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
167 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
168 self.assertEqual(a.x, 11)
168 self.assertEqual(a.x, 11)
169 self.assertEqual(a._trait_values, {'x': 11})
169 self.assertEqual(a._trait_values, {'x': 11})
170
170
171
171
172
172
173 class TestHasTraitsMeta(TestCase):
173 class TestHasTraitsMeta(TestCase):
174
174
175 def test_metaclass(self):
175 def test_metaclass(self):
176 self.assertEqual(type(HasTraits), MetaHasTraits)
176 self.assertEqual(type(HasTraits), MetaHasTraits)
177
177
178 class A(HasTraits):
178 class A(HasTraits):
179 a = Int
179 a = Int
180
180
181 a = A()
181 a = A()
182 self.assertEqual(type(a.__class__), MetaHasTraits)
182 self.assertEqual(type(a.__class__), MetaHasTraits)
183 self.assertEqual(a.a,0)
183 self.assertEqual(a.a,0)
184 a.a = 10
184 a.a = 10
185 self.assertEqual(a.a,10)
185 self.assertEqual(a.a,10)
186
186
187 class B(HasTraits):
187 class B(HasTraits):
188 b = Int()
188 b = Int()
189
189
190 b = B()
190 b = B()
191 self.assertEqual(b.b,0)
191 self.assertEqual(b.b,0)
192 b.b = 10
192 b.b = 10
193 self.assertEqual(b.b,10)
193 self.assertEqual(b.b,10)
194
194
195 class C(HasTraits):
195 class C(HasTraits):
196 c = Int(30)
196 c = Int(30)
197
197
198 c = C()
198 c = C()
199 self.assertEqual(c.c,30)
199 self.assertEqual(c.c,30)
200 c.c = 10
200 c.c = 10
201 self.assertEqual(c.c,10)
201 self.assertEqual(c.c,10)
202
202
203 def test_this_class(self):
203 def test_this_class(self):
204 class A(HasTraits):
204 class A(HasTraits):
205 t = This()
205 t = This()
206 tt = This()
206 tt = This()
207 class B(A):
207 class B(A):
208 tt = This()
208 tt = This()
209 ttt = This()
209 ttt = This()
210 self.assertEqual(A.t.this_class, A)
210 self.assertEqual(A.t.this_class, A)
211 self.assertEqual(B.t.this_class, A)
211 self.assertEqual(B.t.this_class, A)
212 self.assertEqual(B.tt.this_class, B)
212 self.assertEqual(B.tt.this_class, B)
213 self.assertEqual(B.ttt.this_class, B)
213 self.assertEqual(B.ttt.this_class, B)
214
214
215 class TestHasTraitsNotify(TestCase):
215 class TestHasTraitsNotify(TestCase):
216
216
217 def setUp(self):
217 def setUp(self):
218 self._notify1 = []
218 self._notify1 = []
219 self._notify2 = []
219 self._notify2 = []
220
220
221 def notify1(self, name, old, new):
221 def notify1(self, name, old, new):
222 self._notify1.append((name, old, new))
222 self._notify1.append((name, old, new))
223
223
224 def notify2(self, name, old, new):
224 def notify2(self, name, old, new):
225 self._notify2.append((name, old, new))
225 self._notify2.append((name, old, new))
226
226
227 def test_notify_all(self):
227 def test_notify_all(self):
228
228
229 class A(HasTraits):
229 class A(HasTraits):
230 a = Int
230 a = Int
231 b = Float
231 b = Float
232
232
233 a = A()
233 a = A()
234 a.on_trait_change(self.notify1)
234 a.on_trait_change(self.notify1)
235 a.a = 0
235 a.a = 0
236 self.assertEqual(len(self._notify1),0)
236 self.assertEqual(len(self._notify1),0)
237 a.b = 0.0
237 a.b = 0.0
238 self.assertEqual(len(self._notify1),0)
238 self.assertEqual(len(self._notify1),0)
239 a.a = 10
239 a.a = 10
240 self.assert_(('a',0,10) in self._notify1)
240 self.assertTrue(('a',0,10) in self._notify1)
241 a.b = 10.0
241 a.b = 10.0
242 self.assert_(('b',0.0,10.0) in self._notify1)
242 self.assertTrue(('b',0.0,10.0) in self._notify1)
243 self.assertRaises(TraitError,setattr,a,'a','bad string')
243 self.assertRaises(TraitError,setattr,a,'a','bad string')
244 self.assertRaises(TraitError,setattr,a,'b','bad string')
244 self.assertRaises(TraitError,setattr,a,'b','bad string')
245 self._notify1 = []
245 self._notify1 = []
246 a.on_trait_change(self.notify1,remove=True)
246 a.on_trait_change(self.notify1,remove=True)
247 a.a = 20
247 a.a = 20
248 a.b = 20.0
248 a.b = 20.0
249 self.assertEqual(len(self._notify1),0)
249 self.assertEqual(len(self._notify1),0)
250
250
251 def test_notify_one(self):
251 def test_notify_one(self):
252
252
253 class A(HasTraits):
253 class A(HasTraits):
254 a = Int
254 a = Int
255 b = Float
255 b = Float
256
256
257 a = A()
257 a = A()
258 a.on_trait_change(self.notify1, 'a')
258 a.on_trait_change(self.notify1, 'a')
259 a.a = 0
259 a.a = 0
260 self.assertEqual(len(self._notify1),0)
260 self.assertEqual(len(self._notify1),0)
261 a.a = 10
261 a.a = 10
262 self.assert_(('a',0,10) in self._notify1)
262 self.assertTrue(('a',0,10) in self._notify1)
263 self.assertRaises(TraitError,setattr,a,'a','bad string')
263 self.assertRaises(TraitError,setattr,a,'a','bad string')
264
264
265 def test_subclass(self):
265 def test_subclass(self):
266
266
267 class A(HasTraits):
267 class A(HasTraits):
268 a = Int
268 a = Int
269
269
270 class B(A):
270 class B(A):
271 b = Float
271 b = Float
272
272
273 b = B()
273 b = B()
274 self.assertEqual(b.a,0)
274 self.assertEqual(b.a,0)
275 self.assertEqual(b.b,0.0)
275 self.assertEqual(b.b,0.0)
276 b.a = 100
276 b.a = 100
277 b.b = 100.0
277 b.b = 100.0
278 self.assertEqual(b.a,100)
278 self.assertEqual(b.a,100)
279 self.assertEqual(b.b,100.0)
279 self.assertEqual(b.b,100.0)
280
280
281 def test_notify_subclass(self):
281 def test_notify_subclass(self):
282
282
283 class A(HasTraits):
283 class A(HasTraits):
284 a = Int
284 a = Int
285
285
286 class B(A):
286 class B(A):
287 b = Float
287 b = Float
288
288
289 b = B()
289 b = B()
290 b.on_trait_change(self.notify1, 'a')
290 b.on_trait_change(self.notify1, 'a')
291 b.on_trait_change(self.notify2, 'b')
291 b.on_trait_change(self.notify2, 'b')
292 b.a = 0
292 b.a = 0
293 b.b = 0.0
293 b.b = 0.0
294 self.assertEqual(len(self._notify1),0)
294 self.assertEqual(len(self._notify1),0)
295 self.assertEqual(len(self._notify2),0)
295 self.assertEqual(len(self._notify2),0)
296 b.a = 10
296 b.a = 10
297 b.b = 10.0
297 b.b = 10.0
298 self.assert_(('a',0,10) in self._notify1)
298 self.assertTrue(('a',0,10) in self._notify1)
299 self.assert_(('b',0.0,10.0) in self._notify2)
299 self.assertTrue(('b',0.0,10.0) in self._notify2)
300
300
301 def test_static_notify(self):
301 def test_static_notify(self):
302
302
303 class A(HasTraits):
303 class A(HasTraits):
304 a = Int
304 a = Int
305 _notify1 = []
305 _notify1 = []
306 def _a_changed(self, name, old, new):
306 def _a_changed(self, name, old, new):
307 self._notify1.append((name, old, new))
307 self._notify1.append((name, old, new))
308
308
309 a = A()
309 a = A()
310 a.a = 0
310 a.a = 0
311 # This is broken!!!
311 # This is broken!!!
312 self.assertEqual(len(a._notify1),0)
312 self.assertEqual(len(a._notify1),0)
313 a.a = 10
313 a.a = 10
314 self.assert_(('a',0,10) in a._notify1)
314 self.assertTrue(('a',0,10) in a._notify1)
315
315
316 class B(A):
316 class B(A):
317 b = Float
317 b = Float
318 _notify2 = []
318 _notify2 = []
319 def _b_changed(self, name, old, new):
319 def _b_changed(self, name, old, new):
320 self._notify2.append((name, old, new))
320 self._notify2.append((name, old, new))
321
321
322 b = B()
322 b = B()
323 b.a = 10
323 b.a = 10
324 b.b = 10.0
324 b.b = 10.0
325 self.assert_(('a',0,10) in b._notify1)
325 self.assertTrue(('a',0,10) in b._notify1)
326 self.assert_(('b',0.0,10.0) in b._notify2)
326 self.assertTrue(('b',0.0,10.0) in b._notify2)
327
327
328 def test_notify_args(self):
328 def test_notify_args(self):
329
329
330 def callback0():
330 def callback0():
331 self.cb = ()
331 self.cb = ()
332 def callback1(name):
332 def callback1(name):
333 self.cb = (name,)
333 self.cb = (name,)
334 def callback2(name, new):
334 def callback2(name, new):
335 self.cb = (name, new)
335 self.cb = (name, new)
336 def callback3(name, old, new):
336 def callback3(name, old, new):
337 self.cb = (name, old, new)
337 self.cb = (name, old, new)
338
338
339 class A(HasTraits):
339 class A(HasTraits):
340 a = Int
340 a = Int
341
341
342 a = A()
342 a = A()
343 a.on_trait_change(callback0, 'a')
343 a.on_trait_change(callback0, 'a')
344 a.a = 10
344 a.a = 10
345 self.assertEqual(self.cb,())
345 self.assertEqual(self.cb,())
346 a.on_trait_change(callback0, 'a', remove=True)
346 a.on_trait_change(callback0, 'a', remove=True)
347
347
348 a.on_trait_change(callback1, 'a')
348 a.on_trait_change(callback1, 'a')
349 a.a = 100
349 a.a = 100
350 self.assertEqual(self.cb,('a',))
350 self.assertEqual(self.cb,('a',))
351 a.on_trait_change(callback1, 'a', remove=True)
351 a.on_trait_change(callback1, 'a', remove=True)
352
352
353 a.on_trait_change(callback2, 'a')
353 a.on_trait_change(callback2, 'a')
354 a.a = 1000
354 a.a = 1000
355 self.assertEqual(self.cb,('a',1000))
355 self.assertEqual(self.cb,('a',1000))
356 a.on_trait_change(callback2, 'a', remove=True)
356 a.on_trait_change(callback2, 'a', remove=True)
357
357
358 a.on_trait_change(callback3, 'a')
358 a.on_trait_change(callback3, 'a')
359 a.a = 10000
359 a.a = 10000
360 self.assertEqual(self.cb,('a',1000,10000))
360 self.assertEqual(self.cb,('a',1000,10000))
361 a.on_trait_change(callback3, 'a', remove=True)
361 a.on_trait_change(callback3, 'a', remove=True)
362
362
363 self.assertEqual(len(a._trait_notifiers['a']),0)
363 self.assertEqual(len(a._trait_notifiers['a']),0)
364
364
365
365
366 class TestHasTraits(TestCase):
366 class TestHasTraits(TestCase):
367
367
368 def test_trait_names(self):
368 def test_trait_names(self):
369 class A(HasTraits):
369 class A(HasTraits):
370 i = Int
370 i = Int
371 f = Float
371 f = Float
372 a = A()
372 a = A()
373 self.assertEqual(sorted(a.trait_names()),['f','i'])
373 self.assertEqual(sorted(a.trait_names()),['f','i'])
374 self.assertEqual(sorted(A.class_trait_names()),['f','i'])
374 self.assertEqual(sorted(A.class_trait_names()),['f','i'])
375
375
376 def test_trait_metadata(self):
376 def test_trait_metadata(self):
377 class A(HasTraits):
377 class A(HasTraits):
378 i = Int(config_key='MY_VALUE')
378 i = Int(config_key='MY_VALUE')
379 a = A()
379 a = A()
380 self.assertEqual(a.trait_metadata('i','config_key'), 'MY_VALUE')
380 self.assertEqual(a.trait_metadata('i','config_key'), 'MY_VALUE')
381
381
382 def test_traits(self):
382 def test_traits(self):
383 class A(HasTraits):
383 class A(HasTraits):
384 i = Int
384 i = Int
385 f = Float
385 f = Float
386 a = A()
386 a = A()
387 self.assertEqual(a.traits(), dict(i=A.i, f=A.f))
387 self.assertEqual(a.traits(), dict(i=A.i, f=A.f))
388 self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f))
388 self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f))
389
389
390 def test_traits_metadata(self):
390 def test_traits_metadata(self):
391 class A(HasTraits):
391 class A(HasTraits):
392 i = Int(config_key='VALUE1', other_thing='VALUE2')
392 i = Int(config_key='VALUE1', other_thing='VALUE2')
393 f = Float(config_key='VALUE3', other_thing='VALUE2')
393 f = Float(config_key='VALUE3', other_thing='VALUE2')
394 j = Int(0)
394 j = Int(0)
395 a = A()
395 a = A()
396 self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
396 self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
397 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
397 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
398 self.assertEqual(traits, dict(i=A.i))
398 self.assertEqual(traits, dict(i=A.i))
399
399
400 # This passes, but it shouldn't because I am replicating a bug in
400 # This passes, but it shouldn't because I am replicating a bug in
401 # traits.
401 # traits.
402 traits = a.traits(config_key=lambda v: True)
402 traits = a.traits(config_key=lambda v: True)
403 self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
403 self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
404
404
405 def test_init(self):
405 def test_init(self):
406 class A(HasTraits):
406 class A(HasTraits):
407 i = Int()
407 i = Int()
408 x = Float()
408 x = Float()
409 a = A(i=1, x=10.0)
409 a = A(i=1, x=10.0)
410 self.assertEqual(a.i, 1)
410 self.assertEqual(a.i, 1)
411 self.assertEqual(a.x, 10.0)
411 self.assertEqual(a.x, 10.0)
412
412
413 #-----------------------------------------------------------------------------
413 #-----------------------------------------------------------------------------
414 # Tests for specific trait types
414 # Tests for specific trait types
415 #-----------------------------------------------------------------------------
415 #-----------------------------------------------------------------------------
416
416
417
417
418 class TestType(TestCase):
418 class TestType(TestCase):
419
419
420 def test_default(self):
420 def test_default(self):
421
421
422 class B(object): pass
422 class B(object): pass
423 class A(HasTraits):
423 class A(HasTraits):
424 klass = Type
424 klass = Type
425
425
426 a = A()
426 a = A()
427 self.assertEqual(a.klass, None)
427 self.assertEqual(a.klass, None)
428
428
429 a.klass = B
429 a.klass = B
430 self.assertEqual(a.klass, B)
430 self.assertEqual(a.klass, B)
431 self.assertRaises(TraitError, setattr, a, 'klass', 10)
431 self.assertRaises(TraitError, setattr, a, 'klass', 10)
432
432
433 def test_value(self):
433 def test_value(self):
434
434
435 class B(object): pass
435 class B(object): pass
436 class C(object): pass
436 class C(object): pass
437 class A(HasTraits):
437 class A(HasTraits):
438 klass = Type(B)
438 klass = Type(B)
439
439
440 a = A()
440 a = A()
441 self.assertEqual(a.klass, B)
441 self.assertEqual(a.klass, B)
442 self.assertRaises(TraitError, setattr, a, 'klass', C)
442 self.assertRaises(TraitError, setattr, a, 'klass', C)
443 self.assertRaises(TraitError, setattr, a, 'klass', object)
443 self.assertRaises(TraitError, setattr, a, 'klass', object)
444 a.klass = B
444 a.klass = B
445
445
446 def test_allow_none(self):
446 def test_allow_none(self):
447
447
448 class B(object): pass
448 class B(object): pass
449 class C(B): pass
449 class C(B): pass
450 class A(HasTraits):
450 class A(HasTraits):
451 klass = Type(B, allow_none=False)
451 klass = Type(B, allow_none=False)
452
452
453 a = A()
453 a = A()
454 self.assertEqual(a.klass, B)
454 self.assertEqual(a.klass, B)
455 self.assertRaises(TraitError, setattr, a, 'klass', None)
455 self.assertRaises(TraitError, setattr, a, 'klass', None)
456 a.klass = C
456 a.klass = C
457 self.assertEqual(a.klass, C)
457 self.assertEqual(a.klass, C)
458
458
459 def test_validate_klass(self):
459 def test_validate_klass(self):
460
460
461 class A(HasTraits):
461 class A(HasTraits):
462 klass = Type('no strings allowed')
462 klass = Type('no strings allowed')
463
463
464 self.assertRaises(ImportError, A)
464 self.assertRaises(ImportError, A)
465
465
466 class A(HasTraits):
466 class A(HasTraits):
467 klass = Type('rub.adub.Duck')
467 klass = Type('rub.adub.Duck')
468
468
469 self.assertRaises(ImportError, A)
469 self.assertRaises(ImportError, A)
470
470
471 def test_validate_default(self):
471 def test_validate_default(self):
472
472
473 class B(object): pass
473 class B(object): pass
474 class A(HasTraits):
474 class A(HasTraits):
475 klass = Type('bad default', B)
475 klass = Type('bad default', B)
476
476
477 self.assertRaises(ImportError, A)
477 self.assertRaises(ImportError, A)
478
478
479 class C(HasTraits):
479 class C(HasTraits):
480 klass = Type(None, B, allow_none=False)
480 klass = Type(None, B, allow_none=False)
481
481
482 self.assertRaises(TraitError, C)
482 self.assertRaises(TraitError, C)
483
483
484 def test_str_klass(self):
484 def test_str_klass(self):
485
485
486 class A(HasTraits):
486 class A(HasTraits):
487 klass = Type('IPython.utils.ipstruct.Struct')
487 klass = Type('IPython.utils.ipstruct.Struct')
488
488
489 from IPython.utils.ipstruct import Struct
489 from IPython.utils.ipstruct import Struct
490 a = A()
490 a = A()
491 a.klass = Struct
491 a.klass = Struct
492 self.assertEqual(a.klass, Struct)
492 self.assertEqual(a.klass, Struct)
493
493
494 self.assertRaises(TraitError, setattr, a, 'klass', 10)
494 self.assertRaises(TraitError, setattr, a, 'klass', 10)
495
495
496 class TestInstance(TestCase):
496 class TestInstance(TestCase):
497
497
498 def test_basic(self):
498 def test_basic(self):
499 class Foo(object): pass
499 class Foo(object): pass
500 class Bar(Foo): pass
500 class Bar(Foo): pass
501 class Bah(object): pass
501 class Bah(object): pass
502
502
503 class A(HasTraits):
503 class A(HasTraits):
504 inst = Instance(Foo)
504 inst = Instance(Foo)
505
505
506 a = A()
506 a = A()
507 self.assert_(a.inst is None)
507 self.assertTrue(a.inst is None)
508 a.inst = Foo()
508 a.inst = Foo()
509 self.assert_(isinstance(a.inst, Foo))
509 self.assertTrue(isinstance(a.inst, Foo))
510 a.inst = Bar()
510 a.inst = Bar()
511 self.assert_(isinstance(a.inst, Foo))
511 self.assertTrue(isinstance(a.inst, Foo))
512 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
512 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
513 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
513 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
514 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
514 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
515
515
516 def test_unique_default_value(self):
516 def test_unique_default_value(self):
517 class Foo(object): pass
517 class Foo(object): pass
518 class A(HasTraits):
518 class A(HasTraits):
519 inst = Instance(Foo,(),{})
519 inst = Instance(Foo,(),{})
520
520
521 a = A()
521 a = A()
522 b = A()
522 b = A()
523 self.assert_(a.inst is not b.inst)
523 self.assertTrue(a.inst is not b.inst)
524
524
525 def test_args_kw(self):
525 def test_args_kw(self):
526 class Foo(object):
526 class Foo(object):
527 def __init__(self, c): self.c = c
527 def __init__(self, c): self.c = c
528 class Bar(object): pass
528 class Bar(object): pass
529 class Bah(object):
529 class Bah(object):
530 def __init__(self, c, d):
530 def __init__(self, c, d):
531 self.c = c; self.d = d
531 self.c = c; self.d = d
532
532
533 class A(HasTraits):
533 class A(HasTraits):
534 inst = Instance(Foo, (10,))
534 inst = Instance(Foo, (10,))
535 a = A()
535 a = A()
536 self.assertEqual(a.inst.c, 10)
536 self.assertEqual(a.inst.c, 10)
537
537
538 class B(HasTraits):
538 class B(HasTraits):
539 inst = Instance(Bah, args=(10,), kw=dict(d=20))
539 inst = Instance(Bah, args=(10,), kw=dict(d=20))
540 b = B()
540 b = B()
541 self.assertEqual(b.inst.c, 10)
541 self.assertEqual(b.inst.c, 10)
542 self.assertEqual(b.inst.d, 20)
542 self.assertEqual(b.inst.d, 20)
543
543
544 class C(HasTraits):
544 class C(HasTraits):
545 inst = Instance(Foo)
545 inst = Instance(Foo)
546 c = C()
546 c = C()
547 self.assert_(c.inst is None)
547 self.assertTrue(c.inst is None)
548
548
549 def test_bad_default(self):
549 def test_bad_default(self):
550 class Foo(object): pass
550 class Foo(object): pass
551
551
552 class A(HasTraits):
552 class A(HasTraits):
553 inst = Instance(Foo, allow_none=False)
553 inst = Instance(Foo, allow_none=False)
554
554
555 self.assertRaises(TraitError, A)
555 self.assertRaises(TraitError, A)
556
556
557 def test_instance(self):
557 def test_instance(self):
558 class Foo(object): pass
558 class Foo(object): pass
559
559
560 def inner():
560 def inner():
561 class A(HasTraits):
561 class A(HasTraits):
562 inst = Instance(Foo())
562 inst = Instance(Foo())
563
563
564 self.assertRaises(TraitError, inner)
564 self.assertRaises(TraitError, inner)
565
565
566
566
567 class TestThis(TestCase):
567 class TestThis(TestCase):
568
568
569 def test_this_class(self):
569 def test_this_class(self):
570 class Foo(HasTraits):
570 class Foo(HasTraits):
571 this = This
571 this = This
572
572
573 f = Foo()
573 f = Foo()
574 self.assertEqual(f.this, None)
574 self.assertEqual(f.this, None)
575 g = Foo()
575 g = Foo()
576 f.this = g
576 f.this = g
577 self.assertEqual(f.this, g)
577 self.assertEqual(f.this, g)
578 self.assertRaises(TraitError, setattr, f, 'this', 10)
578 self.assertRaises(TraitError, setattr, f, 'this', 10)
579
579
580 def test_this_inst(self):
580 def test_this_inst(self):
581 class Foo(HasTraits):
581 class Foo(HasTraits):
582 this = This()
582 this = This()
583
583
584 f = Foo()
584 f = Foo()
585 f.this = Foo()
585 f.this = Foo()
586 self.assert_(isinstance(f.this, Foo))
586 self.assertTrue(isinstance(f.this, Foo))
587
587
588 def test_subclass(self):
588 def test_subclass(self):
589 class Foo(HasTraits):
589 class Foo(HasTraits):
590 t = This()
590 t = This()
591 class Bar(Foo):
591 class Bar(Foo):
592 pass
592 pass
593 f = Foo()
593 f = Foo()
594 b = Bar()
594 b = Bar()
595 f.t = b
595 f.t = b
596 b.t = f
596 b.t = f
597 self.assertEqual(f.t, b)
597 self.assertEqual(f.t, b)
598 self.assertEqual(b.t, f)
598 self.assertEqual(b.t, f)
599
599
600 def test_subclass_override(self):
600 def test_subclass_override(self):
601 class Foo(HasTraits):
601 class Foo(HasTraits):
602 t = This()
602 t = This()
603 class Bar(Foo):
603 class Bar(Foo):
604 t = This()
604 t = This()
605 f = Foo()
605 f = Foo()
606 b = Bar()
606 b = Bar()
607 f.t = b
607 f.t = b
608 self.assertEqual(f.t, b)
608 self.assertEqual(f.t, b)
609 self.assertRaises(TraitError, setattr, b, 't', f)
609 self.assertRaises(TraitError, setattr, b, 't', f)
610
610
611 class TraitTestBase(TestCase):
611 class TraitTestBase(TestCase):
612 """A best testing class for basic trait types."""
612 """A best testing class for basic trait types."""
613
613
614 def assign(self, value):
614 def assign(self, value):
615 self.obj.value = value
615 self.obj.value = value
616
616
617 def coerce(self, value):
617 def coerce(self, value):
618 return value
618 return value
619
619
620 def test_good_values(self):
620 def test_good_values(self):
621 if hasattr(self, '_good_values'):
621 if hasattr(self, '_good_values'):
622 for value in self._good_values:
622 for value in self._good_values:
623 self.assign(value)
623 self.assign(value)
624 self.assertEqual(self.obj.value, self.coerce(value))
624 self.assertEqual(self.obj.value, self.coerce(value))
625
625
626 def test_bad_values(self):
626 def test_bad_values(self):
627 if hasattr(self, '_bad_values'):
627 if hasattr(self, '_bad_values'):
628 for value in self._bad_values:
628 for value in self._bad_values:
629 try:
629 try:
630 self.assertRaises(TraitError, self.assign, value)
630 self.assertRaises(TraitError, self.assign, value)
631 except AssertionError:
631 except AssertionError:
632 assert False, value
632 assert False, value
633
633
634 def test_default_value(self):
634 def test_default_value(self):
635 if hasattr(self, '_default_value'):
635 if hasattr(self, '_default_value'):
636 self.assertEqual(self._default_value, self.obj.value)
636 self.assertEqual(self._default_value, self.obj.value)
637
637
638 def tearDown(self):
638 def tearDown(self):
639 # restore default value after tests, if set
639 # restore default value after tests, if set
640 if hasattr(self, '_default_value'):
640 if hasattr(self, '_default_value'):
641 self.obj.value = self._default_value
641 self.obj.value = self._default_value
642
642
643
643
644 class AnyTrait(HasTraits):
644 class AnyTrait(HasTraits):
645
645
646 value = Any
646 value = Any
647
647
648 class AnyTraitTest(TraitTestBase):
648 class AnyTraitTest(TraitTestBase):
649
649
650 obj = AnyTrait()
650 obj = AnyTrait()
651
651
652 _default_value = None
652 _default_value = None
653 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
653 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
654 _bad_values = []
654 _bad_values = []
655
655
656
656
657 class IntTrait(HasTraits):
657 class IntTrait(HasTraits):
658
658
659 value = Int(99)
659 value = Int(99)
660
660
661 class TestInt(TraitTestBase):
661 class TestInt(TraitTestBase):
662
662
663 obj = IntTrait()
663 obj = IntTrait()
664 _default_value = 99
664 _default_value = 99
665 _good_values = [10, -10]
665 _good_values = [10, -10]
666 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j,
666 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j,
667 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
667 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
668 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
668 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
669 if not py3compat.PY3:
669 if not py3compat.PY3:
670 _bad_values.extend([10L, -10L, 10*sys.maxint, -10*sys.maxint])
670 _bad_values.extend([10L, -10L, 10*sys.maxint, -10*sys.maxint])
671
671
672
672
673 class LongTrait(HasTraits):
673 class LongTrait(HasTraits):
674
674
675 value = Long(99L)
675 value = Long(99L)
676
676
677 class TestLong(TraitTestBase):
677 class TestLong(TraitTestBase):
678
678
679 obj = LongTrait()
679 obj = LongTrait()
680
680
681 _default_value = 99L
681 _default_value = 99L
682 _good_values = [10, -10, 10L, -10L]
682 _good_values = [10, -10, 10L, -10L]
683 _bad_values = ['ten', u'ten', [10], [10l], {'ten': 10},(10,),(10L,),
683 _bad_values = ['ten', u'ten', [10], [10l], {'ten': 10},(10,),(10L,),
684 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
684 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
685 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
685 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
686 u'-10.1']
686 u'-10.1']
687 if not py3compat.PY3:
687 if not py3compat.PY3:
688 # maxint undefined on py3, because int == long
688 # maxint undefined on py3, because int == long
689 _good_values.extend([10*sys.maxint, -10*sys.maxint])
689 _good_values.extend([10*sys.maxint, -10*sys.maxint])
690
690
691 @skipif(py3compat.PY3, "not relevant on py3")
691 @skipif(py3compat.PY3, "not relevant on py3")
692 def test_cast_small(self):
692 def test_cast_small(self):
693 """Long casts ints to long"""
693 """Long casts ints to long"""
694 self.obj.value = 10
694 self.obj.value = 10
695 self.assertEqual(type(self.obj.value), long)
695 self.assertEqual(type(self.obj.value), long)
696
696
697
697
698 class IntegerTrait(HasTraits):
698 class IntegerTrait(HasTraits):
699 value = Integer(1)
699 value = Integer(1)
700
700
701 class TestInteger(TestLong):
701 class TestInteger(TestLong):
702 obj = IntegerTrait()
702 obj = IntegerTrait()
703 _default_value = 1
703 _default_value = 1
704
704
705 def coerce(self, n):
705 def coerce(self, n):
706 return int(n)
706 return int(n)
707
707
708 @skipif(py3compat.PY3, "not relevant on py3")
708 @skipif(py3compat.PY3, "not relevant on py3")
709 def test_cast_small(self):
709 def test_cast_small(self):
710 """Integer casts small longs to int"""
710 """Integer casts small longs to int"""
711 if py3compat.PY3:
711 if py3compat.PY3:
712 raise SkipTest("not relevant on py3")
712 raise SkipTest("not relevant on py3")
713
713
714 self.obj.value = 100L
714 self.obj.value = 100L
715 self.assertEqual(type(self.obj.value), int)
715 self.assertEqual(type(self.obj.value), int)
716
716
717
717
718 class FloatTrait(HasTraits):
718 class FloatTrait(HasTraits):
719
719
720 value = Float(99.0)
720 value = Float(99.0)
721
721
722 class TestFloat(TraitTestBase):
722 class TestFloat(TraitTestBase):
723
723
724 obj = FloatTrait()
724 obj = FloatTrait()
725
725
726 _default_value = 99.0
726 _default_value = 99.0
727 _good_values = [10, -10, 10.1, -10.1]
727 _good_values = [10, -10, 10.1, -10.1]
728 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None,
728 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None,
729 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
729 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
730 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
730 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
731 if not py3compat.PY3:
731 if not py3compat.PY3:
732 _bad_values.extend([10L, -10L])
732 _bad_values.extend([10L, -10L])
733
733
734
734
735 class ComplexTrait(HasTraits):
735 class ComplexTrait(HasTraits):
736
736
737 value = Complex(99.0-99.0j)
737 value = Complex(99.0-99.0j)
738
738
739 class TestComplex(TraitTestBase):
739 class TestComplex(TraitTestBase):
740
740
741 obj = ComplexTrait()
741 obj = ComplexTrait()
742
742
743 _default_value = 99.0-99.0j
743 _default_value = 99.0-99.0j
744 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
744 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
745 10.1j, 10.1+10.1j, 10.1-10.1j]
745 10.1j, 10.1+10.1j, 10.1-10.1j]
746 _bad_values = [u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
746 _bad_values = [u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
747 if not py3compat.PY3:
747 if not py3compat.PY3:
748 _bad_values.extend([10L, -10L])
748 _bad_values.extend([10L, -10L])
749
749
750
750
751 class BytesTrait(HasTraits):
751 class BytesTrait(HasTraits):
752
752
753 value = Bytes(b'string')
753 value = Bytes(b'string')
754
754
755 class TestBytes(TraitTestBase):
755 class TestBytes(TraitTestBase):
756
756
757 obj = BytesTrait()
757 obj = BytesTrait()
758
758
759 _default_value = b'string'
759 _default_value = b'string'
760 _good_values = [b'10', b'-10', b'10L',
760 _good_values = [b'10', b'-10', b'10L',
761 b'-10L', b'10.1', b'-10.1', b'string']
761 b'-10L', b'10.1', b'-10.1', b'string']
762 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j, [10],
762 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j, [10],
763 ['ten'],{'ten': 10},(10,), None, u'string']
763 ['ten'],{'ten': 10},(10,), None, u'string']
764
764
765
765
766 class UnicodeTrait(HasTraits):
766 class UnicodeTrait(HasTraits):
767
767
768 value = Unicode(u'unicode')
768 value = Unicode(u'unicode')
769
769
770 class TestUnicode(TraitTestBase):
770 class TestUnicode(TraitTestBase):
771
771
772 obj = UnicodeTrait()
772 obj = UnicodeTrait()
773
773
774 _default_value = u'unicode'
774 _default_value = u'unicode'
775 _good_values = ['10', '-10', '10L', '-10L', '10.1',
775 _good_values = ['10', '-10', '10L', '-10L', '10.1',
776 '-10.1', '', u'', 'string', u'string', u"€"]
776 '-10.1', '', u'', 'string', u'string', u"€"]
777 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j,
777 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j,
778 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
778 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
779
779
780
780
781 class ObjectNameTrait(HasTraits):
781 class ObjectNameTrait(HasTraits):
782 value = ObjectName("abc")
782 value = ObjectName("abc")
783
783
784 class TestObjectName(TraitTestBase):
784 class TestObjectName(TraitTestBase):
785 obj = ObjectNameTrait()
785 obj = ObjectNameTrait()
786
786
787 _default_value = "abc"
787 _default_value = "abc"
788 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
788 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
789 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
789 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
790 object(), object]
790 object(), object]
791 if sys.version_info[0] < 3:
791 if sys.version_info[0] < 3:
792 _bad_values.append(u"þ")
792 _bad_values.append(u"þ")
793 else:
793 else:
794 _good_values.append(u"þ") # þ=1 is valid in Python 3 (PEP 3131).
794 _good_values.append(u"þ") # þ=1 is valid in Python 3 (PEP 3131).
795
795
796
796
797 class DottedObjectNameTrait(HasTraits):
797 class DottedObjectNameTrait(HasTraits):
798 value = DottedObjectName("a.b")
798 value = DottedObjectName("a.b")
799
799
800 class TestDottedObjectName(TraitTestBase):
800 class TestDottedObjectName(TraitTestBase):
801 obj = DottedObjectNameTrait()
801 obj = DottedObjectNameTrait()
802
802
803 _default_value = "a.b"
803 _default_value = "a.b"
804 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
804 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
805 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc."]
805 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc."]
806 if sys.version_info[0] < 3:
806 if sys.version_info[0] < 3:
807 _bad_values.append(u"t.þ")
807 _bad_values.append(u"t.þ")
808 else:
808 else:
809 _good_values.append(u"t.þ")
809 _good_values.append(u"t.þ")
810
810
811
811
812 class TCPAddressTrait(HasTraits):
812 class TCPAddressTrait(HasTraits):
813
813
814 value = TCPAddress()
814 value = TCPAddress()
815
815
816 class TestTCPAddress(TraitTestBase):
816 class TestTCPAddress(TraitTestBase):
817
817
818 obj = TCPAddressTrait()
818 obj = TCPAddressTrait()
819
819
820 _default_value = ('127.0.0.1',0)
820 _default_value = ('127.0.0.1',0)
821 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
821 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
822 _bad_values = [(0,0),('localhost',10.0),('localhost',-1)]
822 _bad_values = [(0,0),('localhost',10.0),('localhost',-1)]
823
823
824 class ListTrait(HasTraits):
824 class ListTrait(HasTraits):
825
825
826 value = List(Int)
826 value = List(Int)
827
827
828 class TestList(TraitTestBase):
828 class TestList(TraitTestBase):
829
829
830 obj = ListTrait()
830 obj = ListTrait()
831
831
832 _default_value = []
832 _default_value = []
833 _good_values = [[], [1], range(10)]
833 _good_values = [[], [1], range(10)]
834 _bad_values = [10, [1,'a'], 'a', (1,2)]
834 _bad_values = [10, [1,'a'], 'a', (1,2)]
835
835
836 class LenListTrait(HasTraits):
836 class LenListTrait(HasTraits):
837
837
838 value = List(Int, [0], minlen=1, maxlen=2)
838 value = List(Int, [0], minlen=1, maxlen=2)
839
839
840 class TestLenList(TraitTestBase):
840 class TestLenList(TraitTestBase):
841
841
842 obj = LenListTrait()
842 obj = LenListTrait()
843
843
844 _default_value = [0]
844 _default_value = [0]
845 _good_values = [[1], range(2)]
845 _good_values = [[1], range(2)]
846 _bad_values = [10, [1,'a'], 'a', (1,2), [], range(3)]
846 _bad_values = [10, [1,'a'], 'a', (1,2), [], range(3)]
847
847
848 class TupleTrait(HasTraits):
848 class TupleTrait(HasTraits):
849
849
850 value = Tuple(Int)
850 value = Tuple(Int)
851
851
852 class TestTupleTrait(TraitTestBase):
852 class TestTupleTrait(TraitTestBase):
853
853
854 obj = TupleTrait()
854 obj = TupleTrait()
855
855
856 _default_value = None
856 _default_value = None
857 _good_values = [(1,), None,(0,)]
857 _good_values = [(1,), None,(0,)]
858 _bad_values = [10, (1,2), [1],('a'), ()]
858 _bad_values = [10, (1,2), [1],('a'), ()]
859
859
860 def test_invalid_args(self):
860 def test_invalid_args(self):
861 self.assertRaises(TypeError, Tuple, 5)
861 self.assertRaises(TypeError, Tuple, 5)
862 self.assertRaises(TypeError, Tuple, default_value='hello')
862 self.assertRaises(TypeError, Tuple, default_value='hello')
863 t = Tuple(Int, CBytes, default_value=(1,5))
863 t = Tuple(Int, CBytes, default_value=(1,5))
864
864
865 class LooseTupleTrait(HasTraits):
865 class LooseTupleTrait(HasTraits):
866
866
867 value = Tuple((1,2,3))
867 value = Tuple((1,2,3))
868
868
869 class TestLooseTupleTrait(TraitTestBase):
869 class TestLooseTupleTrait(TraitTestBase):
870
870
871 obj = LooseTupleTrait()
871 obj = LooseTupleTrait()
872
872
873 _default_value = (1,2,3)
873 _default_value = (1,2,3)
874 _good_values = [(1,), None, (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
874 _good_values = [(1,), None, (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
875 _bad_values = [10, 'hello', [1], []]
875 _bad_values = [10, 'hello', [1], []]
876
876
877 def test_invalid_args(self):
877 def test_invalid_args(self):
878 self.assertRaises(TypeError, Tuple, 5)
878 self.assertRaises(TypeError, Tuple, 5)
879 self.assertRaises(TypeError, Tuple, default_value='hello')
879 self.assertRaises(TypeError, Tuple, default_value='hello')
880 t = Tuple(Int, CBytes, default_value=(1,5))
880 t = Tuple(Int, CBytes, default_value=(1,5))
881
881
882
882
883 class MultiTupleTrait(HasTraits):
883 class MultiTupleTrait(HasTraits):
884
884
885 value = Tuple(Int, Bytes, default_value=[99,b'bottles'])
885 value = Tuple(Int, Bytes, default_value=[99,b'bottles'])
886
886
887 class TestMultiTuple(TraitTestBase):
887 class TestMultiTuple(TraitTestBase):
888
888
889 obj = MultiTupleTrait()
889 obj = MultiTupleTrait()
890
890
891 _default_value = (99,b'bottles')
891 _default_value = (99,b'bottles')
892 _good_values = [(1,b'a'), (2,b'b')]
892 _good_values = [(1,b'a'), (2,b'b')]
893 _bad_values = ((),10, b'a', (1,b'a',3), (b'a',1), (1, u'a'))
893 _bad_values = ((),10, b'a', (1,b'a',3), (b'a',1), (1, u'a'))
894
894
895 class CRegExpTrait(HasTraits):
895 class CRegExpTrait(HasTraits):
896
896
897 value = CRegExp(r'')
897 value = CRegExp(r'')
898
898
899 class TestCRegExp(TraitTestBase):
899 class TestCRegExp(TraitTestBase):
900
900
901 def coerce(self, value):
901 def coerce(self, value):
902 return re.compile(value)
902 return re.compile(value)
903
903
904 obj = CRegExpTrait()
904 obj = CRegExpTrait()
905
905
906 _default_value = re.compile(r'')
906 _default_value = re.compile(r'')
907 _good_values = [r'\d+', re.compile(r'\d+')]
907 _good_values = [r'\d+', re.compile(r'\d+')]
908 _bad_values = [r'(', None, ()]
908 _bad_values = [r'(', None, ()]
General Comments 0
You need to be logged in to leave comments. Login now