##// END OF EJS Templates
s/assertEquals/assertEqual/
Bradley M. Froehle -
Show More
@@ -1,175 +1,175 b''
1 """
1 """
2 Tests for IPython.config.application.Application
2 Tests for IPython.config.application.Application
3
3
4 Authors:
4 Authors:
5
5
6 * Brian Granger
6 * Brian Granger
7 """
7 """
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Copyright (C) 2008-2011 The IPython Development Team
10 # Copyright (C) 2008-2011 The IPython Development Team
11 #
11 #
12 # Distributed under the terms of the BSD License. The full license is in
12 # Distributed under the terms of the BSD License. The full license is in
13 # the file COPYING, distributed as part of this software.
13 # the file COPYING, distributed as part of this software.
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15
15
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17 # Imports
17 # Imports
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 import logging
20 import logging
21 from unittest import TestCase
21 from unittest import TestCase
22
22
23 from IPython.config.configurable import Configurable
23 from IPython.config.configurable import Configurable
24 from IPython.config.loader import Config
24 from IPython.config.loader import Config
25
25
26 from IPython.config.application import (
26 from IPython.config.application import (
27 Application
27 Application
28 )
28 )
29
29
30 from IPython.utils.traitlets import (
30 from IPython.utils.traitlets import (
31 Bool, Unicode, Integer, Float, List, Dict
31 Bool, Unicode, Integer, Float, List, Dict
32 )
32 )
33
33
34 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
35 # Code
35 # Code
36 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
37
37
38 class Foo(Configurable):
38 class Foo(Configurable):
39
39
40 i = Integer(0, config=True, help="The integer i.")
40 i = Integer(0, config=True, help="The integer i.")
41 j = Integer(1, config=True, help="The integer j.")
41 j = Integer(1, config=True, help="The integer j.")
42 name = Unicode(u'Brian', config=True, help="First name.")
42 name = Unicode(u'Brian', config=True, help="First name.")
43
43
44
44
45 class Bar(Configurable):
45 class Bar(Configurable):
46
46
47 b = Integer(0, config=True, help="The integer b.")
47 b = Integer(0, config=True, help="The integer b.")
48 enabled = Bool(True, config=True, help="Enable bar.")
48 enabled = Bool(True, config=True, help="Enable bar.")
49
49
50
50
51 class MyApp(Application):
51 class MyApp(Application):
52
52
53 name = Unicode(u'myapp')
53 name = Unicode(u'myapp')
54 running = Bool(False, config=True,
54 running = Bool(False, config=True,
55 help="Is the app running?")
55 help="Is the app running?")
56 classes = List([Bar, Foo])
56 classes = List([Bar, Foo])
57 config_file = Unicode(u'', config=True,
57 config_file = Unicode(u'', config=True,
58 help="Load this config file")
58 help="Load this config file")
59
59
60 aliases = Dict({
60 aliases = Dict({
61 'i' : 'Foo.i',
61 'i' : 'Foo.i',
62 'j' : 'Foo.j',
62 'j' : 'Foo.j',
63 'name' : 'Foo.name',
63 'name' : 'Foo.name',
64 'enabled' : 'Bar.enabled',
64 'enabled' : 'Bar.enabled',
65 'log-level' : 'Application.log_level',
65 'log-level' : 'Application.log_level',
66 })
66 })
67
67
68 flags = Dict(dict(enable=({'Bar': {'enabled' : True}}, "Set Bar.enabled to True"),
68 flags = Dict(dict(enable=({'Bar': {'enabled' : True}}, "Set Bar.enabled to True"),
69 disable=({'Bar': {'enabled' : False}}, "Set Bar.enabled to False"),
69 disable=({'Bar': {'enabled' : False}}, "Set Bar.enabled to False"),
70 crit=({'Application' : {'log_level' : logging.CRITICAL}},
70 crit=({'Application' : {'log_level' : logging.CRITICAL}},
71 "set level=CRITICAL"),
71 "set level=CRITICAL"),
72 ))
72 ))
73
73
74 def init_foo(self):
74 def init_foo(self):
75 self.foo = Foo(config=self.config)
75 self.foo = Foo(config=self.config)
76
76
77 def init_bar(self):
77 def init_bar(self):
78 self.bar = Bar(config=self.config)
78 self.bar = Bar(config=self.config)
79
79
80
80
81 class TestApplication(TestCase):
81 class TestApplication(TestCase):
82
82
83 def test_basic(self):
83 def test_basic(self):
84 app = MyApp()
84 app = MyApp()
85 self.assertEquals(app.name, u'myapp')
85 self.assertEqual(app.name, u'myapp')
86 self.assertEquals(app.running, False)
86 self.assertEqual(app.running, False)
87 self.assertEquals(app.classes, [MyApp,Bar,Foo])
87 self.assertEqual(app.classes, [MyApp,Bar,Foo])
88 self.assertEquals(app.config_file, u'')
88 self.assertEqual(app.config_file, u'')
89
89
90 def test_config(self):
90 def test_config(self):
91 app = MyApp()
91 app = MyApp()
92 app.parse_command_line(["--i=10","--Foo.j=10","--enabled=False","--log-level=50"])
92 app.parse_command_line(["--i=10","--Foo.j=10","--enabled=False","--log-level=50"])
93 config = app.config
93 config = app.config
94 self.assertEquals(config.Foo.i, 10)
94 self.assertEqual(config.Foo.i, 10)
95 self.assertEquals(config.Foo.j, 10)
95 self.assertEqual(config.Foo.j, 10)
96 self.assertEquals(config.Bar.enabled, False)
96 self.assertEqual(config.Bar.enabled, False)
97 self.assertEquals(config.MyApp.log_level,50)
97 self.assertEqual(config.MyApp.log_level,50)
98
98
99 def test_config_propagation(self):
99 def test_config_propagation(self):
100 app = MyApp()
100 app = MyApp()
101 app.parse_command_line(["--i=10","--Foo.j=10","--enabled=False","--log-level=50"])
101 app.parse_command_line(["--i=10","--Foo.j=10","--enabled=False","--log-level=50"])
102 app.init_foo()
102 app.init_foo()
103 app.init_bar()
103 app.init_bar()
104 self.assertEquals(app.foo.i, 10)
104 self.assertEqual(app.foo.i, 10)
105 self.assertEquals(app.foo.j, 10)
105 self.assertEqual(app.foo.j, 10)
106 self.assertEquals(app.bar.enabled, False)
106 self.assertEqual(app.bar.enabled, False)
107
107
108 def test_flags(self):
108 def test_flags(self):
109 app = MyApp()
109 app = MyApp()
110 app.parse_command_line(["--disable"])
110 app.parse_command_line(["--disable"])
111 app.init_bar()
111 app.init_bar()
112 self.assertEquals(app.bar.enabled, False)
112 self.assertEqual(app.bar.enabled, False)
113 app.parse_command_line(["--enable"])
113 app.parse_command_line(["--enable"])
114 app.init_bar()
114 app.init_bar()
115 self.assertEquals(app.bar.enabled, True)
115 self.assertEqual(app.bar.enabled, True)
116
116
117 def test_aliases(self):
117 def test_aliases(self):
118 app = MyApp()
118 app = MyApp()
119 app.parse_command_line(["--i=5", "--j=10"])
119 app.parse_command_line(["--i=5", "--j=10"])
120 app.init_foo()
120 app.init_foo()
121 self.assertEquals(app.foo.i, 5)
121 self.assertEqual(app.foo.i, 5)
122 app.init_foo()
122 app.init_foo()
123 self.assertEquals(app.foo.j, 10)
123 self.assertEqual(app.foo.j, 10)
124
124
125 def test_flag_clobber(self):
125 def test_flag_clobber(self):
126 """test that setting flags doesn't clobber existing settings"""
126 """test that setting flags doesn't clobber existing settings"""
127 app = MyApp()
127 app = MyApp()
128 app.parse_command_line(["--Bar.b=5", "--disable"])
128 app.parse_command_line(["--Bar.b=5", "--disable"])
129 app.init_bar()
129 app.init_bar()
130 self.assertEquals(app.bar.enabled, False)
130 self.assertEqual(app.bar.enabled, False)
131 self.assertEquals(app.bar.b, 5)
131 self.assertEqual(app.bar.b, 5)
132 app.parse_command_line(["--enable", "--Bar.b=10"])
132 app.parse_command_line(["--enable", "--Bar.b=10"])
133 app.init_bar()
133 app.init_bar()
134 self.assertEquals(app.bar.enabled, True)
134 self.assertEqual(app.bar.enabled, True)
135 self.assertEquals(app.bar.b, 10)
135 self.assertEqual(app.bar.b, 10)
136
136
137 def test_flatten_flags(self):
137 def test_flatten_flags(self):
138 cfg = Config()
138 cfg = Config()
139 cfg.MyApp.log_level = logging.WARN
139 cfg.MyApp.log_level = logging.WARN
140 app = MyApp()
140 app = MyApp()
141 app.update_config(cfg)
141 app.update_config(cfg)
142 self.assertEquals(app.log_level, logging.WARN)
142 self.assertEqual(app.log_level, logging.WARN)
143 self.assertEquals(app.config.MyApp.log_level, logging.WARN)
143 self.assertEqual(app.config.MyApp.log_level, logging.WARN)
144 app.initialize(["--crit"])
144 app.initialize(["--crit"])
145 self.assertEquals(app.log_level, logging.CRITICAL)
145 self.assertEqual(app.log_level, logging.CRITICAL)
146 # this would be app.config.Application.log_level if it failed:
146 # this would be app.config.Application.log_level if it failed:
147 self.assertEquals(app.config.MyApp.log_level, logging.CRITICAL)
147 self.assertEqual(app.config.MyApp.log_level, logging.CRITICAL)
148
148
149 def test_flatten_aliases(self):
149 def test_flatten_aliases(self):
150 cfg = Config()
150 cfg = Config()
151 cfg.MyApp.log_level = logging.WARN
151 cfg.MyApp.log_level = logging.WARN
152 app = MyApp()
152 app = MyApp()
153 app.update_config(cfg)
153 app.update_config(cfg)
154 self.assertEquals(app.log_level, logging.WARN)
154 self.assertEqual(app.log_level, logging.WARN)
155 self.assertEquals(app.config.MyApp.log_level, logging.WARN)
155 self.assertEqual(app.config.MyApp.log_level, logging.WARN)
156 app.initialize(["--log-level", "CRITICAL"])
156 app.initialize(["--log-level", "CRITICAL"])
157 self.assertEquals(app.log_level, logging.CRITICAL)
157 self.assertEqual(app.log_level, logging.CRITICAL)
158 # this would be app.config.Application.log_level if it failed:
158 # this would be app.config.Application.log_level if it failed:
159 self.assertEquals(app.config.MyApp.log_level, "CRITICAL")
159 self.assertEqual(app.config.MyApp.log_level, "CRITICAL")
160
160
161 def test_extra_args(self):
161 def test_extra_args(self):
162 app = MyApp()
162 app = MyApp()
163 app.parse_command_line(["--Bar.b=5", 'extra', "--disable", 'args'])
163 app.parse_command_line(["--Bar.b=5", 'extra', "--disable", 'args'])
164 app.init_bar()
164 app.init_bar()
165 self.assertEquals(app.bar.enabled, False)
165 self.assertEqual(app.bar.enabled, False)
166 self.assertEquals(app.bar.b, 5)
166 self.assertEqual(app.bar.b, 5)
167 self.assertEquals(app.extra_args, ['extra', 'args'])
167 self.assertEqual(app.extra_args, ['extra', 'args'])
168 app = MyApp()
168 app = MyApp()
169 app.parse_command_line(["--Bar.b=5", '--', 'extra', "--disable", 'args'])
169 app.parse_command_line(["--Bar.b=5", '--', 'extra', "--disable", 'args'])
170 app.init_bar()
170 app.init_bar()
171 self.assertEquals(app.bar.enabled, True)
171 self.assertEqual(app.bar.enabled, True)
172 self.assertEquals(app.bar.b, 5)
172 self.assertEqual(app.bar.b, 5)
173 self.assertEquals(app.extra_args, ['extra', '--disable', 'args'])
173 self.assertEqual(app.extra_args, ['extra', '--disable', 'args'])
174
174
175
175
@@ -1,183 +1,183 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Tests for IPython.config.configurable
3 Tests for IPython.config.configurable
4
4
5 Authors:
5 Authors:
6
6
7 * Brian Granger
7 * Brian Granger
8 * Fernando Perez (design help)
8 * Fernando Perez (design help)
9 """
9 """
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Copyright (C) 2008-2011 The IPython Development Team
12 # Copyright (C) 2008-2011 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19 # Imports
19 # Imports
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21
21
22 from unittest import TestCase
22 from unittest import TestCase
23
23
24 from IPython.config.configurable import (
24 from IPython.config.configurable import (
25 Configurable,
25 Configurable,
26 SingletonConfigurable
26 SingletonConfigurable
27 )
27 )
28
28
29 from IPython.utils.traitlets import (
29 from IPython.utils.traitlets import (
30 Integer, Float, Unicode
30 Integer, Float, Unicode
31 )
31 )
32
32
33 from IPython.config.loader import Config
33 from IPython.config.loader import Config
34 from IPython.utils.py3compat import PY3
34 from IPython.utils.py3compat import PY3
35
35
36 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
37 # Test cases
37 # Test cases
38 #-----------------------------------------------------------------------------
38 #-----------------------------------------------------------------------------
39
39
40
40
41 class MyConfigurable(Configurable):
41 class MyConfigurable(Configurable):
42 a = Integer(1, config=True, help="The integer a.")
42 a = Integer(1, config=True, help="The integer a.")
43 b = Float(1.0, config=True, help="The integer b.")
43 b = Float(1.0, config=True, help="The integer b.")
44 c = Unicode('no config')
44 c = Unicode('no config')
45
45
46
46
47 mc_help=u"""MyConfigurable options
47 mc_help=u"""MyConfigurable options
48 ----------------------
48 ----------------------
49 --MyConfigurable.a=<Integer>
49 --MyConfigurable.a=<Integer>
50 Default: 1
50 Default: 1
51 The integer a.
51 The integer a.
52 --MyConfigurable.b=<Float>
52 --MyConfigurable.b=<Float>
53 Default: 1.0
53 Default: 1.0
54 The integer b."""
54 The integer b."""
55
55
56 mc_help_inst=u"""MyConfigurable options
56 mc_help_inst=u"""MyConfigurable options
57 ----------------------
57 ----------------------
58 --MyConfigurable.a=<Integer>
58 --MyConfigurable.a=<Integer>
59 Current: 5
59 Current: 5
60 The integer a.
60 The integer a.
61 --MyConfigurable.b=<Float>
61 --MyConfigurable.b=<Float>
62 Current: 4.0
62 Current: 4.0
63 The integer b."""
63 The integer b."""
64
64
65 # On Python 3, the Integer trait is a synonym for Int
65 # On Python 3, the Integer trait is a synonym for Int
66 if PY3:
66 if PY3:
67 mc_help = mc_help.replace(u"<Integer>", u"<Int>")
67 mc_help = mc_help.replace(u"<Integer>", u"<Int>")
68 mc_help_inst = mc_help_inst.replace(u"<Integer>", u"<Int>")
68 mc_help_inst = mc_help_inst.replace(u"<Integer>", u"<Int>")
69
69
70 class Foo(Configurable):
70 class Foo(Configurable):
71 a = Integer(0, config=True, help="The integer a.")
71 a = Integer(0, config=True, help="The integer a.")
72 b = Unicode('nope', config=True)
72 b = Unicode('nope', config=True)
73
73
74
74
75 class Bar(Foo):
75 class Bar(Foo):
76 b = Unicode('gotit', config=False, help="The string b.")
76 b = Unicode('gotit', config=False, help="The string b.")
77 c = Float(config=True, help="The string c.")
77 c = Float(config=True, help="The string c.")
78
78
79
79
80 class TestConfigurable(TestCase):
80 class TestConfigurable(TestCase):
81
81
82 def test_default(self):
82 def test_default(self):
83 c1 = Configurable()
83 c1 = Configurable()
84 c2 = Configurable(config=c1.config)
84 c2 = Configurable(config=c1.config)
85 c3 = Configurable(config=c2.config)
85 c3 = Configurable(config=c2.config)
86 self.assertEquals(c1.config, c2.config)
86 self.assertEqual(c1.config, c2.config)
87 self.assertEquals(c2.config, c3.config)
87 self.assertEqual(c2.config, c3.config)
88
88
89 def test_custom(self):
89 def test_custom(self):
90 config = Config()
90 config = Config()
91 config.foo = 'foo'
91 config.foo = 'foo'
92 config.bar = 'bar'
92 config.bar = 'bar'
93 c1 = Configurable(config=config)
93 c1 = Configurable(config=config)
94 c2 = Configurable(config=c1.config)
94 c2 = Configurable(config=c1.config)
95 c3 = Configurable(config=c2.config)
95 c3 = Configurable(config=c2.config)
96 self.assertEquals(c1.config, config)
96 self.assertEqual(c1.config, config)
97 self.assertEquals(c2.config, config)
97 self.assertEqual(c2.config, config)
98 self.assertEquals(c3.config, config)
98 self.assertEqual(c3.config, config)
99 # Test that copies are not made
99 # Test that copies are not made
100 self.assert_(c1.config is config)
100 self.assert_(c1.config is config)
101 self.assert_(c2.config is config)
101 self.assert_(c2.config is config)
102 self.assert_(c3.config is config)
102 self.assert_(c3.config is config)
103 self.assert_(c1.config is c2.config)
103 self.assert_(c1.config is c2.config)
104 self.assert_(c2.config is c3.config)
104 self.assert_(c2.config is c3.config)
105
105
106 def test_inheritance(self):
106 def test_inheritance(self):
107 config = Config()
107 config = Config()
108 config.MyConfigurable.a = 2
108 config.MyConfigurable.a = 2
109 config.MyConfigurable.b = 2.0
109 config.MyConfigurable.b = 2.0
110 c1 = MyConfigurable(config=config)
110 c1 = MyConfigurable(config=config)
111 c2 = MyConfigurable(config=c1.config)
111 c2 = MyConfigurable(config=c1.config)
112 self.assertEquals(c1.a, config.MyConfigurable.a)
112 self.assertEqual(c1.a, config.MyConfigurable.a)
113 self.assertEquals(c1.b, config.MyConfigurable.b)
113 self.assertEqual(c1.b, config.MyConfigurable.b)
114 self.assertEquals(c2.a, config.MyConfigurable.a)
114 self.assertEqual(c2.a, config.MyConfigurable.a)
115 self.assertEquals(c2.b, config.MyConfigurable.b)
115 self.assertEqual(c2.b, config.MyConfigurable.b)
116
116
117 def test_parent(self):
117 def test_parent(self):
118 config = Config()
118 config = Config()
119 config.Foo.a = 10
119 config.Foo.a = 10
120 config.Foo.b = "wow"
120 config.Foo.b = "wow"
121 config.Bar.b = 'later'
121 config.Bar.b = 'later'
122 config.Bar.c = 100.0
122 config.Bar.c = 100.0
123 f = Foo(config=config)
123 f = Foo(config=config)
124 b = Bar(config=f.config)
124 b = Bar(config=f.config)
125 self.assertEquals(f.a, 10)
125 self.assertEqual(f.a, 10)
126 self.assertEquals(f.b, 'wow')
126 self.assertEqual(f.b, 'wow')
127 self.assertEquals(b.b, 'gotit')
127 self.assertEqual(b.b, 'gotit')
128 self.assertEquals(b.c, 100.0)
128 self.assertEqual(b.c, 100.0)
129
129
130 def test_override1(self):
130 def test_override1(self):
131 config = Config()
131 config = Config()
132 config.MyConfigurable.a = 2
132 config.MyConfigurable.a = 2
133 config.MyConfigurable.b = 2.0
133 config.MyConfigurable.b = 2.0
134 c = MyConfigurable(a=3, config=config)
134 c = MyConfigurable(a=3, config=config)
135 self.assertEquals(c.a, 3)
135 self.assertEqual(c.a, 3)
136 self.assertEquals(c.b, config.MyConfigurable.b)
136 self.assertEqual(c.b, config.MyConfigurable.b)
137 self.assertEquals(c.c, 'no config')
137 self.assertEqual(c.c, 'no config')
138
138
139 def test_override2(self):
139 def test_override2(self):
140 config = Config()
140 config = Config()
141 config.Foo.a = 1
141 config.Foo.a = 1
142 config.Bar.b = 'or' # Up above b is config=False, so this won't do it.
142 config.Bar.b = 'or' # Up above b is config=False, so this won't do it.
143 config.Bar.c = 10.0
143 config.Bar.c = 10.0
144 c = Bar(config=config)
144 c = Bar(config=config)
145 self.assertEquals(c.a, config.Foo.a)
145 self.assertEqual(c.a, config.Foo.a)
146 self.assertEquals(c.b, 'gotit')
146 self.assertEqual(c.b, 'gotit')
147 self.assertEquals(c.c, config.Bar.c)
147 self.assertEqual(c.c, config.Bar.c)
148 c = Bar(a=2, b='and', c=20.0, config=config)
148 c = Bar(a=2, b='and', c=20.0, config=config)
149 self.assertEquals(c.a, 2)
149 self.assertEqual(c.a, 2)
150 self.assertEquals(c.b, 'and')
150 self.assertEqual(c.b, 'and')
151 self.assertEquals(c.c, 20.0)
151 self.assertEqual(c.c, 20.0)
152
152
153 def test_help(self):
153 def test_help(self):
154 self.assertEquals(MyConfigurable.class_get_help(), mc_help)
154 self.assertEqual(MyConfigurable.class_get_help(), mc_help)
155
155
156 def test_help_inst(self):
156 def test_help_inst(self):
157 inst = MyConfigurable(a=5, b=4)
157 inst = MyConfigurable(a=5, b=4)
158 self.assertEquals(MyConfigurable.class_get_help(inst), mc_help_inst)
158 self.assertEqual(MyConfigurable.class_get_help(inst), mc_help_inst)
159
159
160
160
161 class TestSingletonConfigurable(TestCase):
161 class TestSingletonConfigurable(TestCase):
162
162
163 def test_instance(self):
163 def test_instance(self):
164 from IPython.config.configurable import SingletonConfigurable
164 from IPython.config.configurable import SingletonConfigurable
165 class Foo(SingletonConfigurable): pass
165 class Foo(SingletonConfigurable): pass
166 self.assertEquals(Foo.initialized(), False)
166 self.assertEqual(Foo.initialized(), False)
167 foo = Foo.instance()
167 foo = Foo.instance()
168 self.assertEquals(Foo.initialized(), True)
168 self.assertEqual(Foo.initialized(), True)
169 self.assertEquals(foo, Foo.instance())
169 self.assertEqual(foo, Foo.instance())
170 self.assertEquals(SingletonConfigurable._instance, None)
170 self.assertEqual(SingletonConfigurable._instance, None)
171
171
172 def test_inheritance(self):
172 def test_inheritance(self):
173 class Bar(SingletonConfigurable): pass
173 class Bar(SingletonConfigurable): pass
174 class Bam(Bar): pass
174 class Bam(Bar): pass
175 self.assertEquals(Bar.initialized(), False)
175 self.assertEqual(Bar.initialized(), False)
176 self.assertEquals(Bam.initialized(), False)
176 self.assertEqual(Bam.initialized(), False)
177 bam = Bam.instance()
177 bam = Bam.instance()
178 bam == Bar.instance()
178 bam == Bar.instance()
179 self.assertEquals(Bar.initialized(), True)
179 self.assertEqual(Bar.initialized(), True)
180 self.assertEquals(Bam.initialized(), True)
180 self.assertEqual(Bam.initialized(), True)
181 self.assertEquals(bam, Bam._instance)
181 self.assertEqual(bam, Bam._instance)
182 self.assertEquals(bam, Bar._instance)
182 self.assertEqual(bam, Bar._instance)
183 self.assertEquals(SingletonConfigurable._instance, None)
183 self.assertEqual(SingletonConfigurable._instance, None)
@@ -1,263 +1,263 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Tests for IPython.config.loader
3 Tests for IPython.config.loader
4
4
5 Authors:
5 Authors:
6
6
7 * Brian Granger
7 * Brian Granger
8 * Fernando Perez (design help)
8 * Fernando Perez (design help)
9 """
9 """
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Copyright (C) 2008-2011 The IPython Development Team
12 # Copyright (C) 2008-2011 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19 # Imports
19 # Imports
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21
21
22 import os
22 import os
23 import sys
23 import sys
24 from tempfile import mkstemp
24 from tempfile import mkstemp
25 from unittest import TestCase
25 from unittest import TestCase
26
26
27 from nose import SkipTest
27 from nose import SkipTest
28
28
29 from IPython.testing.tools import mute_warn
29 from IPython.testing.tools import mute_warn
30
30
31 from IPython.utils.traitlets import Unicode
31 from IPython.utils.traitlets import Unicode
32 from IPython.config.configurable import Configurable
32 from IPython.config.configurable import Configurable
33 from IPython.config.loader import (
33 from IPython.config.loader import (
34 Config,
34 Config,
35 PyFileConfigLoader,
35 PyFileConfigLoader,
36 KeyValueConfigLoader,
36 KeyValueConfigLoader,
37 ArgParseConfigLoader,
37 ArgParseConfigLoader,
38 KVArgParseConfigLoader,
38 KVArgParseConfigLoader,
39 ConfigError
39 ConfigError
40 )
40 )
41
41
42 #-----------------------------------------------------------------------------
42 #-----------------------------------------------------------------------------
43 # Actual tests
43 # Actual tests
44 #-----------------------------------------------------------------------------
44 #-----------------------------------------------------------------------------
45
45
46
46
47 pyfile = """
47 pyfile = """
48 c = get_config()
48 c = get_config()
49 c.a=10
49 c.a=10
50 c.b=20
50 c.b=20
51 c.Foo.Bar.value=10
51 c.Foo.Bar.value=10
52 c.Foo.Bam.value=list(range(10)) # list() is just so it's the same on Python 3
52 c.Foo.Bam.value=list(range(10)) # list() is just so it's the same on Python 3
53 c.D.C.value='hi there'
53 c.D.C.value='hi there'
54 """
54 """
55
55
56 class TestPyFileCL(TestCase):
56 class TestPyFileCL(TestCase):
57
57
58 def test_basic(self):
58 def test_basic(self):
59 fd, fname = mkstemp('.py')
59 fd, fname = mkstemp('.py')
60 f = os.fdopen(fd, 'w')
60 f = os.fdopen(fd, 'w')
61 f.write(pyfile)
61 f.write(pyfile)
62 f.close()
62 f.close()
63 # Unlink the file
63 # Unlink the file
64 cl = PyFileConfigLoader(fname)
64 cl = PyFileConfigLoader(fname)
65 config = cl.load_config()
65 config = cl.load_config()
66 self.assertEquals(config.a, 10)
66 self.assertEqual(config.a, 10)
67 self.assertEquals(config.b, 20)
67 self.assertEqual(config.b, 20)
68 self.assertEquals(config.Foo.Bar.value, 10)
68 self.assertEqual(config.Foo.Bar.value, 10)
69 self.assertEquals(config.Foo.Bam.value, range(10))
69 self.assertEqual(config.Foo.Bam.value, range(10))
70 self.assertEquals(config.D.C.value, 'hi there')
70 self.assertEqual(config.D.C.value, 'hi there')
71
71
72 class MyLoader1(ArgParseConfigLoader):
72 class MyLoader1(ArgParseConfigLoader):
73 def _add_arguments(self, aliases=None, flags=None):
73 def _add_arguments(self, aliases=None, flags=None):
74 p = self.parser
74 p = self.parser
75 p.add_argument('-f', '--foo', dest='Global.foo', type=str)
75 p.add_argument('-f', '--foo', dest='Global.foo', type=str)
76 p.add_argument('-b', dest='MyClass.bar', type=int)
76 p.add_argument('-b', dest='MyClass.bar', type=int)
77 p.add_argument('-n', dest='n', action='store_true')
77 p.add_argument('-n', dest='n', action='store_true')
78 p.add_argument('Global.bam', type=str)
78 p.add_argument('Global.bam', type=str)
79
79
80 class MyLoader2(ArgParseConfigLoader):
80 class MyLoader2(ArgParseConfigLoader):
81 def _add_arguments(self, aliases=None, flags=None):
81 def _add_arguments(self, aliases=None, flags=None):
82 subparsers = self.parser.add_subparsers(dest='subparser_name')
82 subparsers = self.parser.add_subparsers(dest='subparser_name')
83 subparser1 = subparsers.add_parser('1')
83 subparser1 = subparsers.add_parser('1')
84 subparser1.add_argument('-x',dest='Global.x')
84 subparser1.add_argument('-x',dest='Global.x')
85 subparser2 = subparsers.add_parser('2')
85 subparser2 = subparsers.add_parser('2')
86 subparser2.add_argument('y')
86 subparser2.add_argument('y')
87
87
88 class TestArgParseCL(TestCase):
88 class TestArgParseCL(TestCase):
89
89
90 def test_basic(self):
90 def test_basic(self):
91 cl = MyLoader1()
91 cl = MyLoader1()
92 config = cl.load_config('-f hi -b 10 -n wow'.split())
92 config = cl.load_config('-f hi -b 10 -n wow'.split())
93 self.assertEquals(config.Global.foo, 'hi')
93 self.assertEqual(config.Global.foo, 'hi')
94 self.assertEquals(config.MyClass.bar, 10)
94 self.assertEqual(config.MyClass.bar, 10)
95 self.assertEquals(config.n, True)
95 self.assertEqual(config.n, True)
96 self.assertEquals(config.Global.bam, 'wow')
96 self.assertEqual(config.Global.bam, 'wow')
97 config = cl.load_config(['wow'])
97 config = cl.load_config(['wow'])
98 self.assertEquals(config.keys(), ['Global'])
98 self.assertEqual(config.keys(), ['Global'])
99 self.assertEquals(config.Global.keys(), ['bam'])
99 self.assertEqual(config.Global.keys(), ['bam'])
100 self.assertEquals(config.Global.bam, 'wow')
100 self.assertEqual(config.Global.bam, 'wow')
101
101
102 def test_add_arguments(self):
102 def test_add_arguments(self):
103 cl = MyLoader2()
103 cl = MyLoader2()
104 config = cl.load_config('2 frobble'.split())
104 config = cl.load_config('2 frobble'.split())
105 self.assertEquals(config.subparser_name, '2')
105 self.assertEqual(config.subparser_name, '2')
106 self.assertEquals(config.y, 'frobble')
106 self.assertEqual(config.y, 'frobble')
107 config = cl.load_config('1 -x frobble'.split())
107 config = cl.load_config('1 -x frobble'.split())
108 self.assertEquals(config.subparser_name, '1')
108 self.assertEqual(config.subparser_name, '1')
109 self.assertEquals(config.Global.x, 'frobble')
109 self.assertEqual(config.Global.x, 'frobble')
110
110
111 def test_argv(self):
111 def test_argv(self):
112 cl = MyLoader1(argv='-f hi -b 10 -n wow'.split())
112 cl = MyLoader1(argv='-f hi -b 10 -n wow'.split())
113 config = cl.load_config()
113 config = cl.load_config()
114 self.assertEquals(config.Global.foo, 'hi')
114 self.assertEqual(config.Global.foo, 'hi')
115 self.assertEquals(config.MyClass.bar, 10)
115 self.assertEqual(config.MyClass.bar, 10)
116 self.assertEquals(config.n, True)
116 self.assertEqual(config.n, True)
117 self.assertEquals(config.Global.bam, 'wow')
117 self.assertEqual(config.Global.bam, 'wow')
118
118
119
119
120 class TestKeyValueCL(TestCase):
120 class TestKeyValueCL(TestCase):
121 klass = KeyValueConfigLoader
121 klass = KeyValueConfigLoader
122
122
123 def test_basic(self):
123 def test_basic(self):
124 cl = self.klass()
124 cl = self.klass()
125 argv = ['--'+s.strip('c.') for s in pyfile.split('\n')[2:-1]]
125 argv = ['--'+s.strip('c.') for s in pyfile.split('\n')[2:-1]]
126 with mute_warn():
126 with mute_warn():
127 config = cl.load_config(argv)
127 config = cl.load_config(argv)
128 self.assertEquals(config.a, 10)
128 self.assertEqual(config.a, 10)
129 self.assertEquals(config.b, 20)
129 self.assertEqual(config.b, 20)
130 self.assertEquals(config.Foo.Bar.value, 10)
130 self.assertEqual(config.Foo.Bar.value, 10)
131 self.assertEquals(config.Foo.Bam.value, range(10))
131 self.assertEqual(config.Foo.Bam.value, range(10))
132 self.assertEquals(config.D.C.value, 'hi there')
132 self.assertEqual(config.D.C.value, 'hi there')
133
133
134 def test_expanduser(self):
134 def test_expanduser(self):
135 cl = self.klass()
135 cl = self.klass()
136 argv = ['--a=~/1/2/3', '--b=~', '--c=~/', '--d="~/"']
136 argv = ['--a=~/1/2/3', '--b=~', '--c=~/', '--d="~/"']
137 with mute_warn():
137 with mute_warn():
138 config = cl.load_config(argv)
138 config = cl.load_config(argv)
139 self.assertEquals(config.a, os.path.expanduser('~/1/2/3'))
139 self.assertEqual(config.a, os.path.expanduser('~/1/2/3'))
140 self.assertEquals(config.b, os.path.expanduser('~'))
140 self.assertEqual(config.b, os.path.expanduser('~'))
141 self.assertEquals(config.c, os.path.expanduser('~/'))
141 self.assertEqual(config.c, os.path.expanduser('~/'))
142 self.assertEquals(config.d, '~/')
142 self.assertEqual(config.d, '~/')
143
143
144 def test_extra_args(self):
144 def test_extra_args(self):
145 cl = self.klass()
145 cl = self.klass()
146 with mute_warn():
146 with mute_warn():
147 config = cl.load_config(['--a=5', 'b', '--c=10', 'd'])
147 config = cl.load_config(['--a=5', 'b', '--c=10', 'd'])
148 self.assertEquals(cl.extra_args, ['b', 'd'])
148 self.assertEqual(cl.extra_args, ['b', 'd'])
149 self.assertEquals(config.a, 5)
149 self.assertEqual(config.a, 5)
150 self.assertEquals(config.c, 10)
150 self.assertEqual(config.c, 10)
151 with mute_warn():
151 with mute_warn():
152 config = cl.load_config(['--', '--a=5', '--c=10'])
152 config = cl.load_config(['--', '--a=5', '--c=10'])
153 self.assertEquals(cl.extra_args, ['--a=5', '--c=10'])
153 self.assertEqual(cl.extra_args, ['--a=5', '--c=10'])
154
154
155 def test_unicode_args(self):
155 def test_unicode_args(self):
156 cl = self.klass()
156 cl = self.klass()
157 argv = [u'--a=épsîlön']
157 argv = [u'--a=épsîlön']
158 with mute_warn():
158 with mute_warn():
159 config = cl.load_config(argv)
159 config = cl.load_config(argv)
160 self.assertEquals(config.a, u'épsîlön')
160 self.assertEqual(config.a, u'épsîlön')
161
161
162 def test_unicode_bytes_args(self):
162 def test_unicode_bytes_args(self):
163 uarg = u'--a=é'
163 uarg = u'--a=é'
164 try:
164 try:
165 barg = uarg.encode(sys.stdin.encoding)
165 barg = uarg.encode(sys.stdin.encoding)
166 except (TypeError, UnicodeEncodeError):
166 except (TypeError, UnicodeEncodeError):
167 raise SkipTest("sys.stdin.encoding can't handle 'é'")
167 raise SkipTest("sys.stdin.encoding can't handle 'é'")
168
168
169 cl = self.klass()
169 cl = self.klass()
170 with mute_warn():
170 with mute_warn():
171 config = cl.load_config([barg])
171 config = cl.load_config([barg])
172 self.assertEquals(config.a, u'é')
172 self.assertEqual(config.a, u'é')
173
173
174 def test_unicode_alias(self):
174 def test_unicode_alias(self):
175 cl = self.klass()
175 cl = self.klass()
176 argv = [u'--a=épsîlön']
176 argv = [u'--a=épsîlön']
177 with mute_warn():
177 with mute_warn():
178 config = cl.load_config(argv, aliases=dict(a='A.a'))
178 config = cl.load_config(argv, aliases=dict(a='A.a'))
179 self.assertEquals(config.A.a, u'épsîlön')
179 self.assertEqual(config.A.a, u'épsîlön')
180
180
181
181
182 class TestArgParseKVCL(TestKeyValueCL):
182 class TestArgParseKVCL(TestKeyValueCL):
183 klass = KVArgParseConfigLoader
183 klass = KVArgParseConfigLoader
184
184
185 def test_expanduser2(self):
185 def test_expanduser2(self):
186 cl = self.klass()
186 cl = self.klass()
187 argv = ['-a', '~/1/2/3', '--b', "'~/1/2/3'"]
187 argv = ['-a', '~/1/2/3', '--b', "'~/1/2/3'"]
188 with mute_warn():
188 with mute_warn():
189 config = cl.load_config(argv, aliases=dict(a='A.a', b='A.b'))
189 config = cl.load_config(argv, aliases=dict(a='A.a', b='A.b'))
190 self.assertEquals(config.A.a, os.path.expanduser('~/1/2/3'))
190 self.assertEqual(config.A.a, os.path.expanduser('~/1/2/3'))
191 self.assertEquals(config.A.b, '~/1/2/3')
191 self.assertEqual(config.A.b, '~/1/2/3')
192
192
193 def test_eval(self):
193 def test_eval(self):
194 cl = self.klass()
194 cl = self.klass()
195 argv = ['-c', 'a=5']
195 argv = ['-c', 'a=5']
196 with mute_warn():
196 with mute_warn():
197 config = cl.load_config(argv, aliases=dict(c='A.c'))
197 config = cl.load_config(argv, aliases=dict(c='A.c'))
198 self.assertEquals(config.A.c, u"a=5")
198 self.assertEqual(config.A.c, u"a=5")
199
199
200
200
201 class TestConfig(TestCase):
201 class TestConfig(TestCase):
202
202
203 def test_setget(self):
203 def test_setget(self):
204 c = Config()
204 c = Config()
205 c.a = 10
205 c.a = 10
206 self.assertEquals(c.a, 10)
206 self.assertEqual(c.a, 10)
207 self.assertEquals('b' in c, False)
207 self.assertEqual('b' in c, False)
208
208
209 def test_auto_section(self):
209 def test_auto_section(self):
210 c = Config()
210 c = Config()
211 self.assertEquals('A' in c, True)
211 self.assertEqual('A' in c, True)
212 self.assertEquals(c._has_section('A'), False)
212 self.assertEqual(c._has_section('A'), False)
213 A = c.A
213 A = c.A
214 A.foo = 'hi there'
214 A.foo = 'hi there'
215 self.assertEquals(c._has_section('A'), True)
215 self.assertEqual(c._has_section('A'), True)
216 self.assertEquals(c.A.foo, 'hi there')
216 self.assertEqual(c.A.foo, 'hi there')
217 del c.A
217 del c.A
218 self.assertEquals(len(c.A.keys()),0)
218 self.assertEqual(len(c.A.keys()),0)
219
219
220 def test_merge_doesnt_exist(self):
220 def test_merge_doesnt_exist(self):
221 c1 = Config()
221 c1 = Config()
222 c2 = Config()
222 c2 = Config()
223 c2.bar = 10
223 c2.bar = 10
224 c2.Foo.bar = 10
224 c2.Foo.bar = 10
225 c1._merge(c2)
225 c1._merge(c2)
226 self.assertEquals(c1.Foo.bar, 10)
226 self.assertEqual(c1.Foo.bar, 10)
227 self.assertEquals(c1.bar, 10)
227 self.assertEqual(c1.bar, 10)
228 c2.Bar.bar = 10
228 c2.Bar.bar = 10
229 c1._merge(c2)
229 c1._merge(c2)
230 self.assertEquals(c1.Bar.bar, 10)
230 self.assertEqual(c1.Bar.bar, 10)
231
231
232 def test_merge_exists(self):
232 def test_merge_exists(self):
233 c1 = Config()
233 c1 = Config()
234 c2 = Config()
234 c2 = Config()
235 c1.Foo.bar = 10
235 c1.Foo.bar = 10
236 c1.Foo.bam = 30
236 c1.Foo.bam = 30
237 c2.Foo.bar = 20
237 c2.Foo.bar = 20
238 c2.Foo.wow = 40
238 c2.Foo.wow = 40
239 c1._merge(c2)
239 c1._merge(c2)
240 self.assertEquals(c1.Foo.bam, 30)
240 self.assertEqual(c1.Foo.bam, 30)
241 self.assertEquals(c1.Foo.bar, 20)
241 self.assertEqual(c1.Foo.bar, 20)
242 self.assertEquals(c1.Foo.wow, 40)
242 self.assertEqual(c1.Foo.wow, 40)
243 c2.Foo.Bam.bam = 10
243 c2.Foo.Bam.bam = 10
244 c1._merge(c2)
244 c1._merge(c2)
245 self.assertEquals(c1.Foo.Bam.bam, 10)
245 self.assertEqual(c1.Foo.Bam.bam, 10)
246
246
247 def test_deepcopy(self):
247 def test_deepcopy(self):
248 c1 = Config()
248 c1 = Config()
249 c1.Foo.bar = 10
249 c1.Foo.bar = 10
250 c1.Foo.bam = 30
250 c1.Foo.bam = 30
251 c1.a = 'asdf'
251 c1.a = 'asdf'
252 c1.b = range(10)
252 c1.b = range(10)
253 import copy
253 import copy
254 c2 = copy.deepcopy(c1)
254 c2 = copy.deepcopy(c1)
255 self.assertEquals(c1, c2)
255 self.assertEqual(c1, c2)
256 self.assert_(c1 is not c2)
256 self.assert_(c1 is not c2)
257 self.assert_(c1.Foo is not c2.Foo)
257 self.assert_(c1.Foo is not c2.Foo)
258
258
259 def test_builtin(self):
259 def test_builtin(self):
260 c1 = Config()
260 c1 = Config()
261 exec 'foo = True' in c1
261 exec 'foo = True' in c1
262 self.assertEquals(c1.foo, True)
262 self.assertEqual(c1.foo, True)
263 self.assertRaises(ConfigError, setattr, c1, 'ValueError', 10)
263 self.assertRaises(ConfigError, setattr, c1, 'ValueError', 10)
@@ -1,407 +1,407 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Tests for the key interactiveshell module.
2 """Tests for the key interactiveshell module.
3
3
4 Historically the main classes in interactiveshell have been under-tested. This
4 Historically the main classes in interactiveshell have been under-tested. This
5 module should grow as many single-method tests as possible to trap many of the
5 module should grow as many single-method tests as possible to trap many of the
6 recurring bugs we seem to encounter with high-level interaction.
6 recurring bugs we seem to encounter with high-level interaction.
7
7
8 Authors
8 Authors
9 -------
9 -------
10 * Fernando Perez
10 * Fernando Perez
11 """
11 """
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # Copyright (C) 2011 The IPython Development Team
13 # Copyright (C) 2011 The IPython Development Team
14 #
14 #
15 # Distributed under the terms of the BSD License. The full license is in
15 # Distributed under the terms of the BSD License. The full license is in
16 # the file COPYING, distributed as part of this software.
16 # the file COPYING, distributed as part of this software.
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18
18
19 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
20 # Imports
20 # Imports
21 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
22 # stdlib
22 # stdlib
23 import os
23 import os
24 import shutil
24 import shutil
25 import sys
25 import sys
26 import tempfile
26 import tempfile
27 import unittest
27 import unittest
28 from os.path import join
28 from os.path import join
29 from StringIO import StringIO
29 from StringIO import StringIO
30
30
31 # third-party
31 # third-party
32 import nose.tools as nt
32 import nose.tools as nt
33
33
34 # Our own
34 # Our own
35 from IPython.testing.decorators import skipif
35 from IPython.testing.decorators import skipif
36 from IPython.utils import io
36 from IPython.utils import io
37
37
38 #-----------------------------------------------------------------------------
38 #-----------------------------------------------------------------------------
39 # Globals
39 # Globals
40 #-----------------------------------------------------------------------------
40 #-----------------------------------------------------------------------------
41 # This is used by every single test, no point repeating it ad nauseam
41 # This is used by every single test, no point repeating it ad nauseam
42 ip = get_ipython()
42 ip = get_ipython()
43
43
44 #-----------------------------------------------------------------------------
44 #-----------------------------------------------------------------------------
45 # Tests
45 # Tests
46 #-----------------------------------------------------------------------------
46 #-----------------------------------------------------------------------------
47
47
48 class InteractiveShellTestCase(unittest.TestCase):
48 class InteractiveShellTestCase(unittest.TestCase):
49 def test_naked_string_cells(self):
49 def test_naked_string_cells(self):
50 """Test that cells with only naked strings are fully executed"""
50 """Test that cells with only naked strings are fully executed"""
51 # First, single-line inputs
51 # First, single-line inputs
52 ip.run_cell('"a"\n')
52 ip.run_cell('"a"\n')
53 self.assertEquals(ip.user_ns['_'], 'a')
53 self.assertEqual(ip.user_ns['_'], 'a')
54 # And also multi-line cells
54 # And also multi-line cells
55 ip.run_cell('"""a\nb"""\n')
55 ip.run_cell('"""a\nb"""\n')
56 self.assertEquals(ip.user_ns['_'], 'a\nb')
56 self.assertEqual(ip.user_ns['_'], 'a\nb')
57
57
58 def test_run_empty_cell(self):
58 def test_run_empty_cell(self):
59 """Just make sure we don't get a horrible error with a blank
59 """Just make sure we don't get a horrible error with a blank
60 cell of input. Yes, I did overlook that."""
60 cell of input. Yes, I did overlook that."""
61 old_xc = ip.execution_count
61 old_xc = ip.execution_count
62 ip.run_cell('')
62 ip.run_cell('')
63 self.assertEquals(ip.execution_count, old_xc)
63 self.assertEqual(ip.execution_count, old_xc)
64
64
65 def test_run_cell_multiline(self):
65 def test_run_cell_multiline(self):
66 """Multi-block, multi-line cells must execute correctly.
66 """Multi-block, multi-line cells must execute correctly.
67 """
67 """
68 src = '\n'.join(["x=1",
68 src = '\n'.join(["x=1",
69 "y=2",
69 "y=2",
70 "if 1:",
70 "if 1:",
71 " x += 1",
71 " x += 1",
72 " y += 1",])
72 " y += 1",])
73 ip.run_cell(src)
73 ip.run_cell(src)
74 self.assertEquals(ip.user_ns['x'], 2)
74 self.assertEqual(ip.user_ns['x'], 2)
75 self.assertEquals(ip.user_ns['y'], 3)
75 self.assertEqual(ip.user_ns['y'], 3)
76
76
77 def test_multiline_string_cells(self):
77 def test_multiline_string_cells(self):
78 "Code sprinkled with multiline strings should execute (GH-306)"
78 "Code sprinkled with multiline strings should execute (GH-306)"
79 ip.run_cell('tmp=0')
79 ip.run_cell('tmp=0')
80 self.assertEquals(ip.user_ns['tmp'], 0)
80 self.assertEqual(ip.user_ns['tmp'], 0)
81 ip.run_cell('tmp=1;"""a\nb"""\n')
81 ip.run_cell('tmp=1;"""a\nb"""\n')
82 self.assertEquals(ip.user_ns['tmp'], 1)
82 self.assertEqual(ip.user_ns['tmp'], 1)
83
83
84 def test_dont_cache_with_semicolon(self):
84 def test_dont_cache_with_semicolon(self):
85 "Ending a line with semicolon should not cache the returned object (GH-307)"
85 "Ending a line with semicolon should not cache the returned object (GH-307)"
86 oldlen = len(ip.user_ns['Out'])
86 oldlen = len(ip.user_ns['Out'])
87 a = ip.run_cell('1;', store_history=True)
87 a = ip.run_cell('1;', store_history=True)
88 newlen = len(ip.user_ns['Out'])
88 newlen = len(ip.user_ns['Out'])
89 self.assertEquals(oldlen, newlen)
89 self.assertEqual(oldlen, newlen)
90 #also test the default caching behavior
90 #also test the default caching behavior
91 ip.run_cell('1', store_history=True)
91 ip.run_cell('1', store_history=True)
92 newlen = len(ip.user_ns['Out'])
92 newlen = len(ip.user_ns['Out'])
93 self.assertEquals(oldlen+1, newlen)
93 self.assertEqual(oldlen+1, newlen)
94
94
95 def test_In_variable(self):
95 def test_In_variable(self):
96 "Verify that In variable grows with user input (GH-284)"
96 "Verify that In variable grows with user input (GH-284)"
97 oldlen = len(ip.user_ns['In'])
97 oldlen = len(ip.user_ns['In'])
98 ip.run_cell('1;', store_history=True)
98 ip.run_cell('1;', store_history=True)
99 newlen = len(ip.user_ns['In'])
99 newlen = len(ip.user_ns['In'])
100 self.assertEquals(oldlen+1, newlen)
100 self.assertEqual(oldlen+1, newlen)
101 self.assertEquals(ip.user_ns['In'][-1],'1;')
101 self.assertEqual(ip.user_ns['In'][-1],'1;')
102
102
103 def test_magic_names_in_string(self):
103 def test_magic_names_in_string(self):
104 ip.run_cell('a = """\n%exit\n"""')
104 ip.run_cell('a = """\n%exit\n"""')
105 self.assertEquals(ip.user_ns['a'], '\n%exit\n')
105 self.assertEqual(ip.user_ns['a'], '\n%exit\n')
106
106
107 def test_alias_crash(self):
107 def test_alias_crash(self):
108 """Errors in prefilter can't crash IPython"""
108 """Errors in prefilter can't crash IPython"""
109 ip.run_cell('%alias parts echo first %s second %s')
109 ip.run_cell('%alias parts echo first %s second %s')
110 # capture stderr:
110 # capture stderr:
111 save_err = io.stderr
111 save_err = io.stderr
112 io.stderr = StringIO()
112 io.stderr = StringIO()
113 ip.run_cell('parts 1')
113 ip.run_cell('parts 1')
114 err = io.stderr.getvalue()
114 err = io.stderr.getvalue()
115 io.stderr = save_err
115 io.stderr = save_err
116 self.assertEquals(err.split(':')[0], 'ERROR')
116 self.assertEqual(err.split(':')[0], 'ERROR')
117
117
118 def test_trailing_newline(self):
118 def test_trailing_newline(self):
119 """test that running !(command) does not raise a SyntaxError"""
119 """test that running !(command) does not raise a SyntaxError"""
120 ip.run_cell('!(true)\n', False)
120 ip.run_cell('!(true)\n', False)
121 ip.run_cell('!(true)\n\n\n', False)
121 ip.run_cell('!(true)\n\n\n', False)
122
122
123 def test_gh_597(self):
123 def test_gh_597(self):
124 """Pretty-printing lists of objects with non-ascii reprs may cause
124 """Pretty-printing lists of objects with non-ascii reprs may cause
125 problems."""
125 problems."""
126 class Spam(object):
126 class Spam(object):
127 def __repr__(self):
127 def __repr__(self):
128 return "\xe9"*50
128 return "\xe9"*50
129 import IPython.core.formatters
129 import IPython.core.formatters
130 f = IPython.core.formatters.PlainTextFormatter()
130 f = IPython.core.formatters.PlainTextFormatter()
131 f([Spam(),Spam()])
131 f([Spam(),Spam()])
132
132
133
133
134 def test_future_flags(self):
134 def test_future_flags(self):
135 """Check that future flags are used for parsing code (gh-777)"""
135 """Check that future flags are used for parsing code (gh-777)"""
136 ip.run_cell('from __future__ import print_function')
136 ip.run_cell('from __future__ import print_function')
137 try:
137 try:
138 ip.run_cell('prfunc_return_val = print(1,2, sep=" ")')
138 ip.run_cell('prfunc_return_val = print(1,2, sep=" ")')
139 assert 'prfunc_return_val' in ip.user_ns
139 assert 'prfunc_return_val' in ip.user_ns
140 finally:
140 finally:
141 # Reset compiler flags so we don't mess up other tests.
141 # Reset compiler flags so we don't mess up other tests.
142 ip.compile.reset_compiler_flags()
142 ip.compile.reset_compiler_flags()
143
143
144 def test_future_unicode(self):
144 def test_future_unicode(self):
145 """Check that unicode_literals is imported from __future__ (gh #786)"""
145 """Check that unicode_literals is imported from __future__ (gh #786)"""
146 try:
146 try:
147 ip.run_cell(u'byte_str = "a"')
147 ip.run_cell(u'byte_str = "a"')
148 assert isinstance(ip.user_ns['byte_str'], str) # string literals are byte strings by default
148 assert isinstance(ip.user_ns['byte_str'], str) # string literals are byte strings by default
149 ip.run_cell('from __future__ import unicode_literals')
149 ip.run_cell('from __future__ import unicode_literals')
150 ip.run_cell(u'unicode_str = "a"')
150 ip.run_cell(u'unicode_str = "a"')
151 assert isinstance(ip.user_ns['unicode_str'], unicode) # strings literals are now unicode
151 assert isinstance(ip.user_ns['unicode_str'], unicode) # strings literals are now unicode
152 finally:
152 finally:
153 # Reset compiler flags so we don't mess up other tests.
153 # Reset compiler flags so we don't mess up other tests.
154 ip.compile.reset_compiler_flags()
154 ip.compile.reset_compiler_flags()
155
155
156 def test_can_pickle(self):
156 def test_can_pickle(self):
157 "Can we pickle objects defined interactively (GH-29)"
157 "Can we pickle objects defined interactively (GH-29)"
158 ip = get_ipython()
158 ip = get_ipython()
159 ip.reset()
159 ip.reset()
160 ip.run_cell(("class Mylist(list):\n"
160 ip.run_cell(("class Mylist(list):\n"
161 " def __init__(self,x=[]):\n"
161 " def __init__(self,x=[]):\n"
162 " list.__init__(self,x)"))
162 " list.__init__(self,x)"))
163 ip.run_cell("w=Mylist([1,2,3])")
163 ip.run_cell("w=Mylist([1,2,3])")
164
164
165 from cPickle import dumps
165 from cPickle import dumps
166
166
167 # We need to swap in our main module - this is only necessary
167 # We need to swap in our main module - this is only necessary
168 # inside the test framework, because IPython puts the interactive module
168 # inside the test framework, because IPython puts the interactive module
169 # in place (but the test framework undoes this).
169 # in place (but the test framework undoes this).
170 _main = sys.modules['__main__']
170 _main = sys.modules['__main__']
171 sys.modules['__main__'] = ip.user_module
171 sys.modules['__main__'] = ip.user_module
172 try:
172 try:
173 res = dumps(ip.user_ns["w"])
173 res = dumps(ip.user_ns["w"])
174 finally:
174 finally:
175 sys.modules['__main__'] = _main
175 sys.modules['__main__'] = _main
176 self.assertTrue(isinstance(res, bytes))
176 self.assertTrue(isinstance(res, bytes))
177
177
178 def test_global_ns(self):
178 def test_global_ns(self):
179 "Code in functions must be able to access variables outside them."
179 "Code in functions must be able to access variables outside them."
180 ip = get_ipython()
180 ip = get_ipython()
181 ip.run_cell("a = 10")
181 ip.run_cell("a = 10")
182 ip.run_cell(("def f(x):\n"
182 ip.run_cell(("def f(x):\n"
183 " return x + a"))
183 " return x + a"))
184 ip.run_cell("b = f(12)")
184 ip.run_cell("b = f(12)")
185 self.assertEqual(ip.user_ns["b"], 22)
185 self.assertEqual(ip.user_ns["b"], 22)
186
186
187 def test_bad_custom_tb(self):
187 def test_bad_custom_tb(self):
188 """Check that InteractiveShell is protected from bad custom exception handlers"""
188 """Check that InteractiveShell is protected from bad custom exception handlers"""
189 from IPython.utils import io
189 from IPython.utils import io
190 save_stderr = io.stderr
190 save_stderr = io.stderr
191 try:
191 try:
192 # capture stderr
192 # capture stderr
193 io.stderr = StringIO()
193 io.stderr = StringIO()
194 ip.set_custom_exc((IOError,), lambda etype,value,tb: 1/0)
194 ip.set_custom_exc((IOError,), lambda etype,value,tb: 1/0)
195 self.assertEquals(ip.custom_exceptions, (IOError,))
195 self.assertEqual(ip.custom_exceptions, (IOError,))
196 ip.run_cell(u'raise IOError("foo")')
196 ip.run_cell(u'raise IOError("foo")')
197 self.assertEquals(ip.custom_exceptions, ())
197 self.assertEqual(ip.custom_exceptions, ())
198 self.assertTrue("Custom TB Handler failed" in io.stderr.getvalue())
198 self.assertTrue("Custom TB Handler failed" in io.stderr.getvalue())
199 finally:
199 finally:
200 io.stderr = save_stderr
200 io.stderr = save_stderr
201
201
202 def test_bad_custom_tb_return(self):
202 def test_bad_custom_tb_return(self):
203 """Check that InteractiveShell is protected from bad return types in custom exception handlers"""
203 """Check that InteractiveShell is protected from bad return types in custom exception handlers"""
204 from IPython.utils import io
204 from IPython.utils import io
205 save_stderr = io.stderr
205 save_stderr = io.stderr
206 try:
206 try:
207 # capture stderr
207 # capture stderr
208 io.stderr = StringIO()
208 io.stderr = StringIO()
209 ip.set_custom_exc((NameError,),lambda etype,value,tb, tb_offset=None: 1)
209 ip.set_custom_exc((NameError,),lambda etype,value,tb, tb_offset=None: 1)
210 self.assertEquals(ip.custom_exceptions, (NameError,))
210 self.assertEqual(ip.custom_exceptions, (NameError,))
211 ip.run_cell(u'a=abracadabra')
211 ip.run_cell(u'a=abracadabra')
212 self.assertEquals(ip.custom_exceptions, ())
212 self.assertEqual(ip.custom_exceptions, ())
213 self.assertTrue("Custom TB Handler failed" in io.stderr.getvalue())
213 self.assertTrue("Custom TB Handler failed" in io.stderr.getvalue())
214 finally:
214 finally:
215 io.stderr = save_stderr
215 io.stderr = save_stderr
216
216
217 def test_drop_by_id(self):
217 def test_drop_by_id(self):
218 myvars = {"a":object(), "b":object(), "c": object()}
218 myvars = {"a":object(), "b":object(), "c": object()}
219 ip.push(myvars, interactive=False)
219 ip.push(myvars, interactive=False)
220 for name in myvars:
220 for name in myvars:
221 assert name in ip.user_ns, name
221 assert name in ip.user_ns, name
222 assert name in ip.user_ns_hidden, name
222 assert name in ip.user_ns_hidden, name
223 ip.user_ns['b'] = 12
223 ip.user_ns['b'] = 12
224 ip.drop_by_id(myvars)
224 ip.drop_by_id(myvars)
225 for name in ["a", "c"]:
225 for name in ["a", "c"]:
226 assert name not in ip.user_ns, name
226 assert name not in ip.user_ns, name
227 assert name not in ip.user_ns_hidden, name
227 assert name not in ip.user_ns_hidden, name
228 assert ip.user_ns['b'] == 12
228 assert ip.user_ns['b'] == 12
229 ip.reset()
229 ip.reset()
230
230
231 def test_var_expand(self):
231 def test_var_expand(self):
232 ip.user_ns['f'] = u'Ca\xf1o'
232 ip.user_ns['f'] = u'Ca\xf1o'
233 self.assertEqual(ip.var_expand(u'echo $f'), u'echo Ca\xf1o')
233 self.assertEqual(ip.var_expand(u'echo $f'), u'echo Ca\xf1o')
234 self.assertEqual(ip.var_expand(u'echo {f}'), u'echo Ca\xf1o')
234 self.assertEqual(ip.var_expand(u'echo {f}'), u'echo Ca\xf1o')
235 self.assertEqual(ip.var_expand(u'echo {f[:-1]}'), u'echo Ca\xf1')
235 self.assertEqual(ip.var_expand(u'echo {f[:-1]}'), u'echo Ca\xf1')
236 self.assertEqual(ip.var_expand(u'echo {1*2}'), u'echo 2')
236 self.assertEqual(ip.var_expand(u'echo {1*2}'), u'echo 2')
237
237
238 ip.user_ns['f'] = b'Ca\xc3\xb1o'
238 ip.user_ns['f'] = b'Ca\xc3\xb1o'
239 # This should not raise any exception:
239 # This should not raise any exception:
240 ip.var_expand(u'echo $f')
240 ip.var_expand(u'echo $f')
241
241
242 def test_var_expand_local(self):
242 def test_var_expand_local(self):
243 """Test local variable expansion in !system and %magic calls"""
243 """Test local variable expansion in !system and %magic calls"""
244 # !system
244 # !system
245 ip.run_cell('def test():\n'
245 ip.run_cell('def test():\n'
246 ' lvar = "ttt"\n'
246 ' lvar = "ttt"\n'
247 ' ret = !echo {lvar}\n'
247 ' ret = !echo {lvar}\n'
248 ' return ret[0]\n')
248 ' return ret[0]\n')
249 res = ip.user_ns['test']()
249 res = ip.user_ns['test']()
250 nt.assert_in('ttt', res)
250 nt.assert_in('ttt', res)
251
251
252 # %magic
252 # %magic
253 ip.run_cell('def makemacro():\n'
253 ip.run_cell('def makemacro():\n'
254 ' macroname = "macro_var_expand_locals"\n'
254 ' macroname = "macro_var_expand_locals"\n'
255 ' %macro {macroname} codestr\n')
255 ' %macro {macroname} codestr\n')
256 ip.user_ns['codestr'] = "str(12)"
256 ip.user_ns['codestr'] = "str(12)"
257 ip.run_cell('makemacro()')
257 ip.run_cell('makemacro()')
258 nt.assert_in('macro_var_expand_locals', ip.user_ns)
258 nt.assert_in('macro_var_expand_locals', ip.user_ns)
259
259
260 def test_bad_var_expand(self):
260 def test_bad_var_expand(self):
261 """var_expand on invalid formats shouldn't raise"""
261 """var_expand on invalid formats shouldn't raise"""
262 # SyntaxError
262 # SyntaxError
263 self.assertEqual(ip.var_expand(u"{'a':5}"), u"{'a':5}")
263 self.assertEqual(ip.var_expand(u"{'a':5}"), u"{'a':5}")
264 # NameError
264 # NameError
265 self.assertEqual(ip.var_expand(u"{asdf}"), u"{asdf}")
265 self.assertEqual(ip.var_expand(u"{asdf}"), u"{asdf}")
266 # ZeroDivisionError
266 # ZeroDivisionError
267 self.assertEqual(ip.var_expand(u"{1/0}"), u"{1/0}")
267 self.assertEqual(ip.var_expand(u"{1/0}"), u"{1/0}")
268
268
269 def test_silent_nopostexec(self):
269 def test_silent_nopostexec(self):
270 """run_cell(silent=True) doesn't invoke post-exec funcs"""
270 """run_cell(silent=True) doesn't invoke post-exec funcs"""
271 d = dict(called=False)
271 d = dict(called=False)
272 def set_called():
272 def set_called():
273 d['called'] = True
273 d['called'] = True
274
274
275 ip.register_post_execute(set_called)
275 ip.register_post_execute(set_called)
276 ip.run_cell("1", silent=True)
276 ip.run_cell("1", silent=True)
277 self.assertFalse(d['called'])
277 self.assertFalse(d['called'])
278 # double-check that non-silent exec did what we expected
278 # double-check that non-silent exec did what we expected
279 # silent to avoid
279 # silent to avoid
280 ip.run_cell("1")
280 ip.run_cell("1")
281 self.assertTrue(d['called'])
281 self.assertTrue(d['called'])
282 # remove post-exec
282 # remove post-exec
283 ip._post_execute.pop(set_called)
283 ip._post_execute.pop(set_called)
284
284
285 def test_silent_noadvance(self):
285 def test_silent_noadvance(self):
286 """run_cell(silent=True) doesn't advance execution_count"""
286 """run_cell(silent=True) doesn't advance execution_count"""
287 ec = ip.execution_count
287 ec = ip.execution_count
288 # silent should force store_history=False
288 # silent should force store_history=False
289 ip.run_cell("1", store_history=True, silent=True)
289 ip.run_cell("1", store_history=True, silent=True)
290
290
291 self.assertEquals(ec, ip.execution_count)
291 self.assertEqual(ec, ip.execution_count)
292 # double-check that non-silent exec did what we expected
292 # double-check that non-silent exec did what we expected
293 # silent to avoid
293 # silent to avoid
294 ip.run_cell("1", store_history=True)
294 ip.run_cell("1", store_history=True)
295 self.assertEquals(ec+1, ip.execution_count)
295 self.assertEqual(ec+1, ip.execution_count)
296
296
297 def test_silent_nodisplayhook(self):
297 def test_silent_nodisplayhook(self):
298 """run_cell(silent=True) doesn't trigger displayhook"""
298 """run_cell(silent=True) doesn't trigger displayhook"""
299 d = dict(called=False)
299 d = dict(called=False)
300
300
301 trap = ip.display_trap
301 trap = ip.display_trap
302 save_hook = trap.hook
302 save_hook = trap.hook
303
303
304 def failing_hook(*args, **kwargs):
304 def failing_hook(*args, **kwargs):
305 d['called'] = True
305 d['called'] = True
306
306
307 try:
307 try:
308 trap.hook = failing_hook
308 trap.hook = failing_hook
309 ip.run_cell("1", silent=True)
309 ip.run_cell("1", silent=True)
310 self.assertFalse(d['called'])
310 self.assertFalse(d['called'])
311 # double-check that non-silent exec did what we expected
311 # double-check that non-silent exec did what we expected
312 # silent to avoid
312 # silent to avoid
313 ip.run_cell("1")
313 ip.run_cell("1")
314 self.assertTrue(d['called'])
314 self.assertTrue(d['called'])
315 finally:
315 finally:
316 trap.hook = save_hook
316 trap.hook = save_hook
317
317
318 @skipif(sys.version_info[0] >= 3, "softspace removed in py3")
318 @skipif(sys.version_info[0] >= 3, "softspace removed in py3")
319 def test_print_softspace(self):
319 def test_print_softspace(self):
320 """Verify that softspace is handled correctly when executing multiple
320 """Verify that softspace is handled correctly when executing multiple
321 statements.
321 statements.
322
322
323 In [1]: print 1; print 2
323 In [1]: print 1; print 2
324 1
324 1
325 2
325 2
326
326
327 In [2]: print 1,; print 2
327 In [2]: print 1,; print 2
328 1 2
328 1 2
329 """
329 """
330
330
331 def test_ofind_line_magic(self):
331 def test_ofind_line_magic(self):
332 from IPython.core.magic import register_line_magic
332 from IPython.core.magic import register_line_magic
333
333
334 @register_line_magic
334 @register_line_magic
335 def lmagic(line):
335 def lmagic(line):
336 "A line magic"
336 "A line magic"
337
337
338 # Get info on line magic
338 # Get info on line magic
339 lfind = ip._ofind('lmagic')
339 lfind = ip._ofind('lmagic')
340 info = dict(found=True, isalias=False, ismagic=True,
340 info = dict(found=True, isalias=False, ismagic=True,
341 namespace = 'IPython internal', obj= lmagic.__wrapped__,
341 namespace = 'IPython internal', obj= lmagic.__wrapped__,
342 parent = None)
342 parent = None)
343 nt.assert_equal(lfind, info)
343 nt.assert_equal(lfind, info)
344
344
345 def test_ofind_cell_magic(self):
345 def test_ofind_cell_magic(self):
346 from IPython.core.magic import register_cell_magic
346 from IPython.core.magic import register_cell_magic
347
347
348 @register_cell_magic
348 @register_cell_magic
349 def cmagic(line, cell):
349 def cmagic(line, cell):
350 "A cell magic"
350 "A cell magic"
351
351
352 # Get info on cell magic
352 # Get info on cell magic
353 find = ip._ofind('cmagic')
353 find = ip._ofind('cmagic')
354 info = dict(found=True, isalias=False, ismagic=True,
354 info = dict(found=True, isalias=False, ismagic=True,
355 namespace = 'IPython internal', obj= cmagic.__wrapped__,
355 namespace = 'IPython internal', obj= cmagic.__wrapped__,
356 parent = None)
356 parent = None)
357 nt.assert_equal(find, info)
357 nt.assert_equal(find, info)
358
358
359 def test_custom_exception(self):
359 def test_custom_exception(self):
360 called = []
360 called = []
361 def my_handler(shell, etype, value, tb, tb_offset=None):
361 def my_handler(shell, etype, value, tb, tb_offset=None):
362 called.append(etype)
362 called.append(etype)
363 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
363 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
364
364
365 ip.set_custom_exc((ValueError,), my_handler)
365 ip.set_custom_exc((ValueError,), my_handler)
366 try:
366 try:
367 ip.run_cell("raise ValueError('test')")
367 ip.run_cell("raise ValueError('test')")
368 # Check that this was called, and only once.
368 # Check that this was called, and only once.
369 self.assertEqual(called, [ValueError])
369 self.assertEqual(called, [ValueError])
370 finally:
370 finally:
371 # Reset the custom exception hook
371 # Reset the custom exception hook
372 ip.set_custom_exc((), None)
372 ip.set_custom_exc((), None)
373
373
374
374
375 class TestSafeExecfileNonAsciiPath(unittest.TestCase):
375 class TestSafeExecfileNonAsciiPath(unittest.TestCase):
376
376
377 def setUp(self):
377 def setUp(self):
378 self.BASETESTDIR = tempfile.mkdtemp()
378 self.BASETESTDIR = tempfile.mkdtemp()
379 self.TESTDIR = join(self.BASETESTDIR, u"åäö")
379 self.TESTDIR = join(self.BASETESTDIR, u"åäö")
380 os.mkdir(self.TESTDIR)
380 os.mkdir(self.TESTDIR)
381 with open(join(self.TESTDIR, u"åäötestscript.py"), "w") as sfile:
381 with open(join(self.TESTDIR, u"åäötestscript.py"), "w") as sfile:
382 sfile.write("pass\n")
382 sfile.write("pass\n")
383 self.oldpath = os.getcwdu()
383 self.oldpath = os.getcwdu()
384 os.chdir(self.TESTDIR)
384 os.chdir(self.TESTDIR)
385 self.fname = u"åäötestscript.py"
385 self.fname = u"åäötestscript.py"
386
386
387 def tearDown(self):
387 def tearDown(self):
388 os.chdir(self.oldpath)
388 os.chdir(self.oldpath)
389 shutil.rmtree(self.BASETESTDIR)
389 shutil.rmtree(self.BASETESTDIR)
390
390
391 def test_1(self):
391 def test_1(self):
392 """Test safe_execfile with non-ascii path
392 """Test safe_execfile with non-ascii path
393 """
393 """
394 ip.safe_execfile(self.fname, {}, raise_exceptions=True)
394 ip.safe_execfile(self.fname, {}, raise_exceptions=True)
395
395
396
396
397 class TestSystemRaw(unittest.TestCase):
397 class TestSystemRaw(unittest.TestCase):
398 def test_1(self):
398 def test_1(self):
399 """Test system_raw with non-ascii cmd
399 """Test system_raw with non-ascii cmd
400 """
400 """
401 cmd = ur'''python -c "'åäö'" '''
401 cmd = ur'''python -c "'åäö'" '''
402 ip.system_raw(cmd)
402 ip.system_raw(cmd)
403
403
404
404
405 def test__IPYTHON__():
405 def test__IPYTHON__():
406 # This shouldn't raise a NameError, that's all
406 # This shouldn't raise a NameError, that's all
407 __IPYTHON__
407 __IPYTHON__
@@ -1,46 +1,46 b''
1 """Tests for plugin.py"""
1 """Tests for plugin.py"""
2
2
3 #-----------------------------------------------------------------------------
3 #-----------------------------------------------------------------------------
4 # Imports
4 # Imports
5 #-----------------------------------------------------------------------------
5 #-----------------------------------------------------------------------------
6
6
7 from unittest import TestCase
7 from unittest import TestCase
8
8
9 from IPython.core.plugin import Plugin, PluginManager
9 from IPython.core.plugin import Plugin, PluginManager
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Tests
12 # Tests
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 class FooPlugin(Plugin):
15 class FooPlugin(Plugin):
16 pass
16 pass
17
17
18
18
19 class BarPlugin(Plugin):
19 class BarPlugin(Plugin):
20 pass
20 pass
21
21
22
22
23 class BadPlugin(object):
23 class BadPlugin(object):
24 pass
24 pass
25
25
26
26
27 class PluginTest(TestCase):
27 class PluginTest(TestCase):
28
28
29 def setUp(self):
29 def setUp(self):
30 self.manager = PluginManager()
30 self.manager = PluginManager()
31
31
32 def test_register_get(self):
32 def test_register_get(self):
33 self.assertEquals(None, self.manager.get_plugin('foo'))
33 self.assertEqual(None, self.manager.get_plugin('foo'))
34 foo = FooPlugin()
34 foo = FooPlugin()
35 self.manager.register_plugin('foo', foo)
35 self.manager.register_plugin('foo', foo)
36 self.assertEquals(foo, self.manager.get_plugin('foo'))
36 self.assertEqual(foo, self.manager.get_plugin('foo'))
37 bar = BarPlugin()
37 bar = BarPlugin()
38 self.assertRaises(KeyError, self.manager.register_plugin, 'foo', bar)
38 self.assertRaises(KeyError, self.manager.register_plugin, 'foo', bar)
39 bad = BadPlugin()
39 bad = BadPlugin()
40 self.assertRaises(TypeError, self.manager.register_plugin, 'bad')
40 self.assertRaises(TypeError, self.manager.register_plugin, 'bad')
41
41
42 def test_unregister(self):
42 def test_unregister(self):
43 foo = FooPlugin()
43 foo = FooPlugin()
44 self.manager.register_plugin('foo', foo)
44 self.manager.register_plugin('foo', foo)
45 self.manager.unregister_plugin('foo')
45 self.manager.unregister_plugin('foo')
46 self.assertEquals(None, self.manager.get_plugin('foo'))
46 self.assertEqual(None, self.manager.get_plugin('foo'))
@@ -1,111 +1,111 b''
1 # -*- coding: utf-8
1 # -*- coding: utf-8
2 """Tests for prompt generation."""
2 """Tests for prompt generation."""
3
3
4 import unittest
4 import unittest
5
5
6 import os
6 import os
7 import nose.tools as nt
7 import nose.tools as nt
8
8
9 from IPython.testing import tools as tt, decorators as dec
9 from IPython.testing import tools as tt, decorators as dec
10 from IPython.core.prompts import PromptManager, LazyEvaluate
10 from IPython.core.prompts import PromptManager, LazyEvaluate
11 from IPython.testing.globalipapp import get_ipython
11 from IPython.testing.globalipapp import get_ipython
12 from IPython.utils import py3compat
12 from IPython.utils import py3compat
13 from IPython.utils.tempdir import TemporaryDirectory
13 from IPython.utils.tempdir import TemporaryDirectory
14
14
15 ip = get_ipython()
15 ip = get_ipython()
16
16
17
17
18 class PromptTests(unittest.TestCase):
18 class PromptTests(unittest.TestCase):
19 def setUp(self):
19 def setUp(self):
20 self.pm = PromptManager(shell=ip, config=ip.config)
20 self.pm = PromptManager(shell=ip, config=ip.config)
21
21
22 def test_multiline_prompt(self):
22 def test_multiline_prompt(self):
23 self.pm.in_template = "[In]\n>>>"
23 self.pm.in_template = "[In]\n>>>"
24 self.pm.render('in')
24 self.pm.render('in')
25 self.assertEqual(self.pm.width, 3)
25 self.assertEqual(self.pm.width, 3)
26 self.assertEqual(self.pm.txtwidth, 3)
26 self.assertEqual(self.pm.txtwidth, 3)
27
27
28 self.pm.in_template = '[In]\n'
28 self.pm.in_template = '[In]\n'
29 self.pm.render('in')
29 self.pm.render('in')
30 self.assertEqual(self.pm.width, 0)
30 self.assertEqual(self.pm.width, 0)
31 self.assertEqual(self.pm.txtwidth, 0)
31 self.assertEqual(self.pm.txtwidth, 0)
32
32
33 def test_translate_abbreviations(self):
33 def test_translate_abbreviations(self):
34 def do_translate(template):
34 def do_translate(template):
35 self.pm.in_template = template
35 self.pm.in_template = template
36 return self.pm.templates['in']
36 return self.pm.templates['in']
37
37
38 pairs = [(r'%n>', '{color.number}{count}{color.prompt}>'),
38 pairs = [(r'%n>', '{color.number}{count}{color.prompt}>'),
39 (r'\T', '{time}'),
39 (r'\T', '{time}'),
40 (r'\n', '\n')
40 (r'\n', '\n')
41 ]
41 ]
42
42
43 tt.check_pairs(do_translate, pairs)
43 tt.check_pairs(do_translate, pairs)
44
44
45 def test_user_ns(self):
45 def test_user_ns(self):
46 self.pm.color_scheme = 'NoColor'
46 self.pm.color_scheme = 'NoColor'
47 ip.ex("foo='bar'")
47 ip.ex("foo='bar'")
48 self.pm.in_template = "In [{foo}]"
48 self.pm.in_template = "In [{foo}]"
49 prompt = self.pm.render('in')
49 prompt = self.pm.render('in')
50 self.assertEquals(prompt, u'In [bar]')
50 self.assertEqual(prompt, u'In [bar]')
51
51
52 def test_builtins(self):
52 def test_builtins(self):
53 self.pm.color_scheme = 'NoColor'
53 self.pm.color_scheme = 'NoColor'
54 self.pm.in_template = "In [{int}]"
54 self.pm.in_template = "In [{int}]"
55 prompt = self.pm.render('in')
55 prompt = self.pm.render('in')
56 self.assertEquals(prompt, u"In [%r]" % int)
56 self.assertEqual(prompt, u"In [%r]" % int)
57
57
58 def test_undefined(self):
58 def test_undefined(self):
59 self.pm.color_scheme = 'NoColor'
59 self.pm.color_scheme = 'NoColor'
60 self.pm.in_template = "In [{foo_dne}]"
60 self.pm.in_template = "In [{foo_dne}]"
61 prompt = self.pm.render('in')
61 prompt = self.pm.render('in')
62 self.assertEquals(prompt, u"In [<ERROR: 'foo_dne' not found>]")
62 self.assertEqual(prompt, u"In [<ERROR: 'foo_dne' not found>]")
63
63
64 def test_render(self):
64 def test_render(self):
65 self.pm.in_template = r'\#>'
65 self.pm.in_template = r'\#>'
66 self.assertEqual(self.pm.render('in',color=False), '%d>' % ip.execution_count)
66 self.assertEqual(self.pm.render('in',color=False), '%d>' % ip.execution_count)
67
67
68 def test_render_unicode_cwd(self):
68 def test_render_unicode_cwd(self):
69 save = os.getcwdu()
69 save = os.getcwdu()
70 with TemporaryDirectory(u'ünicødé') as td:
70 with TemporaryDirectory(u'ünicødé') as td:
71 os.chdir(td)
71 os.chdir(td)
72 self.pm.in_template = r'\w [\#]'
72 self.pm.in_template = r'\w [\#]'
73 p = self.pm.render('in', color=False)
73 p = self.pm.render('in', color=False)
74 self.assertEquals(p, u"%s [%i]" % (os.getcwdu(), ip.execution_count))
74 self.assertEqual(p, u"%s [%i]" % (os.getcwdu(), ip.execution_count))
75 os.chdir(save)
75 os.chdir(save)
76
76
77 def test_lazy_eval_unicode(self):
77 def test_lazy_eval_unicode(self):
78 u = u'ünicødé'
78 u = u'ünicødé'
79 lz = LazyEvaluate(lambda : u)
79 lz = LazyEvaluate(lambda : u)
80 # str(lz) would fail
80 # str(lz) would fail
81 self.assertEquals(unicode(lz), u)
81 self.assertEqual(unicode(lz), u)
82 self.assertEquals(format(lz), u)
82 self.assertEqual(format(lz), u)
83
83
84 def test_lazy_eval_nonascii_bytes(self):
84 def test_lazy_eval_nonascii_bytes(self):
85 u = u'ünicødé'
85 u = u'ünicødé'
86 b = u.encode('utf8')
86 b = u.encode('utf8')
87 lz = LazyEvaluate(lambda : b)
87 lz = LazyEvaluate(lambda : b)
88 # unicode(lz) would fail
88 # unicode(lz) would fail
89 self.assertEquals(str(lz), str(b))
89 self.assertEqual(str(lz), str(b))
90 self.assertEquals(format(lz), str(b))
90 self.assertEqual(format(lz), str(b))
91
91
92 def test_lazy_eval_float(self):
92 def test_lazy_eval_float(self):
93 f = 0.503
93 f = 0.503
94 lz = LazyEvaluate(lambda : f)
94 lz = LazyEvaluate(lambda : f)
95
95
96 self.assertEquals(str(lz), str(f))
96 self.assertEqual(str(lz), str(f))
97 self.assertEquals(unicode(lz), unicode(f))
97 self.assertEqual(unicode(lz), unicode(f))
98 self.assertEquals(format(lz), str(f))
98 self.assertEqual(format(lz), str(f))
99 self.assertEquals(format(lz, '.1'), '0.5')
99 self.assertEqual(format(lz, '.1'), '0.5')
100
100
101 @dec.skip_win32
101 @dec.skip_win32
102 def test_cwd_x(self):
102 def test_cwd_x(self):
103 self.pm.in_template = r"\X0"
103 self.pm.in_template = r"\X0"
104 save = os.getcwdu()
104 save = os.getcwdu()
105 os.chdir(os.path.expanduser('~'))
105 os.chdir(os.path.expanduser('~'))
106 p = self.pm.render('in', color=False)
106 p = self.pm.render('in', color=False)
107 try:
107 try:
108 self.assertEquals(p, '~')
108 self.assertEqual(p, '~')
109 finally:
109 finally:
110 os.chdir(save)
110 os.chdir(save)
111
111
@@ -1,27 +1,27 b''
1 """Tests for the notebook kernel and session manager."""
1 """Tests for the notebook kernel and session manager."""
2
2
3 from unittest import TestCase
3 from unittest import TestCase
4
4
5 from IPython.frontend.html.notebook.kernelmanager import MultiKernelManager
5 from IPython.frontend.html.notebook.kernelmanager import MultiKernelManager
6
6
7 class TestKernelManager(TestCase):
7 class TestKernelManager(TestCase):
8
8
9 def test_km_lifecycle(self):
9 def test_km_lifecycle(self):
10 km = MultiKernelManager()
10 km = MultiKernelManager()
11 kid = km.start_kernel()
11 kid = km.start_kernel()
12 self.assert_(kid in km)
12 self.assert_(kid in km)
13 self.assertEquals(len(km),1)
13 self.assertEqual(len(km),1)
14 km.kill_kernel(kid)
14 km.kill_kernel(kid)
15 self.assert_(not kid in km)
15 self.assert_(not kid in km)
16
16
17 kid = km.start_kernel()
17 kid = km.start_kernel()
18 self.assertEquals('127.0.0.1',km.get_kernel_ip(kid))
18 self.assertEqual('127.0.0.1',km.get_kernel_ip(kid))
19 port_dict = km.get_kernel_ports(kid)
19 port_dict = km.get_kernel_ports(kid)
20 self.assert_('stdin_port' in port_dict)
20 self.assert_('stdin_port' in port_dict)
21 self.assert_('iopub_port' in port_dict)
21 self.assert_('iopub_port' in port_dict)
22 self.assert_('shell_port' in port_dict)
22 self.assert_('shell_port' in port_dict)
23 self.assert_('hb_port' in port_dict)
23 self.assert_('hb_port' in port_dict)
24 km.get_kernel(kid)
24 km.get_kernel(kid)
25 km.kill_kernel(kid)
25 km.kill_kernel(kid)
26
26
27
27
@@ -1,34 +1,34 b''
1 """Tests for the notebook manager."""
1 """Tests for the notebook manager."""
2
2
3 import os
3 import os
4 from unittest import TestCase
4 from unittest import TestCase
5 from tempfile import NamedTemporaryFile
5 from tempfile import NamedTemporaryFile
6
6
7 from IPython.utils.tempdir import TemporaryDirectory
7 from IPython.utils.tempdir import TemporaryDirectory
8 from IPython.utils.traitlets import TraitError
8 from IPython.utils.traitlets import TraitError
9
9
10 from IPython.frontend.html.notebook.notebookmanager import NotebookManager
10 from IPython.frontend.html.notebook.notebookmanager import NotebookManager
11
11
12 class TestNotebookManager(TestCase):
12 class TestNotebookManager(TestCase):
13
13
14 def test_nb_dir(self):
14 def test_nb_dir(self):
15 with TemporaryDirectory() as td:
15 with TemporaryDirectory() as td:
16 km = NotebookManager(notebook_dir=td)
16 km = NotebookManager(notebook_dir=td)
17 self.assertEquals(km.notebook_dir, td)
17 self.assertEqual(km.notebook_dir, td)
18
18
19 def test_create_nb_dir(self):
19 def test_create_nb_dir(self):
20 with TemporaryDirectory() as td:
20 with TemporaryDirectory() as td:
21 nbdir = os.path.join(td, 'notebooks')
21 nbdir = os.path.join(td, 'notebooks')
22 km = NotebookManager(notebook_dir=nbdir)
22 km = NotebookManager(notebook_dir=nbdir)
23 self.assertEquals(km.notebook_dir, nbdir)
23 self.assertEqual(km.notebook_dir, nbdir)
24
24
25 def test_missing_nb_dir(self):
25 def test_missing_nb_dir(self):
26 with TemporaryDirectory() as td:
26 with TemporaryDirectory() as td:
27 nbdir = os.path.join(td, 'notebook', 'dir', 'is', 'missing')
27 nbdir = os.path.join(td, 'notebook', 'dir', 'is', 'missing')
28 self.assertRaises(TraitError, NotebookManager, notebook_dir=nbdir)
28 self.assertRaises(TraitError, NotebookManager, notebook_dir=nbdir)
29
29
30 def test_invalid_nb_dir(self):
30 def test_invalid_nb_dir(self):
31 with NamedTemporaryFile() as tf:
31 with NamedTemporaryFile() as tf:
32 self.assertRaises(TraitError, NotebookManager, notebook_dir=tf.name)
32 self.assertRaises(TraitError, NotebookManager, notebook_dir=tf.name)
33
33
34
34
@@ -1,171 +1,171 b''
1 # Standard library imports
1 # Standard library imports
2 import unittest
2 import unittest
3
3
4 # Local imports
4 # Local imports
5 from IPython.frontend.qt.console.ansi_code_processor import AnsiCodeProcessor
5 from IPython.frontend.qt.console.ansi_code_processor import AnsiCodeProcessor
6
6
7
7
8 class TestAnsiCodeProcessor(unittest.TestCase):
8 class TestAnsiCodeProcessor(unittest.TestCase):
9
9
10 def setUp(self):
10 def setUp(self):
11 self.processor = AnsiCodeProcessor()
11 self.processor = AnsiCodeProcessor()
12
12
13 def test_clear(self):
13 def test_clear(self):
14 """ Do control sequences for clearing the console work?
14 """ Do control sequences for clearing the console work?
15 """
15 """
16 string = '\x1b[2J\x1b[K'
16 string = '\x1b[2J\x1b[K'
17 i = -1
17 i = -1
18 for i, substring in enumerate(self.processor.split_string(string)):
18 for i, substring in enumerate(self.processor.split_string(string)):
19 if i == 0:
19 if i == 0:
20 self.assertEquals(len(self.processor.actions), 1)
20 self.assertEqual(len(self.processor.actions), 1)
21 action = self.processor.actions[0]
21 action = self.processor.actions[0]
22 self.assertEquals(action.action, 'erase')
22 self.assertEqual(action.action, 'erase')
23 self.assertEquals(action.area, 'screen')
23 self.assertEqual(action.area, 'screen')
24 self.assertEquals(action.erase_to, 'all')
24 self.assertEqual(action.erase_to, 'all')
25 elif i == 1:
25 elif i == 1:
26 self.assertEquals(len(self.processor.actions), 1)
26 self.assertEqual(len(self.processor.actions), 1)
27 action = self.processor.actions[0]
27 action = self.processor.actions[0]
28 self.assertEquals(action.action, 'erase')
28 self.assertEqual(action.action, 'erase')
29 self.assertEquals(action.area, 'line')
29 self.assertEqual(action.area, 'line')
30 self.assertEquals(action.erase_to, 'end')
30 self.assertEqual(action.erase_to, 'end')
31 else:
31 else:
32 self.fail('Too many substrings.')
32 self.fail('Too many substrings.')
33 self.assertEquals(i, 1, 'Too few substrings.')
33 self.assertEqual(i, 1, 'Too few substrings.')
34
34
35 def test_colors(self):
35 def test_colors(self):
36 """ Do basic controls sequences for colors work?
36 """ Do basic controls sequences for colors work?
37 """
37 """
38 string = 'first\x1b[34mblue\x1b[0mlast'
38 string = 'first\x1b[34mblue\x1b[0mlast'
39 i = -1
39 i = -1
40 for i, substring in enumerate(self.processor.split_string(string)):
40 for i, substring in enumerate(self.processor.split_string(string)):
41 if i == 0:
41 if i == 0:
42 self.assertEquals(substring, 'first')
42 self.assertEqual(substring, 'first')
43 self.assertEquals(self.processor.foreground_color, None)
43 self.assertEqual(self.processor.foreground_color, None)
44 elif i == 1:
44 elif i == 1:
45 self.assertEquals(substring, 'blue')
45 self.assertEqual(substring, 'blue')
46 self.assertEquals(self.processor.foreground_color, 4)
46 self.assertEqual(self.processor.foreground_color, 4)
47 elif i == 2:
47 elif i == 2:
48 self.assertEquals(substring, 'last')
48 self.assertEqual(substring, 'last')
49 self.assertEquals(self.processor.foreground_color, None)
49 self.assertEqual(self.processor.foreground_color, None)
50 else:
50 else:
51 self.fail('Too many substrings.')
51 self.fail('Too many substrings.')
52 self.assertEquals(i, 2, 'Too few substrings.')
52 self.assertEqual(i, 2, 'Too few substrings.')
53
53
54 def test_colors_xterm(self):
54 def test_colors_xterm(self):
55 """ Do xterm-specific control sequences for colors work?
55 """ Do xterm-specific control sequences for colors work?
56 """
56 """
57 string = '\x1b]4;20;rgb:ff/ff/ff\x1b' \
57 string = '\x1b]4;20;rgb:ff/ff/ff\x1b' \
58 '\x1b]4;25;rgbi:1.0/1.0/1.0\x1b'
58 '\x1b]4;25;rgbi:1.0/1.0/1.0\x1b'
59 substrings = list(self.processor.split_string(string))
59 substrings = list(self.processor.split_string(string))
60 desired = { 20 : (255, 255, 255),
60 desired = { 20 : (255, 255, 255),
61 25 : (255, 255, 255) }
61 25 : (255, 255, 255) }
62 self.assertEquals(self.processor.color_map, desired)
62 self.assertEqual(self.processor.color_map, desired)
63
63
64 string = '\x1b[38;5;20m\x1b[48;5;25m'
64 string = '\x1b[38;5;20m\x1b[48;5;25m'
65 substrings = list(self.processor.split_string(string))
65 substrings = list(self.processor.split_string(string))
66 self.assertEquals(self.processor.foreground_color, 20)
66 self.assertEqual(self.processor.foreground_color, 20)
67 self.assertEquals(self.processor.background_color, 25)
67 self.assertEqual(self.processor.background_color, 25)
68
68
69 def test_scroll(self):
69 def test_scroll(self):
70 """ Do control sequences for scrolling the buffer work?
70 """ Do control sequences for scrolling the buffer work?
71 """
71 """
72 string = '\x1b[5S\x1b[T'
72 string = '\x1b[5S\x1b[T'
73 i = -1
73 i = -1
74 for i, substring in enumerate(self.processor.split_string(string)):
74 for i, substring in enumerate(self.processor.split_string(string)):
75 if i == 0:
75 if i == 0:
76 self.assertEquals(len(self.processor.actions), 1)
76 self.assertEqual(len(self.processor.actions), 1)
77 action = self.processor.actions[0]
77 action = self.processor.actions[0]
78 self.assertEquals(action.action, 'scroll')
78 self.assertEqual(action.action, 'scroll')
79 self.assertEquals(action.dir, 'up')
79 self.assertEqual(action.dir, 'up')
80 self.assertEquals(action.unit, 'line')
80 self.assertEqual(action.unit, 'line')
81 self.assertEquals(action.count, 5)
81 self.assertEqual(action.count, 5)
82 elif i == 1:
82 elif i == 1:
83 self.assertEquals(len(self.processor.actions), 1)
83 self.assertEqual(len(self.processor.actions), 1)
84 action = self.processor.actions[0]
84 action = self.processor.actions[0]
85 self.assertEquals(action.action, 'scroll')
85 self.assertEqual(action.action, 'scroll')
86 self.assertEquals(action.dir, 'down')
86 self.assertEqual(action.dir, 'down')
87 self.assertEquals(action.unit, 'line')
87 self.assertEqual(action.unit, 'line')
88 self.assertEquals(action.count, 1)
88 self.assertEqual(action.count, 1)
89 else:
89 else:
90 self.fail('Too many substrings.')
90 self.fail('Too many substrings.')
91 self.assertEquals(i, 1, 'Too few substrings.')
91 self.assertEqual(i, 1, 'Too few substrings.')
92
92
93 def test_formfeed(self):
93 def test_formfeed(self):
94 """ Are formfeed characters processed correctly?
94 """ Are formfeed characters processed correctly?
95 """
95 """
96 string = '\f' # form feed
96 string = '\f' # form feed
97 self.assertEquals(list(self.processor.split_string(string)), [''])
97 self.assertEqual(list(self.processor.split_string(string)), [''])
98 self.assertEquals(len(self.processor.actions), 1)
98 self.assertEqual(len(self.processor.actions), 1)
99 action = self.processor.actions[0]
99 action = self.processor.actions[0]
100 self.assertEquals(action.action, 'scroll')
100 self.assertEqual(action.action, 'scroll')
101 self.assertEquals(action.dir, 'down')
101 self.assertEqual(action.dir, 'down')
102 self.assertEquals(action.unit, 'page')
102 self.assertEqual(action.unit, 'page')
103 self.assertEquals(action.count, 1)
103 self.assertEqual(action.count, 1)
104
104
105 def test_carriage_return(self):
105 def test_carriage_return(self):
106 """ Are carriage return characters processed correctly?
106 """ Are carriage return characters processed correctly?
107 """
107 """
108 string = 'foo\rbar' # carriage return
108 string = 'foo\rbar' # carriage return
109 splits = []
109 splits = []
110 actions = []
110 actions = []
111 for split in self.processor.split_string(string):
111 for split in self.processor.split_string(string):
112 splits.append(split)
112 splits.append(split)
113 actions.append([action.action for action in self.processor.actions])
113 actions.append([action.action for action in self.processor.actions])
114 self.assertEquals(splits, ['foo', None, 'bar'])
114 self.assertEqual(splits, ['foo', None, 'bar'])
115 self.assertEquals(actions, [[], ['carriage-return'], []])
115 self.assertEqual(actions, [[], ['carriage-return'], []])
116
116
117 def test_carriage_return_newline(self):
117 def test_carriage_return_newline(self):
118 """transform CRLF to LF"""
118 """transform CRLF to LF"""
119 string = 'foo\rbar\r\ncat\r\n\n' # carriage return and newline
119 string = 'foo\rbar\r\ncat\r\n\n' # carriage return and newline
120 # only one CR action should occur, and '\r\n' should transform to '\n'
120 # only one CR action should occur, and '\r\n' should transform to '\n'
121 splits = []
121 splits = []
122 actions = []
122 actions = []
123 for split in self.processor.split_string(string):
123 for split in self.processor.split_string(string):
124 splits.append(split)
124 splits.append(split)
125 actions.append([action.action for action in self.processor.actions])
125 actions.append([action.action for action in self.processor.actions])
126 self.assertEquals(splits, ['foo', None, 'bar', '\r\n', 'cat', '\r\n', '\n'])
126 self.assertEqual(splits, ['foo', None, 'bar', '\r\n', 'cat', '\r\n', '\n'])
127 self.assertEquals(actions, [[], ['carriage-return'], [], ['newline'], [], ['newline'], ['newline']])
127 self.assertEqual(actions, [[], ['carriage-return'], [], ['newline'], [], ['newline'], ['newline']])
128
128
129 def test_beep(self):
129 def test_beep(self):
130 """ Are beep characters processed correctly?
130 """ Are beep characters processed correctly?
131 """
131 """
132 string = 'foo\abar' # bell
132 string = 'foo\abar' # bell
133 splits = []
133 splits = []
134 actions = []
134 actions = []
135 for split in self.processor.split_string(string):
135 for split in self.processor.split_string(string):
136 splits.append(split)
136 splits.append(split)
137 actions.append([action.action for action in self.processor.actions])
137 actions.append([action.action for action in self.processor.actions])
138 self.assertEquals(splits, ['foo', None, 'bar'])
138 self.assertEqual(splits, ['foo', None, 'bar'])
139 self.assertEquals(actions, [[], ['beep'], []])
139 self.assertEqual(actions, [[], ['beep'], []])
140
140
141 def test_backspace(self):
141 def test_backspace(self):
142 """ Are backspace characters processed correctly?
142 """ Are backspace characters processed correctly?
143 """
143 """
144 string = 'foo\bbar' # backspace
144 string = 'foo\bbar' # backspace
145 splits = []
145 splits = []
146 actions = []
146 actions = []
147 for split in self.processor.split_string(string):
147 for split in self.processor.split_string(string):
148 splits.append(split)
148 splits.append(split)
149 actions.append([action.action for action in self.processor.actions])
149 actions.append([action.action for action in self.processor.actions])
150 self.assertEquals(splits, ['foo', None, 'bar'])
150 self.assertEqual(splits, ['foo', None, 'bar'])
151 self.assertEquals(actions, [[], ['backspace'], []])
151 self.assertEqual(actions, [[], ['backspace'], []])
152
152
153 def test_combined(self):
153 def test_combined(self):
154 """ Are CR and BS characters processed correctly in combination?
154 """ Are CR and BS characters processed correctly in combination?
155
155
156 BS is treated as a change in print position, rather than a
156 BS is treated as a change in print position, rather than a
157 backwards character deletion. Therefore a BS at EOL is
157 backwards character deletion. Therefore a BS at EOL is
158 effectively ignored.
158 effectively ignored.
159 """
159 """
160 string = 'abc\rdef\b' # CR and backspace
160 string = 'abc\rdef\b' # CR and backspace
161 splits = []
161 splits = []
162 actions = []
162 actions = []
163 for split in self.processor.split_string(string):
163 for split in self.processor.split_string(string):
164 splits.append(split)
164 splits.append(split)
165 actions.append([action.action for action in self.processor.actions])
165 actions.append([action.action for action in self.processor.actions])
166 self.assertEquals(splits, ['abc', None, 'def', None])
166 self.assertEqual(splits, ['abc', None, 'def', None])
167 self.assertEquals(actions, [[], ['carriage-return'], [], ['backspace']])
167 self.assertEqual(actions, [[], ['carriage-return'], [], ['backspace']])
168
168
169
169
170 if __name__ == '__main__':
170 if __name__ == '__main__':
171 unittest.main()
171 unittest.main()
@@ -1,47 +1,47 b''
1 # Standard library imports
1 # Standard library imports
2 import unittest
2 import unittest
3
3
4 # System library imports
4 # System library imports
5 from pygments.lexers import CLexer, CppLexer, PythonLexer
5 from pygments.lexers import CLexer, CppLexer, PythonLexer
6
6
7 # Local imports
7 # Local imports
8 from IPython.frontend.qt.console.completion_lexer import CompletionLexer
8 from IPython.frontend.qt.console.completion_lexer import CompletionLexer
9
9
10
10
11 class TestCompletionLexer(unittest.TestCase):
11 class TestCompletionLexer(unittest.TestCase):
12
12
13 def testPython(self):
13 def testPython(self):
14 """ Does the CompletionLexer work for Python?
14 """ Does the CompletionLexer work for Python?
15 """
15 """
16 lexer = CompletionLexer(PythonLexer())
16 lexer = CompletionLexer(PythonLexer())
17
17
18 # Test simplest case.
18 # Test simplest case.
19 self.assertEquals(lexer.get_context("foo.bar.baz"),
19 self.assertEqual(lexer.get_context("foo.bar.baz"),
20 [ "foo", "bar", "baz" ])
20 [ "foo", "bar", "baz" ])
21
21
22 # Test trailing period.
22 # Test trailing period.
23 self.assertEquals(lexer.get_context("foo.bar."), [ "foo", "bar", "" ])
23 self.assertEqual(lexer.get_context("foo.bar."), [ "foo", "bar", "" ])
24
24
25 # Test with prompt present.
25 # Test with prompt present.
26 self.assertEquals(lexer.get_context(">>> foo.bar.baz"),
26 self.assertEqual(lexer.get_context(">>> foo.bar.baz"),
27 [ "foo", "bar", "baz" ])
27 [ "foo", "bar", "baz" ])
28
28
29 # Test spacing in name.
29 # Test spacing in name.
30 self.assertEquals(lexer.get_context("foo.bar. baz"), [ "baz" ])
30 self.assertEqual(lexer.get_context("foo.bar. baz"), [ "baz" ])
31
31
32 # Test parenthesis.
32 # Test parenthesis.
33 self.assertEquals(lexer.get_context("foo("), [])
33 self.assertEqual(lexer.get_context("foo("), [])
34
34
35 def testC(self):
35 def testC(self):
36 """ Does the CompletionLexer work for C/C++?
36 """ Does the CompletionLexer work for C/C++?
37 """
37 """
38 lexer = CompletionLexer(CLexer())
38 lexer = CompletionLexer(CLexer())
39 self.assertEquals(lexer.get_context("foo.bar"), [ "foo", "bar" ])
39 self.assertEqual(lexer.get_context("foo.bar"), [ "foo", "bar" ])
40 self.assertEquals(lexer.get_context("foo->bar"), [ "foo", "bar" ])
40 self.assertEqual(lexer.get_context("foo->bar"), [ "foo", "bar" ])
41
41
42 lexer = CompletionLexer(CppLexer())
42 lexer = CompletionLexer(CppLexer())
43 self.assertEquals(lexer.get_context("Foo::Bar"), [ "Foo", "Bar" ])
43 self.assertEqual(lexer.get_context("Foo::Bar"), [ "Foo", "Bar" ])
44
44
45
45
46 if __name__ == '__main__':
46 if __name__ == '__main__':
47 unittest.main()
47 unittest.main()
@@ -1,42 +1,42 b''
1 # Standard library imports
1 # Standard library imports
2 import unittest
2 import unittest
3
3
4 # System library imports
4 # System library imports
5 from IPython.external.qt import QtGui
5 from IPython.external.qt import QtGui
6
6
7 # Local imports
7 # Local imports
8 from IPython.frontend.qt.console.console_widget import ConsoleWidget
8 from IPython.frontend.qt.console.console_widget import ConsoleWidget
9
9
10
10
11 class TestConsoleWidget(unittest.TestCase):
11 class TestConsoleWidget(unittest.TestCase):
12
12
13 @classmethod
13 @classmethod
14 def setUpClass(cls):
14 def setUpClass(cls):
15 """ Create the application for the test case.
15 """ Create the application for the test case.
16 """
16 """
17 cls._app = QtGui.QApplication.instance()
17 cls._app = QtGui.QApplication.instance()
18 if cls._app is None:
18 if cls._app is None:
19 cls._app = QtGui.QApplication([])
19 cls._app = QtGui.QApplication([])
20 cls._app.setQuitOnLastWindowClosed(False)
20 cls._app.setQuitOnLastWindowClosed(False)
21
21
22 @classmethod
22 @classmethod
23 def tearDownClass(cls):
23 def tearDownClass(cls):
24 """ Exit the application.
24 """ Exit the application.
25 """
25 """
26 QtGui.QApplication.quit()
26 QtGui.QApplication.quit()
27
27
28 def test_special_characters(self):
28 def test_special_characters(self):
29 """ Are special characters displayed correctly?
29 """ Are special characters displayed correctly?
30 """
30 """
31 w = ConsoleWidget()
31 w = ConsoleWidget()
32 cursor = w._get_prompt_cursor()
32 cursor = w._get_prompt_cursor()
33
33
34 test_inputs = ['xyz\b\b=\n', 'foo\b\nbar\n', 'foo\b\nbar\r\n', 'abc\rxyz\b\b=']
34 test_inputs = ['xyz\b\b=\n', 'foo\b\nbar\n', 'foo\b\nbar\r\n', 'abc\rxyz\b\b=']
35 expected_outputs = [u'x=z\u2029', u'foo\u2029bar\u2029', u'foo\u2029bar\u2029', 'x=z']
35 expected_outputs = [u'x=z\u2029', u'foo\u2029bar\u2029', u'foo\u2029bar\u2029', 'x=z']
36 for i, text in enumerate(test_inputs):
36 for i, text in enumerate(test_inputs):
37 w._insert_plain_text(cursor, text)
37 w._insert_plain_text(cursor, text)
38 cursor.select(cursor.Document)
38 cursor.select(cursor.Document)
39 selection = cursor.selectedText()
39 selection = cursor.selectedText()
40 self.assertEquals(expected_outputs[i], selection)
40 self.assertEqual(expected_outputs[i], selection)
41 # clear all the text
41 # clear all the text
42 cursor.insertText('')
42 cursor.insertText('')
@@ -1,171 +1,171 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Tests for the key interactiveshell module.
2 """Tests for the key interactiveshell module.
3
3
4 Authors
4 Authors
5 -------
5 -------
6 * Julian Taylor
6 * Julian Taylor
7 """
7 """
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18 # stdlib
18 # stdlib
19 import sys
19 import sys
20 import unittest
20 import unittest
21
21
22 from IPython.testing.decorators import skipif
22 from IPython.testing.decorators import skipif
23 from IPython.utils import py3compat
23 from IPython.utils import py3compat
24
24
25 class InteractiveShellTestCase(unittest.TestCase):
25 class InteractiveShellTestCase(unittest.TestCase):
26 def rl_hist_entries(self, rl, n):
26 def rl_hist_entries(self, rl, n):
27 """Get last n readline history entries as a list"""
27 """Get last n readline history entries as a list"""
28 return [rl.get_history_item(rl.get_current_history_length() - x)
28 return [rl.get_history_item(rl.get_current_history_length() - x)
29 for x in range(n - 1, -1, -1)]
29 for x in range(n - 1, -1, -1)]
30
30
31 def test_runs_without_rl(self):
31 def test_runs_without_rl(self):
32 """Test that function does not throw without readline"""
32 """Test that function does not throw without readline"""
33 ip = get_ipython()
33 ip = get_ipython()
34 ip.has_readline = False
34 ip.has_readline = False
35 ip.readline = None
35 ip.readline = None
36 ip._replace_rlhist_multiline(u'source', 0)
36 ip._replace_rlhist_multiline(u'source', 0)
37
37
38 @skipif(not get_ipython().has_readline, 'no readline')
38 @skipif(not get_ipython().has_readline, 'no readline')
39 def test_runs_without_remove_history_item(self):
39 def test_runs_without_remove_history_item(self):
40 """Test that function does not throw on windows without
40 """Test that function does not throw on windows without
41 remove_history_item"""
41 remove_history_item"""
42 ip = get_ipython()
42 ip = get_ipython()
43 if hasattr(ip.readline, 'remove_history_item'):
43 if hasattr(ip.readline, 'remove_history_item'):
44 del ip.readline.remove_history_item
44 del ip.readline.remove_history_item
45 ip._replace_rlhist_multiline(u'source', 0)
45 ip._replace_rlhist_multiline(u'source', 0)
46
46
47 @skipif(not get_ipython().has_readline, 'no readline')
47 @skipif(not get_ipython().has_readline, 'no readline')
48 @skipif(not hasattr(get_ipython().readline, 'remove_history_item'),
48 @skipif(not hasattr(get_ipython().readline, 'remove_history_item'),
49 'no remove_history_item')
49 'no remove_history_item')
50 def test_replace_multiline_hist_disabled(self):
50 def test_replace_multiline_hist_disabled(self):
51 """Test that multiline replace does nothing if disabled"""
51 """Test that multiline replace does nothing if disabled"""
52 ip = get_ipython()
52 ip = get_ipython()
53 ip.multiline_history = False
53 ip.multiline_history = False
54
54
55 ghist = [u'line1', u'line2']
55 ghist = [u'line1', u'line2']
56 for h in ghist:
56 for h in ghist:
57 ip.readline.add_history(h)
57 ip.readline.add_history(h)
58 hlen_b4_cell = ip.readline.get_current_history_length()
58 hlen_b4_cell = ip.readline.get_current_history_length()
59 hlen_b4_cell = ip._replace_rlhist_multiline(u'sourc€\nsource2',
59 hlen_b4_cell = ip._replace_rlhist_multiline(u'sourc€\nsource2',
60 hlen_b4_cell)
60 hlen_b4_cell)
61
61
62 self.assertEquals(ip.readline.get_current_history_length(),
62 self.assertEqual(ip.readline.get_current_history_length(),
63 hlen_b4_cell)
63 hlen_b4_cell)
64 hist = self.rl_hist_entries(ip.readline, 2)
64 hist = self.rl_hist_entries(ip.readline, 2)
65 self.assertEquals(hist, ghist)
65 self.assertEqual(hist, ghist)
66
66
67 @skipif(not get_ipython().has_readline, 'no readline')
67 @skipif(not get_ipython().has_readline, 'no readline')
68 @skipif(not hasattr(get_ipython().readline, 'remove_history_item'),
68 @skipif(not hasattr(get_ipython().readline, 'remove_history_item'),
69 'no remove_history_item')
69 'no remove_history_item')
70 def test_replace_multiline_hist_adds(self):
70 def test_replace_multiline_hist_adds(self):
71 """Test that multiline replace function adds history"""
71 """Test that multiline replace function adds history"""
72 ip = get_ipython()
72 ip = get_ipython()
73
73
74 hlen_b4_cell = ip.readline.get_current_history_length()
74 hlen_b4_cell = ip.readline.get_current_history_length()
75 hlen_b4_cell = ip._replace_rlhist_multiline(u'sourc€', hlen_b4_cell)
75 hlen_b4_cell = ip._replace_rlhist_multiline(u'sourc€', hlen_b4_cell)
76
76
77 self.assertEquals(hlen_b4_cell,
77 self.assertEqual(hlen_b4_cell,
78 ip.readline.get_current_history_length())
78 ip.readline.get_current_history_length())
79
79
80 @skipif(not get_ipython().has_readline, 'no readline')
80 @skipif(not get_ipython().has_readline, 'no readline')
81 @skipif(not hasattr(get_ipython().readline, 'remove_history_item'),
81 @skipif(not hasattr(get_ipython().readline, 'remove_history_item'),
82 'no remove_history_item')
82 'no remove_history_item')
83 def test_replace_multiline_hist_keeps_history(self):
83 def test_replace_multiline_hist_keeps_history(self):
84 """Test that multiline replace does not delete history"""
84 """Test that multiline replace does not delete history"""
85 ip = get_ipython()
85 ip = get_ipython()
86 ip.multiline_history = True
86 ip.multiline_history = True
87
87
88 ghist = [u'line1', u'line2']
88 ghist = [u'line1', u'line2']
89 for h in ghist:
89 for h in ghist:
90 ip.readline.add_history(h)
90 ip.readline.add_history(h)
91
91
92 #start cell
92 #start cell
93 hlen_b4_cell = ip.readline.get_current_history_length()
93 hlen_b4_cell = ip.readline.get_current_history_length()
94 # nothing added to rl history, should do nothing
94 # nothing added to rl history, should do nothing
95 hlen_b4_cell = ip._replace_rlhist_multiline(u'sourc€\nsource2',
95 hlen_b4_cell = ip._replace_rlhist_multiline(u'sourc€\nsource2',
96 hlen_b4_cell)
96 hlen_b4_cell)
97
97
98 self.assertEquals(ip.readline.get_current_history_length(),
98 self.assertEqual(ip.readline.get_current_history_length(),
99 hlen_b4_cell)
99 hlen_b4_cell)
100 hist = self.rl_hist_entries(ip.readline, 2)
100 hist = self.rl_hist_entries(ip.readline, 2)
101 self.assertEquals(hist, ghist)
101 self.assertEqual(hist, ghist)
102
102
103
103
104 @skipif(not get_ipython().has_readline, 'no readline')
104 @skipif(not get_ipython().has_readline, 'no readline')
105 @skipif(not hasattr(get_ipython().readline, 'remove_history_item'),
105 @skipif(not hasattr(get_ipython().readline, 'remove_history_item'),
106 'no remove_history_item')
106 'no remove_history_item')
107 def test_replace_multiline_hist_replaces_twice(self):
107 def test_replace_multiline_hist_replaces_twice(self):
108 """Test that multiline entries are replaced twice"""
108 """Test that multiline entries are replaced twice"""
109 ip = get_ipython()
109 ip = get_ipython()
110 ip.multiline_history = True
110 ip.multiline_history = True
111
111
112 ip.readline.add_history(u'line0')
112 ip.readline.add_history(u'line0')
113 #start cell
113 #start cell
114 hlen_b4_cell = ip.readline.get_current_history_length()
114 hlen_b4_cell = ip.readline.get_current_history_length()
115 ip.readline.add_history('l€ne1')
115 ip.readline.add_history('l€ne1')
116 ip.readline.add_history('line2')
116 ip.readline.add_history('line2')
117 #replace cell with single line
117 #replace cell with single line
118 hlen_b4_cell = ip._replace_rlhist_multiline(u'l€ne1\nline2',
118 hlen_b4_cell = ip._replace_rlhist_multiline(u'l€ne1\nline2',
119 hlen_b4_cell)
119 hlen_b4_cell)
120 ip.readline.add_history('l€ne3')
120 ip.readline.add_history('l€ne3')
121 ip.readline.add_history('line4')
121 ip.readline.add_history('line4')
122 #replace cell with single line
122 #replace cell with single line
123 hlen_b4_cell = ip._replace_rlhist_multiline(u'l€ne3\nline4',
123 hlen_b4_cell = ip._replace_rlhist_multiline(u'l€ne3\nline4',
124 hlen_b4_cell)
124 hlen_b4_cell)
125
125
126 self.assertEquals(ip.readline.get_current_history_length(),
126 self.assertEqual(ip.readline.get_current_history_length(),
127 hlen_b4_cell)
127 hlen_b4_cell)
128 hist = self.rl_hist_entries(ip.readline, 3)
128 hist = self.rl_hist_entries(ip.readline, 3)
129 expected = [u'line0', u'l€ne1\nline2', u'l€ne3\nline4']
129 expected = [u'line0', u'l€ne1\nline2', u'l€ne3\nline4']
130 # perform encoding, in case of casting due to ASCII locale
130 # perform encoding, in case of casting due to ASCII locale
131 enc = sys.stdin.encoding or "utf-8"
131 enc = sys.stdin.encoding or "utf-8"
132 expected = [ py3compat.unicode_to_str(e, enc) for e in expected ]
132 expected = [ py3compat.unicode_to_str(e, enc) for e in expected ]
133 self.assertEquals(hist, expected)
133 self.assertEqual(hist, expected)
134
134
135
135
136 @skipif(not get_ipython().has_readline, 'no readline')
136 @skipif(not get_ipython().has_readline, 'no readline')
137 @skipif(not hasattr(get_ipython().readline, 'remove_history_item'),
137 @skipif(not hasattr(get_ipython().readline, 'remove_history_item'),
138 'no remove_history_item')
138 'no remove_history_item')
139 def test_replace_multiline_hist_replaces_empty_line(self):
139 def test_replace_multiline_hist_replaces_empty_line(self):
140 """Test that multiline history skips empty line cells"""
140 """Test that multiline history skips empty line cells"""
141 ip = get_ipython()
141 ip = get_ipython()
142 ip.multiline_history = True
142 ip.multiline_history = True
143
143
144 ip.readline.add_history(u'line0')
144 ip.readline.add_history(u'line0')
145 #start cell
145 #start cell
146 hlen_b4_cell = ip.readline.get_current_history_length()
146 hlen_b4_cell = ip.readline.get_current_history_length()
147 ip.readline.add_history('l€ne1')
147 ip.readline.add_history('l€ne1')
148 ip.readline.add_history('line2')
148 ip.readline.add_history('line2')
149 hlen_b4_cell = ip._replace_rlhist_multiline(u'l€ne1\nline2',
149 hlen_b4_cell = ip._replace_rlhist_multiline(u'l€ne1\nline2',
150 hlen_b4_cell)
150 hlen_b4_cell)
151 ip.readline.add_history('')
151 ip.readline.add_history('')
152 hlen_b4_cell = ip._replace_rlhist_multiline(u'', hlen_b4_cell)
152 hlen_b4_cell = ip._replace_rlhist_multiline(u'', hlen_b4_cell)
153 ip.readline.add_history('l€ne3')
153 ip.readline.add_history('l€ne3')
154 hlen_b4_cell = ip._replace_rlhist_multiline(u'l€ne3', hlen_b4_cell)
154 hlen_b4_cell = ip._replace_rlhist_multiline(u'l€ne3', hlen_b4_cell)
155 ip.readline.add_history(' ')
155 ip.readline.add_history(' ')
156 hlen_b4_cell = ip._replace_rlhist_multiline(' ', hlen_b4_cell)
156 hlen_b4_cell = ip._replace_rlhist_multiline(' ', hlen_b4_cell)
157 ip.readline.add_history('\t')
157 ip.readline.add_history('\t')
158 ip.readline.add_history('\t ')
158 ip.readline.add_history('\t ')
159 hlen_b4_cell = ip._replace_rlhist_multiline('\t', hlen_b4_cell)
159 hlen_b4_cell = ip._replace_rlhist_multiline('\t', hlen_b4_cell)
160 ip.readline.add_history('line4')
160 ip.readline.add_history('line4')
161 hlen_b4_cell = ip._replace_rlhist_multiline(u'line4', hlen_b4_cell)
161 hlen_b4_cell = ip._replace_rlhist_multiline(u'line4', hlen_b4_cell)
162
162
163 self.assertEquals(ip.readline.get_current_history_length(),
163 self.assertEqual(ip.readline.get_current_history_length(),
164 hlen_b4_cell)
164 hlen_b4_cell)
165 hist = self.rl_hist_entries(ip.readline, 4)
165 hist = self.rl_hist_entries(ip.readline, 4)
166 # expect no empty cells in history
166 # expect no empty cells in history
167 expected = [u'line0', u'l€ne1\nline2', u'l€ne3', u'line4']
167 expected = [u'line0', u'l€ne1\nline2', u'l€ne3', u'line4']
168 # perform encoding, in case of casting due to ASCII locale
168 # perform encoding, in case of casting due to ASCII locale
169 enc = sys.stdin.encoding or "utf-8"
169 enc = sys.stdin.encoding or "utf-8"
170 expected = [ py3compat.unicode_to_str(e, enc) for e in expected ]
170 expected = [ py3compat.unicode_to_str(e, enc) for e in expected ]
171 self.assertEquals(hist, expected)
171 self.assertEqual(hist, expected)
@@ -1,14 +1,14 b''
1 from unittest import TestCase
1 from unittest import TestCase
2
2
3 from ..nbjson import reads, writes
3 from ..nbjson import reads, writes
4 from .nbexamples import nb0
4 from .nbexamples import nb0
5
5
6
6
7 class TestJSON(TestCase):
7 class TestJSON(TestCase):
8
8
9 def test_roundtrip(self):
9 def test_roundtrip(self):
10 s = writes(nb0)
10 s = writes(nb0)
11 self.assertEquals(reads(s),nb0)
11 self.assertEqual(reads(s),nb0)
12
12
13
13
14
14
@@ -1,41 +1,41 b''
1 from unittest import TestCase
1 from unittest import TestCase
2
2
3 from ..nbbase import (
3 from ..nbbase import (
4 NotebookNode,
4 NotebookNode,
5 new_code_cell, new_text_cell, new_notebook
5 new_code_cell, new_text_cell, new_notebook
6 )
6 )
7
7
8 class TestCell(TestCase):
8 class TestCell(TestCase):
9
9
10 def test_empty_code_cell(self):
10 def test_empty_code_cell(self):
11 cc = new_code_cell()
11 cc = new_code_cell()
12 self.assertEquals(cc.cell_type,'code')
12 self.assertEqual(cc.cell_type,'code')
13 self.assertEquals('code' not in cc, True)
13 self.assertEqual('code' not in cc, True)
14 self.assertEquals('prompt_number' not in cc, True)
14 self.assertEqual('prompt_number' not in cc, True)
15
15
16 def test_code_cell(self):
16 def test_code_cell(self):
17 cc = new_code_cell(code='a=10', prompt_number=0)
17 cc = new_code_cell(code='a=10', prompt_number=0)
18 self.assertEquals(cc.code, u'a=10')
18 self.assertEqual(cc.code, u'a=10')
19 self.assertEquals(cc.prompt_number, 0)
19 self.assertEqual(cc.prompt_number, 0)
20
20
21 def test_empty_text_cell(self):
21 def test_empty_text_cell(self):
22 tc = new_text_cell()
22 tc = new_text_cell()
23 self.assertEquals(tc.cell_type, 'text')
23 self.assertEqual(tc.cell_type, 'text')
24 self.assertEquals('text' not in tc, True)
24 self.assertEqual('text' not in tc, True)
25
25
26 def test_text_cell(self):
26 def test_text_cell(self):
27 tc = new_text_cell('hi')
27 tc = new_text_cell('hi')
28 self.assertEquals(tc.text, u'hi')
28 self.assertEqual(tc.text, u'hi')
29
29
30
30
31 class TestNotebook(TestCase):
31 class TestNotebook(TestCase):
32
32
33 def test_empty_notebook(self):
33 def test_empty_notebook(self):
34 nb = new_notebook()
34 nb = new_notebook()
35 self.assertEquals(nb.cells, [])
35 self.assertEqual(nb.cells, [])
36
36
37 def test_notebooke(self):
37 def test_notebooke(self):
38 cells = [new_code_cell(),new_text_cell()]
38 cells = [new_code_cell(),new_text_cell()]
39 nb = new_notebook(cells=cells)
39 nb = new_notebook(cells=cells)
40 self.assertEquals(nb.cells,cells)
40 self.assertEqual(nb.cells,cells)
41
41
@@ -1,34 +1,34 b''
1 import pprint
1 import pprint
2 from unittest import TestCase
2 from unittest import TestCase
3
3
4 from ..nbjson import reads, writes
4 from ..nbjson import reads, writes
5 from .nbexamples import nb0
5 from .nbexamples import nb0
6
6
7
7
8 class TestJSON(TestCase):
8 class TestJSON(TestCase):
9
9
10 def test_roundtrip(self):
10 def test_roundtrip(self):
11 s = writes(nb0)
11 s = writes(nb0)
12 # print
12 # print
13 # print pprint.pformat(nb0,indent=2)
13 # print pprint.pformat(nb0,indent=2)
14 # print
14 # print
15 # print pprint.pformat(reads(s),indent=2)
15 # print pprint.pformat(reads(s),indent=2)
16 # print
16 # print
17 # print s
17 # print s
18 self.assertEquals(reads(s),nb0)
18 self.assertEqual(reads(s),nb0)
19
19
20 def test_roundtrip_nosplit(self):
20 def test_roundtrip_nosplit(self):
21 """Ensure that multiline blobs are still readable"""
21 """Ensure that multiline blobs are still readable"""
22 # ensures that notebooks written prior to splitlines change
22 # ensures that notebooks written prior to splitlines change
23 # are still readable.
23 # are still readable.
24 s = writes(nb0, split_lines=False)
24 s = writes(nb0, split_lines=False)
25 self.assertEquals(reads(s),nb0)
25 self.assertEqual(reads(s),nb0)
26
26
27 def test_roundtrip_split(self):
27 def test_roundtrip_split(self):
28 """Ensure that splitting multiline blocks is safe"""
28 """Ensure that splitting multiline blocks is safe"""
29 # This won't differ from test_roundtrip unless the default changes
29 # This won't differ from test_roundtrip unless the default changes
30 s = writes(nb0, split_lines=True)
30 s = writes(nb0, split_lines=True)
31 self.assertEquals(reads(s),nb0)
31 self.assertEqual(reads(s),nb0)
32
32
33
33
34
34
@@ -1,113 +1,113 b''
1 from unittest import TestCase
1 from unittest import TestCase
2
2
3 from ..nbbase import (
3 from ..nbbase import (
4 NotebookNode,
4 NotebookNode,
5 new_code_cell, new_text_cell, new_worksheet, new_notebook, new_output,
5 new_code_cell, new_text_cell, new_worksheet, new_notebook, new_output,
6 new_author, new_metadata
6 new_author, new_metadata
7 )
7 )
8
8
9 class TestCell(TestCase):
9 class TestCell(TestCase):
10
10
11 def test_empty_code_cell(self):
11 def test_empty_code_cell(self):
12 cc = new_code_cell()
12 cc = new_code_cell()
13 self.assertEquals(cc.cell_type,u'code')
13 self.assertEqual(cc.cell_type,u'code')
14 self.assertEquals(u'input' not in cc, True)
14 self.assertEqual(u'input' not in cc, True)
15 self.assertEquals(u'prompt_number' not in cc, True)
15 self.assertEqual(u'prompt_number' not in cc, True)
16 self.assertEquals(cc.outputs, [])
16 self.assertEqual(cc.outputs, [])
17 self.assertEquals(cc.collapsed, False)
17 self.assertEqual(cc.collapsed, False)
18
18
19 def test_code_cell(self):
19 def test_code_cell(self):
20 cc = new_code_cell(input='a=10', prompt_number=0, collapsed=True)
20 cc = new_code_cell(input='a=10', prompt_number=0, collapsed=True)
21 cc.outputs = [new_output(output_type=u'pyout',
21 cc.outputs = [new_output(output_type=u'pyout',
22 output_svg=u'foo',output_text=u'10',prompt_number=0)]
22 output_svg=u'foo',output_text=u'10',prompt_number=0)]
23 self.assertEquals(cc.input, u'a=10')
23 self.assertEqual(cc.input, u'a=10')
24 self.assertEquals(cc.prompt_number, 0)
24 self.assertEqual(cc.prompt_number, 0)
25 self.assertEquals(cc.language, u'python')
25 self.assertEqual(cc.language, u'python')
26 self.assertEquals(cc.outputs[0].svg, u'foo')
26 self.assertEqual(cc.outputs[0].svg, u'foo')
27 self.assertEquals(cc.outputs[0].text, u'10')
27 self.assertEqual(cc.outputs[0].text, u'10')
28 self.assertEquals(cc.outputs[0].prompt_number, 0)
28 self.assertEqual(cc.outputs[0].prompt_number, 0)
29 self.assertEquals(cc.collapsed, True)
29 self.assertEqual(cc.collapsed, True)
30
30
31 def test_pyerr(self):
31 def test_pyerr(self):
32 o = new_output(output_type=u'pyerr', etype=u'NameError',
32 o = new_output(output_type=u'pyerr', etype=u'NameError',
33 evalue=u'Name not found', traceback=[u'frame 0', u'frame 1', u'frame 2']
33 evalue=u'Name not found', traceback=[u'frame 0', u'frame 1', u'frame 2']
34 )
34 )
35 self.assertEquals(o.output_type, u'pyerr')
35 self.assertEqual(o.output_type, u'pyerr')
36 self.assertEquals(o.etype, u'NameError')
36 self.assertEqual(o.etype, u'NameError')
37 self.assertEquals(o.evalue, u'Name not found')
37 self.assertEqual(o.evalue, u'Name not found')
38 self.assertEquals(o.traceback, [u'frame 0', u'frame 1', u'frame 2'])
38 self.assertEqual(o.traceback, [u'frame 0', u'frame 1', u'frame 2'])
39
39
40 def test_empty_html_cell(self):
40 def test_empty_html_cell(self):
41 tc = new_text_cell(u'html')
41 tc = new_text_cell(u'html')
42 self.assertEquals(tc.cell_type, u'html')
42 self.assertEqual(tc.cell_type, u'html')
43 self.assertEquals(u'source' not in tc, True)
43 self.assertEqual(u'source' not in tc, True)
44 self.assertEquals(u'rendered' not in tc, True)
44 self.assertEqual(u'rendered' not in tc, True)
45
45
46 def test_html_cell(self):
46 def test_html_cell(self):
47 tc = new_text_cell(u'html', 'hi', 'hi')
47 tc = new_text_cell(u'html', 'hi', 'hi')
48 self.assertEquals(tc.source, u'hi')
48 self.assertEqual(tc.source, u'hi')
49 self.assertEquals(tc.rendered, u'hi')
49 self.assertEqual(tc.rendered, u'hi')
50
50
51 def test_empty_markdown_cell(self):
51 def test_empty_markdown_cell(self):
52 tc = new_text_cell(u'markdown')
52 tc = new_text_cell(u'markdown')
53 self.assertEquals(tc.cell_type, u'markdown')
53 self.assertEqual(tc.cell_type, u'markdown')
54 self.assertEquals(u'source' not in tc, True)
54 self.assertEqual(u'source' not in tc, True)
55 self.assertEquals(u'rendered' not in tc, True)
55 self.assertEqual(u'rendered' not in tc, True)
56
56
57 def test_markdown_cell(self):
57 def test_markdown_cell(self):
58 tc = new_text_cell(u'markdown', 'hi', 'hi')
58 tc = new_text_cell(u'markdown', 'hi', 'hi')
59 self.assertEquals(tc.source, u'hi')
59 self.assertEqual(tc.source, u'hi')
60 self.assertEquals(tc.rendered, u'hi')
60 self.assertEqual(tc.rendered, u'hi')
61
61
62
62
63 class TestWorksheet(TestCase):
63 class TestWorksheet(TestCase):
64
64
65 def test_empty_worksheet(self):
65 def test_empty_worksheet(self):
66 ws = new_worksheet()
66 ws = new_worksheet()
67 self.assertEquals(ws.cells,[])
67 self.assertEqual(ws.cells,[])
68 self.assertEquals(u'name' not in ws, True)
68 self.assertEqual(u'name' not in ws, True)
69
69
70 def test_worksheet(self):
70 def test_worksheet(self):
71 cells = [new_code_cell(), new_text_cell(u'html')]
71 cells = [new_code_cell(), new_text_cell(u'html')]
72 ws = new_worksheet(cells=cells,name=u'foo')
72 ws = new_worksheet(cells=cells,name=u'foo')
73 self.assertEquals(ws.cells,cells)
73 self.assertEqual(ws.cells,cells)
74 self.assertEquals(ws.name,u'foo')
74 self.assertEqual(ws.name,u'foo')
75
75
76 class TestNotebook(TestCase):
76 class TestNotebook(TestCase):
77
77
78 def test_empty_notebook(self):
78 def test_empty_notebook(self):
79 nb = new_notebook()
79 nb = new_notebook()
80 self.assertEquals(nb.worksheets, [])
80 self.assertEqual(nb.worksheets, [])
81 self.assertEquals(nb.metadata, NotebookNode())
81 self.assertEqual(nb.metadata, NotebookNode())
82 self.assertEquals(nb.nbformat,2)
82 self.assertEqual(nb.nbformat,2)
83
83
84 def test_notebook(self):
84 def test_notebook(self):
85 worksheets = [new_worksheet(),new_worksheet()]
85 worksheets = [new_worksheet(),new_worksheet()]
86 metadata = new_metadata(name=u'foo')
86 metadata = new_metadata(name=u'foo')
87 nb = new_notebook(metadata=metadata,worksheets=worksheets)
87 nb = new_notebook(metadata=metadata,worksheets=worksheets)
88 self.assertEquals(nb.metadata.name,u'foo')
88 self.assertEqual(nb.metadata.name,u'foo')
89 self.assertEquals(nb.worksheets,worksheets)
89 self.assertEqual(nb.worksheets,worksheets)
90 self.assertEquals(nb.nbformat,2)
90 self.assertEqual(nb.nbformat,2)
91
91
92 class TestMetadata(TestCase):
92 class TestMetadata(TestCase):
93
93
94 def test_empty_metadata(self):
94 def test_empty_metadata(self):
95 md = new_metadata()
95 md = new_metadata()
96 self.assertEquals(u'name' not in md, True)
96 self.assertEqual(u'name' not in md, True)
97 self.assertEquals(u'authors' not in md, True)
97 self.assertEqual(u'authors' not in md, True)
98 self.assertEquals(u'license' not in md, True)
98 self.assertEqual(u'license' not in md, True)
99 self.assertEquals(u'saved' not in md, True)
99 self.assertEqual(u'saved' not in md, True)
100 self.assertEquals(u'modified' not in md, True)
100 self.assertEqual(u'modified' not in md, True)
101 self.assertEquals(u'gistid' not in md, True)
101 self.assertEqual(u'gistid' not in md, True)
102
102
103 def test_metadata(self):
103 def test_metadata(self):
104 authors = [new_author(name='Bart Simpson',email='bsimpson@fox.com')]
104 authors = [new_author(name='Bart Simpson',email='bsimpson@fox.com')]
105 md = new_metadata(name=u'foo',license=u'BSD',created=u'today',
105 md = new_metadata(name=u'foo',license=u'BSD',created=u'today',
106 modified=u'now',gistid=u'21341231',authors=authors)
106 modified=u'now',gistid=u'21341231',authors=authors)
107 self.assertEquals(md.name, u'foo')
107 self.assertEqual(md.name, u'foo')
108 self.assertEquals(md.license, u'BSD')
108 self.assertEqual(md.license, u'BSD')
109 self.assertEquals(md.created, u'today')
109 self.assertEqual(md.created, u'today')
110 self.assertEquals(md.modified, u'now')
110 self.assertEqual(md.modified, u'now')
111 self.assertEquals(md.gistid, u'21341231')
111 self.assertEqual(md.gistid, u'21341231')
112 self.assertEquals(md.authors, authors)
112 self.assertEqual(md.authors, authors)
113
113
@@ -1,17 +1,17 b''
1 from unittest import TestCase
1 from unittest import TestCase
2
2
3 from ..nbbase import (
3 from ..nbbase import (
4 NotebookNode,
4 NotebookNode,
5 new_code_cell, new_text_cell, new_worksheet, new_notebook
5 new_code_cell, new_text_cell, new_worksheet, new_notebook
6 )
6 )
7
7
8 from ..nbpy import reads, writes
8 from ..nbpy import reads, writes
9 from .nbexamples import nb0, nb0_py
9 from .nbexamples import nb0, nb0_py
10
10
11
11
12 class TestPy(TestCase):
12 class TestPy(TestCase):
13
13
14 def test_write(self):
14 def test_write(self):
15 s = writes(nb0)
15 s = writes(nb0)
16 self.assertEquals(s,nb0_py)
16 self.assertEqual(s,nb0_py)
17
17
@@ -1,63 +1,63 b''
1 # -*- coding: utf8 -*-
1 # -*- coding: utf8 -*-
2 import io
2 import io
3 import os
3 import os
4 import shutil
4 import shutil
5 import tempfile
5 import tempfile
6
6
7 pjoin = os.path.join
7 pjoin = os.path.join
8
8
9 from ..nbbase import (
9 from ..nbbase import (
10 NotebookNode,
10 NotebookNode,
11 new_code_cell, new_text_cell, new_worksheet, new_notebook
11 new_code_cell, new_text_cell, new_worksheet, new_notebook
12 )
12 )
13
13
14 from ..nbpy import reads, writes, read, write
14 from ..nbpy import reads, writes, read, write
15 from .nbexamples import nb0, nb0_py
15 from .nbexamples import nb0, nb0_py
16
16
17
17
18 def open_utf8(fname, mode):
18 def open_utf8(fname, mode):
19 return io.open(fname, mode=mode, encoding='utf-8')
19 return io.open(fname, mode=mode, encoding='utf-8')
20
20
21 class NBFormatTest:
21 class NBFormatTest:
22 """Mixin for writing notebook format tests"""
22 """Mixin for writing notebook format tests"""
23
23
24 # override with appropriate values in subclasses
24 # override with appropriate values in subclasses
25 nb0_ref = None
25 nb0_ref = None
26 ext = None
26 ext = None
27 mod = None
27 mod = None
28
28
29 def setUp(self):
29 def setUp(self):
30 self.wd = tempfile.mkdtemp()
30 self.wd = tempfile.mkdtemp()
31
31
32 def tearDown(self):
32 def tearDown(self):
33 shutil.rmtree(self.wd)
33 shutil.rmtree(self.wd)
34
34
35 def assertNBEquals(self, nba, nbb):
35 def assertNBEquals(self, nba, nbb):
36 self.assertEquals(nba, nbb)
36 self.assertEqual(nba, nbb)
37
37
38 def test_writes(self):
38 def test_writes(self):
39 s = self.mod.writes(nb0)
39 s = self.mod.writes(nb0)
40 if self.nb0_ref:
40 if self.nb0_ref:
41 self.assertEquals(s, self.nb0_ref)
41 self.assertEqual(s, self.nb0_ref)
42
42
43 def test_reads(self):
43 def test_reads(self):
44 s = self.mod.writes(nb0)
44 s = self.mod.writes(nb0)
45 nb = self.mod.reads(s)
45 nb = self.mod.reads(s)
46
46
47 def test_roundtrip(self):
47 def test_roundtrip(self):
48 s = self.mod.writes(nb0)
48 s = self.mod.writes(nb0)
49 self.assertNBEquals(self.mod.reads(s),nb0)
49 self.assertNBEquals(self.mod.reads(s),nb0)
50
50
51 def test_write_file(self):
51 def test_write_file(self):
52 with open_utf8(pjoin(self.wd, "nb0.%s" % self.ext), 'w') as f:
52 with open_utf8(pjoin(self.wd, "nb0.%s" % self.ext), 'w') as f:
53 self.mod.write(nb0, f)
53 self.mod.write(nb0, f)
54
54
55 def test_read_file(self):
55 def test_read_file(self):
56 with open_utf8(pjoin(self.wd, "nb0.%s" % self.ext), 'w') as f:
56 with open_utf8(pjoin(self.wd, "nb0.%s" % self.ext), 'w') as f:
57 self.mod.write(nb0, f)
57 self.mod.write(nb0, f)
58
58
59 with open_utf8(pjoin(self.wd, "nb0.%s" % self.ext), 'r') as f:
59 with open_utf8(pjoin(self.wd, "nb0.%s" % self.ext), 'r') as f:
60 nb = self.mod.read(f)
60 nb = self.mod.read(f)
61
61
62
62
63
63
@@ -1,33 +1,33 b''
1 import pprint
1 import pprint
2 from unittest import TestCase
2 from unittest import TestCase
3
3
4 from ..nbjson import reads, writes
4 from ..nbjson import reads, writes
5 from .. import nbjson
5 from .. import nbjson
6 from .nbexamples import nb0
6 from .nbexamples import nb0
7
7
8 from . import formattest
8 from . import formattest
9
9
10 from .nbexamples import nb0
10 from .nbexamples import nb0
11
11
12
12
13 class TestJSON(formattest.NBFormatTest, TestCase):
13 class TestJSON(formattest.NBFormatTest, TestCase):
14
14
15 nb0_ref = None
15 nb0_ref = None
16 ext = 'ipynb'
16 ext = 'ipynb'
17 mod = nbjson
17 mod = nbjson
18
18
19 def test_roundtrip_nosplit(self):
19 def test_roundtrip_nosplit(self):
20 """Ensure that multiline blobs are still readable"""
20 """Ensure that multiline blobs are still readable"""
21 # ensures that notebooks written prior to splitlines change
21 # ensures that notebooks written prior to splitlines change
22 # are still readable.
22 # are still readable.
23 s = writes(nb0, split_lines=False)
23 s = writes(nb0, split_lines=False)
24 self.assertEquals(nbjson.reads(s),nb0)
24 self.assertEqual(nbjson.reads(s),nb0)
25
25
26 def test_roundtrip_split(self):
26 def test_roundtrip_split(self):
27 """Ensure that splitting multiline blocks is safe"""
27 """Ensure that splitting multiline blocks is safe"""
28 # This won't differ from test_roundtrip unless the default changes
28 # This won't differ from test_roundtrip unless the default changes
29 s = writes(nb0, split_lines=True)
29 s = writes(nb0, split_lines=True)
30 self.assertEquals(nbjson.reads(s),nb0)
30 self.assertEqual(nbjson.reads(s),nb0)
31
31
32
32
33
33
@@ -1,143 +1,143 b''
1 from unittest import TestCase
1 from unittest import TestCase
2
2
3 from ..nbbase import (
3 from ..nbbase import (
4 NotebookNode,
4 NotebookNode,
5 new_code_cell, new_text_cell, new_worksheet, new_notebook, new_output,
5 new_code_cell, new_text_cell, new_worksheet, new_notebook, new_output,
6 new_author, new_metadata, new_heading_cell, nbformat
6 new_author, new_metadata, new_heading_cell, nbformat
7 )
7 )
8
8
9 class TestCell(TestCase):
9 class TestCell(TestCase):
10
10
11 def test_empty_code_cell(self):
11 def test_empty_code_cell(self):
12 cc = new_code_cell()
12 cc = new_code_cell()
13 self.assertEquals(cc.cell_type,u'code')
13 self.assertEqual(cc.cell_type,u'code')
14 self.assertEquals(u'input' not in cc, True)
14 self.assertEqual(u'input' not in cc, True)
15 self.assertEquals(u'prompt_number' not in cc, True)
15 self.assertEqual(u'prompt_number' not in cc, True)
16 self.assertEquals(cc.outputs, [])
16 self.assertEqual(cc.outputs, [])
17 self.assertEquals(cc.collapsed, False)
17 self.assertEqual(cc.collapsed, False)
18
18
19 def test_code_cell(self):
19 def test_code_cell(self):
20 cc = new_code_cell(input='a=10', prompt_number=0, collapsed=True)
20 cc = new_code_cell(input='a=10', prompt_number=0, collapsed=True)
21 cc.outputs = [new_output(output_type=u'pyout',
21 cc.outputs = [new_output(output_type=u'pyout',
22 output_svg=u'foo',output_text=u'10',prompt_number=0)]
22 output_svg=u'foo',output_text=u'10',prompt_number=0)]
23 self.assertEquals(cc.input, u'a=10')
23 self.assertEqual(cc.input, u'a=10')
24 self.assertEquals(cc.prompt_number, 0)
24 self.assertEqual(cc.prompt_number, 0)
25 self.assertEquals(cc.language, u'python')
25 self.assertEqual(cc.language, u'python')
26 self.assertEquals(cc.outputs[0].svg, u'foo')
26 self.assertEqual(cc.outputs[0].svg, u'foo')
27 self.assertEquals(cc.outputs[0].text, u'10')
27 self.assertEqual(cc.outputs[0].text, u'10')
28 self.assertEquals(cc.outputs[0].prompt_number, 0)
28 self.assertEqual(cc.outputs[0].prompt_number, 0)
29 self.assertEquals(cc.collapsed, True)
29 self.assertEqual(cc.collapsed, True)
30
30
31 def test_pyerr(self):
31 def test_pyerr(self):
32 o = new_output(output_type=u'pyerr', etype=u'NameError',
32 o = new_output(output_type=u'pyerr', etype=u'NameError',
33 evalue=u'Name not found', traceback=[u'frame 0', u'frame 1', u'frame 2']
33 evalue=u'Name not found', traceback=[u'frame 0', u'frame 1', u'frame 2']
34 )
34 )
35 self.assertEquals(o.output_type, u'pyerr')
35 self.assertEqual(o.output_type, u'pyerr')
36 self.assertEquals(o.etype, u'NameError')
36 self.assertEqual(o.etype, u'NameError')
37 self.assertEquals(o.evalue, u'Name not found')
37 self.assertEqual(o.evalue, u'Name not found')
38 self.assertEquals(o.traceback, [u'frame 0', u'frame 1', u'frame 2'])
38 self.assertEqual(o.traceback, [u'frame 0', u'frame 1', u'frame 2'])
39
39
40 def test_empty_html_cell(self):
40 def test_empty_html_cell(self):
41 tc = new_text_cell(u'html')
41 tc = new_text_cell(u'html')
42 self.assertEquals(tc.cell_type, u'html')
42 self.assertEqual(tc.cell_type, u'html')
43 self.assertEquals(u'source' not in tc, True)
43 self.assertEqual(u'source' not in tc, True)
44 self.assertEquals(u'rendered' not in tc, True)
44 self.assertEqual(u'rendered' not in tc, True)
45
45
46 def test_html_cell(self):
46 def test_html_cell(self):
47 tc = new_text_cell(u'html', 'hi', 'hi')
47 tc = new_text_cell(u'html', 'hi', 'hi')
48 self.assertEquals(tc.source, u'hi')
48 self.assertEqual(tc.source, u'hi')
49 self.assertEquals(tc.rendered, u'hi')
49 self.assertEqual(tc.rendered, u'hi')
50
50
51 def test_empty_markdown_cell(self):
51 def test_empty_markdown_cell(self):
52 tc = new_text_cell(u'markdown')
52 tc = new_text_cell(u'markdown')
53 self.assertEquals(tc.cell_type, u'markdown')
53 self.assertEqual(tc.cell_type, u'markdown')
54 self.assertEquals(u'source' not in tc, True)
54 self.assertEqual(u'source' not in tc, True)
55 self.assertEquals(u'rendered' not in tc, True)
55 self.assertEqual(u'rendered' not in tc, True)
56
56
57 def test_markdown_cell(self):
57 def test_markdown_cell(self):
58 tc = new_text_cell(u'markdown', 'hi', 'hi')
58 tc = new_text_cell(u'markdown', 'hi', 'hi')
59 self.assertEquals(tc.source, u'hi')
59 self.assertEqual(tc.source, u'hi')
60 self.assertEquals(tc.rendered, u'hi')
60 self.assertEqual(tc.rendered, u'hi')
61
61
62 def test_empty_raw_cell(self):
62 def test_empty_raw_cell(self):
63 tc = new_text_cell(u'raw')
63 tc = new_text_cell(u'raw')
64 self.assertEquals(tc.cell_type, u'raw')
64 self.assertEqual(tc.cell_type, u'raw')
65 self.assertEquals(u'source' not in tc, True)
65 self.assertEqual(u'source' not in tc, True)
66 self.assertEquals(u'rendered' not in tc, True)
66 self.assertEqual(u'rendered' not in tc, True)
67
67
68 def test_raw_cell(self):
68 def test_raw_cell(self):
69 tc = new_text_cell(u'raw', 'hi', 'hi')
69 tc = new_text_cell(u'raw', 'hi', 'hi')
70 self.assertEquals(tc.source, u'hi')
70 self.assertEqual(tc.source, u'hi')
71 self.assertEquals(tc.rendered, u'hi')
71 self.assertEqual(tc.rendered, u'hi')
72
72
73 def test_empty_heading_cell(self):
73 def test_empty_heading_cell(self):
74 tc = new_heading_cell()
74 tc = new_heading_cell()
75 self.assertEquals(tc.cell_type, u'heading')
75 self.assertEqual(tc.cell_type, u'heading')
76 self.assertEquals(u'source' not in tc, True)
76 self.assertEqual(u'source' not in tc, True)
77 self.assertEquals(u'rendered' not in tc, True)
77 self.assertEqual(u'rendered' not in tc, True)
78
78
79 def test_heading_cell(self):
79 def test_heading_cell(self):
80 tc = new_heading_cell(u'hi', u'hi', level=2)
80 tc = new_heading_cell(u'hi', u'hi', level=2)
81 self.assertEquals(tc.source, u'hi')
81 self.assertEqual(tc.source, u'hi')
82 self.assertEquals(tc.rendered, u'hi')
82 self.assertEqual(tc.rendered, u'hi')
83 self.assertEquals(tc.level, 2)
83 self.assertEqual(tc.level, 2)
84
84
85
85
86 class TestWorksheet(TestCase):
86 class TestWorksheet(TestCase):
87
87
88 def test_empty_worksheet(self):
88 def test_empty_worksheet(self):
89 ws = new_worksheet()
89 ws = new_worksheet()
90 self.assertEquals(ws.cells,[])
90 self.assertEqual(ws.cells,[])
91 self.assertEquals(u'name' not in ws, True)
91 self.assertEqual(u'name' not in ws, True)
92
92
93 def test_worksheet(self):
93 def test_worksheet(self):
94 cells = [new_code_cell(), new_text_cell(u'html')]
94 cells = [new_code_cell(), new_text_cell(u'html')]
95 ws = new_worksheet(cells=cells,name=u'foo')
95 ws = new_worksheet(cells=cells,name=u'foo')
96 self.assertEquals(ws.cells,cells)
96 self.assertEqual(ws.cells,cells)
97 self.assertEquals(ws.name,u'foo')
97 self.assertEqual(ws.name,u'foo')
98
98
99 class TestNotebook(TestCase):
99 class TestNotebook(TestCase):
100
100
101 def test_empty_notebook(self):
101 def test_empty_notebook(self):
102 nb = new_notebook()
102 nb = new_notebook()
103 self.assertEquals(nb.worksheets, [])
103 self.assertEqual(nb.worksheets, [])
104 self.assertEquals(nb.metadata, NotebookNode())
104 self.assertEqual(nb.metadata, NotebookNode())
105 self.assertEquals(nb.nbformat,nbformat)
105 self.assertEqual(nb.nbformat,nbformat)
106
106
107 def test_notebook(self):
107 def test_notebook(self):
108 worksheets = [new_worksheet(),new_worksheet()]
108 worksheets = [new_worksheet(),new_worksheet()]
109 metadata = new_metadata(name=u'foo')
109 metadata = new_metadata(name=u'foo')
110 nb = new_notebook(metadata=metadata,worksheets=worksheets)
110 nb = new_notebook(metadata=metadata,worksheets=worksheets)
111 self.assertEquals(nb.metadata.name,u'foo')
111 self.assertEqual(nb.metadata.name,u'foo')
112 self.assertEquals(nb.worksheets,worksheets)
112 self.assertEqual(nb.worksheets,worksheets)
113 self.assertEquals(nb.nbformat,nbformat)
113 self.assertEqual(nb.nbformat,nbformat)
114
114
115 def test_notebook_name(self):
115 def test_notebook_name(self):
116 worksheets = [new_worksheet(),new_worksheet()]
116 worksheets = [new_worksheet(),new_worksheet()]
117 nb = new_notebook(name='foo',worksheets=worksheets)
117 nb = new_notebook(name='foo',worksheets=worksheets)
118 self.assertEquals(nb.metadata.name,u'foo')
118 self.assertEqual(nb.metadata.name,u'foo')
119 self.assertEquals(nb.worksheets,worksheets)
119 self.assertEqual(nb.worksheets,worksheets)
120 self.assertEquals(nb.nbformat,nbformat)
120 self.assertEqual(nb.nbformat,nbformat)
121
121
122 class TestMetadata(TestCase):
122 class TestMetadata(TestCase):
123
123
124 def test_empty_metadata(self):
124 def test_empty_metadata(self):
125 md = new_metadata()
125 md = new_metadata()
126 self.assertEquals(u'name' not in md, True)
126 self.assertEqual(u'name' not in md, True)
127 self.assertEquals(u'authors' not in md, True)
127 self.assertEqual(u'authors' not in md, True)
128 self.assertEquals(u'license' not in md, True)
128 self.assertEqual(u'license' not in md, True)
129 self.assertEquals(u'saved' not in md, True)
129 self.assertEqual(u'saved' not in md, True)
130 self.assertEquals(u'modified' not in md, True)
130 self.assertEqual(u'modified' not in md, True)
131 self.assertEquals(u'gistid' not in md, True)
131 self.assertEqual(u'gistid' not in md, True)
132
132
133 def test_metadata(self):
133 def test_metadata(self):
134 authors = [new_author(name='Bart Simpson',email='bsimpson@fox.com')]
134 authors = [new_author(name='Bart Simpson',email='bsimpson@fox.com')]
135 md = new_metadata(name=u'foo',license=u'BSD',created=u'today',
135 md = new_metadata(name=u'foo',license=u'BSD',created=u'today',
136 modified=u'now',gistid=u'21341231',authors=authors)
136 modified=u'now',gistid=u'21341231',authors=authors)
137 self.assertEquals(md.name, u'foo')
137 self.assertEqual(md.name, u'foo')
138 self.assertEquals(md.license, u'BSD')
138 self.assertEqual(md.license, u'BSD')
139 self.assertEquals(md.created, u'today')
139 self.assertEqual(md.created, u'today')
140 self.assertEquals(md.modified, u'now')
140 self.assertEqual(md.modified, u'now')
141 self.assertEquals(md.gistid, u'21341231')
141 self.assertEqual(md.gistid, u'21341231')
142 self.assertEquals(md.authors, authors)
142 self.assertEqual(md.authors, authors)
143
143
@@ -1,46 +1,46 b''
1 # -*- coding: utf8 -*-
1 # -*- coding: utf8 -*-
2
2
3 from unittest import TestCase
3 from unittest import TestCase
4
4
5 from . import formattest
5 from . import formattest
6
6
7 from .. import nbpy
7 from .. import nbpy
8 from .nbexamples import nb0, nb0_py
8 from .nbexamples import nb0, nb0_py
9
9
10
10
11 class TestPy(formattest.NBFormatTest, TestCase):
11 class TestPy(formattest.NBFormatTest, TestCase):
12
12
13 nb0_ref = nb0_py
13 nb0_ref = nb0_py
14 ext = 'py'
14 ext = 'py'
15 mod = nbpy
15 mod = nbpy
16 ignored_keys = ['collapsed', 'outputs', 'prompt_number', 'metadata']
16 ignored_keys = ['collapsed', 'outputs', 'prompt_number', 'metadata']
17
17
18 def assertSubset(self, da, db):
18 def assertSubset(self, da, db):
19 """assert that da is a subset of db, ignoring self.ignored_keys.
19 """assert that da is a subset of db, ignoring self.ignored_keys.
20
20
21 Called recursively on containers, ultimately comparing individual
21 Called recursively on containers, ultimately comparing individual
22 elements.
22 elements.
23 """
23 """
24 if isinstance(da, dict):
24 if isinstance(da, dict):
25 for k,v in da.iteritems():
25 for k,v in da.iteritems():
26 if k in self.ignored_keys:
26 if k in self.ignored_keys:
27 continue
27 continue
28 self.assertTrue(k in db)
28 self.assertTrue(k in db)
29 self.assertSubset(v, db[k])
29 self.assertSubset(v, db[k])
30 elif isinstance(da, list):
30 elif isinstance(da, list):
31 for a,b in zip(da, db):
31 for a,b in zip(da, db):
32 self.assertSubset(a,b)
32 self.assertSubset(a,b)
33 else:
33 else:
34 if isinstance(da, basestring) and isinstance(db, basestring):
34 if isinstance(da, basestring) and isinstance(db, basestring):
35 # pyfile is not sensitive to preserving leading/trailing
35 # pyfile is not sensitive to preserving leading/trailing
36 # newlines in blocks through roundtrip
36 # newlines in blocks through roundtrip
37 da = da.strip('\n')
37 da = da.strip('\n')
38 db = db.strip('\n')
38 db = db.strip('\n')
39 self.assertEquals(da, db)
39 self.assertEqual(da, db)
40 return True
40 return True
41
41
42 def assertNBEquals(self, nba, nbb):
42 def assertNBEquals(self, nba, nbb):
43 # since roundtrip is lossy, only compare keys that are preserved
43 # since roundtrip is lossy, only compare keys that are preserved
44 # assumes nba is read from my file format
44 # assumes nba is read from my file format
45 return self.assertSubset(nba, nbb)
45 return self.assertSubset(nba, nbb)
46
46
@@ -1,184 +1,184 b''
1 """base class for parallel client tests
1 """base class for parallel client tests
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14 from __future__ import print_function
14 from __future__ import print_function
15
15
16 import sys
16 import sys
17 import tempfile
17 import tempfile
18 import time
18 import time
19 from StringIO import StringIO
19 from StringIO import StringIO
20
20
21 from nose import SkipTest
21 from nose import SkipTest
22
22
23 import zmq
23 import zmq
24 from zmq.tests import BaseZMQTestCase
24 from zmq.tests import BaseZMQTestCase
25
25
26 from IPython.external.decorator import decorator
26 from IPython.external.decorator import decorator
27
27
28 from IPython.parallel import error
28 from IPython.parallel import error
29 from IPython.parallel import Client
29 from IPython.parallel import Client
30
30
31 from IPython.parallel.tests import launchers, add_engines
31 from IPython.parallel.tests import launchers, add_engines
32
32
33 # simple tasks for use in apply tests
33 # simple tasks for use in apply tests
34
34
35 def segfault():
35 def segfault():
36 """this will segfault"""
36 """this will segfault"""
37 import ctypes
37 import ctypes
38 ctypes.memset(-1,0,1)
38 ctypes.memset(-1,0,1)
39
39
40 def crash():
40 def crash():
41 """from stdlib crashers in the test suite"""
41 """from stdlib crashers in the test suite"""
42 import types
42 import types
43 if sys.platform.startswith('win'):
43 if sys.platform.startswith('win'):
44 import ctypes
44 import ctypes
45 ctypes.windll.kernel32.SetErrorMode(0x0002);
45 ctypes.windll.kernel32.SetErrorMode(0x0002);
46 args = [ 0, 0, 0, 0, b'\x04\x71\x00\x00', (), (), (), '', '', 1, b'']
46 args = [ 0, 0, 0, 0, b'\x04\x71\x00\x00', (), (), (), '', '', 1, b'']
47 if sys.version_info[0] >= 3:
47 if sys.version_info[0] >= 3:
48 # Python3 adds 'kwonlyargcount' as the second argument to Code
48 # Python3 adds 'kwonlyargcount' as the second argument to Code
49 args.insert(1, 0)
49 args.insert(1, 0)
50
50
51 co = types.CodeType(*args)
51 co = types.CodeType(*args)
52 exec(co)
52 exec(co)
53
53
54 def wait(n):
54 def wait(n):
55 """sleep for a time"""
55 """sleep for a time"""
56 import time
56 import time
57 time.sleep(n)
57 time.sleep(n)
58 return n
58 return n
59
59
60 def raiser(eclass):
60 def raiser(eclass):
61 """raise an exception"""
61 """raise an exception"""
62 raise eclass()
62 raise eclass()
63
63
64 def generate_output():
64 def generate_output():
65 """function for testing output
65 """function for testing output
66
66
67 publishes two outputs of each type, and returns
67 publishes two outputs of each type, and returns
68 a rich displayable object.
68 a rich displayable object.
69 """
69 """
70
70
71 import sys
71 import sys
72 from IPython.core.display import display, HTML, Math
72 from IPython.core.display import display, HTML, Math
73
73
74 print("stdout")
74 print("stdout")
75 print("stderr", file=sys.stderr)
75 print("stderr", file=sys.stderr)
76
76
77 display(HTML("<b>HTML</b>"))
77 display(HTML("<b>HTML</b>"))
78
78
79 print("stdout2")
79 print("stdout2")
80 print("stderr2", file=sys.stderr)
80 print("stderr2", file=sys.stderr)
81
81
82 display(Math(r"\alpha=\beta"))
82 display(Math(r"\alpha=\beta"))
83
83
84 return Math("42")
84 return Math("42")
85
85
86 # test decorator for skipping tests when libraries are unavailable
86 # test decorator for skipping tests when libraries are unavailable
87 def skip_without(*names):
87 def skip_without(*names):
88 """skip a test if some names are not importable"""
88 """skip a test if some names are not importable"""
89 @decorator
89 @decorator
90 def skip_without_names(f, *args, **kwargs):
90 def skip_without_names(f, *args, **kwargs):
91 """decorator to skip tests in the absence of numpy."""
91 """decorator to skip tests in the absence of numpy."""
92 for name in names:
92 for name in names:
93 try:
93 try:
94 __import__(name)
94 __import__(name)
95 except ImportError:
95 except ImportError:
96 raise SkipTest
96 raise SkipTest
97 return f(*args, **kwargs)
97 return f(*args, **kwargs)
98 return skip_without_names
98 return skip_without_names
99
99
100 #-------------------------------------------------------------------------------
100 #-------------------------------------------------------------------------------
101 # Classes
101 # Classes
102 #-------------------------------------------------------------------------------
102 #-------------------------------------------------------------------------------
103
103
104
104
105 class ClusterTestCase(BaseZMQTestCase):
105 class ClusterTestCase(BaseZMQTestCase):
106
106
107 def add_engines(self, n=1, block=True):
107 def add_engines(self, n=1, block=True):
108 """add multiple engines to our cluster"""
108 """add multiple engines to our cluster"""
109 self.engines.extend(add_engines(n))
109 self.engines.extend(add_engines(n))
110 if block:
110 if block:
111 self.wait_on_engines()
111 self.wait_on_engines()
112
112
113 def minimum_engines(self, n=1, block=True):
113 def minimum_engines(self, n=1, block=True):
114 """add engines until there are at least n connected"""
114 """add engines until there are at least n connected"""
115 self.engines.extend(add_engines(n, total=True))
115 self.engines.extend(add_engines(n, total=True))
116 if block:
116 if block:
117 self.wait_on_engines()
117 self.wait_on_engines()
118
118
119
119
120 def wait_on_engines(self, timeout=5):
120 def wait_on_engines(self, timeout=5):
121 """wait for our engines to connect."""
121 """wait for our engines to connect."""
122 n = len(self.engines)+self.base_engine_count
122 n = len(self.engines)+self.base_engine_count
123 tic = time.time()
123 tic = time.time()
124 while time.time()-tic < timeout and len(self.client.ids) < n:
124 while time.time()-tic < timeout and len(self.client.ids) < n:
125 time.sleep(0.1)
125 time.sleep(0.1)
126
126
127 assert not len(self.client.ids) < n, "waiting for engines timed out"
127 assert not len(self.client.ids) < n, "waiting for engines timed out"
128
128
129 def connect_client(self):
129 def connect_client(self):
130 """connect a client with my Context, and track its sockets for cleanup"""
130 """connect a client with my Context, and track its sockets for cleanup"""
131 c = Client(profile='iptest', context=self.context)
131 c = Client(profile='iptest', context=self.context)
132 for name in filter(lambda n:n.endswith('socket'), dir(c)):
132 for name in filter(lambda n:n.endswith('socket'), dir(c)):
133 s = getattr(c, name)
133 s = getattr(c, name)
134 s.setsockopt(zmq.LINGER, 0)
134 s.setsockopt(zmq.LINGER, 0)
135 self.sockets.append(s)
135 self.sockets.append(s)
136 return c
136 return c
137
137
138 def assertRaisesRemote(self, etype, f, *args, **kwargs):
138 def assertRaisesRemote(self, etype, f, *args, **kwargs):
139 try:
139 try:
140 try:
140 try:
141 f(*args, **kwargs)
141 f(*args, **kwargs)
142 except error.CompositeError as e:
142 except error.CompositeError as e:
143 e.raise_exception()
143 e.raise_exception()
144 except error.RemoteError as e:
144 except error.RemoteError as e:
145 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(etype.__name__, e.ename))
145 self.assertEqual(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(etype.__name__, e.ename))
146 else:
146 else:
147 self.fail("should have raised a RemoteError")
147 self.fail("should have raised a RemoteError")
148
148
149 def _wait_for(self, f, timeout=10):
149 def _wait_for(self, f, timeout=10):
150 """wait for a condition"""
150 """wait for a condition"""
151 tic = time.time()
151 tic = time.time()
152 while time.time() <= tic + timeout:
152 while time.time() <= tic + timeout:
153 if f():
153 if f():
154 return
154 return
155 time.sleep(0.1)
155 time.sleep(0.1)
156 self.client.spin()
156 self.client.spin()
157 if not f():
157 if not f():
158 print("Warning: Awaited condition never arrived")
158 print("Warning: Awaited condition never arrived")
159
159
160 def setUp(self):
160 def setUp(self):
161 BaseZMQTestCase.setUp(self)
161 BaseZMQTestCase.setUp(self)
162 self.client = self.connect_client()
162 self.client = self.connect_client()
163 # start every test with clean engine namespaces:
163 # start every test with clean engine namespaces:
164 self.client.clear(block=True)
164 self.client.clear(block=True)
165 self.base_engine_count=len(self.client.ids)
165 self.base_engine_count=len(self.client.ids)
166 self.engines=[]
166 self.engines=[]
167
167
168 def tearDown(self):
168 def tearDown(self):
169 # self.client.clear(block=True)
169 # self.client.clear(block=True)
170 # close fds:
170 # close fds:
171 for e in filter(lambda e: e.poll() is not None, launchers):
171 for e in filter(lambda e: e.poll() is not None, launchers):
172 launchers.remove(e)
172 launchers.remove(e)
173
173
174 # allow flushing of incoming messages to prevent crash on socket close
174 # allow flushing of incoming messages to prevent crash on socket close
175 self.client.wait(timeout=2)
175 self.client.wait(timeout=2)
176 # time.sleep(2)
176 # time.sleep(2)
177 self.client.spin()
177 self.client.spin()
178 self.client.close()
178 self.client.close()
179 BaseZMQTestCase.tearDown(self)
179 BaseZMQTestCase.tearDown(self)
180 # this will be redundant when pyzmq merges PR #88
180 # this will be redundant when pyzmq merges PR #88
181 # self.context.term()
181 # self.context.term()
182 # print tempfile.TemporaryFile().fileno(),
182 # print tempfile.TemporaryFile().fileno(),
183 # sys.stdout.flush()
183 # sys.stdout.flush()
184
184
@@ -1,267 +1,267 b''
1 """Tests for asyncresult.py
1 """Tests for asyncresult.py
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import time
19 import time
20
20
21 from IPython.utils.io import capture_output
21 from IPython.utils.io import capture_output
22
22
23 from IPython.parallel.error import TimeoutError
23 from IPython.parallel.error import TimeoutError
24 from IPython.parallel import error, Client
24 from IPython.parallel import error, Client
25 from IPython.parallel.tests import add_engines
25 from IPython.parallel.tests import add_engines
26 from .clienttest import ClusterTestCase
26 from .clienttest import ClusterTestCase
27
27
28 def setup():
28 def setup():
29 add_engines(2, total=True)
29 add_engines(2, total=True)
30
30
31 def wait(n):
31 def wait(n):
32 import time
32 import time
33 time.sleep(n)
33 time.sleep(n)
34 return n
34 return n
35
35
36 class AsyncResultTest(ClusterTestCase):
36 class AsyncResultTest(ClusterTestCase):
37
37
38 def test_single_result_view(self):
38 def test_single_result_view(self):
39 """various one-target views get the right value for single_result"""
39 """various one-target views get the right value for single_result"""
40 eid = self.client.ids[-1]
40 eid = self.client.ids[-1]
41 ar = self.client[eid].apply_async(lambda : 42)
41 ar = self.client[eid].apply_async(lambda : 42)
42 self.assertEquals(ar.get(), 42)
42 self.assertEqual(ar.get(), 42)
43 ar = self.client[[eid]].apply_async(lambda : 42)
43 ar = self.client[[eid]].apply_async(lambda : 42)
44 self.assertEquals(ar.get(), [42])
44 self.assertEqual(ar.get(), [42])
45 ar = self.client[-1:].apply_async(lambda : 42)
45 ar = self.client[-1:].apply_async(lambda : 42)
46 self.assertEquals(ar.get(), [42])
46 self.assertEqual(ar.get(), [42])
47
47
48 def test_get_after_done(self):
48 def test_get_after_done(self):
49 ar = self.client[-1].apply_async(lambda : 42)
49 ar = self.client[-1].apply_async(lambda : 42)
50 ar.wait()
50 ar.wait()
51 self.assertTrue(ar.ready())
51 self.assertTrue(ar.ready())
52 self.assertEquals(ar.get(), 42)
52 self.assertEqual(ar.get(), 42)
53 self.assertEquals(ar.get(), 42)
53 self.assertEqual(ar.get(), 42)
54
54
55 def test_get_before_done(self):
55 def test_get_before_done(self):
56 ar = self.client[-1].apply_async(wait, 0.1)
56 ar = self.client[-1].apply_async(wait, 0.1)
57 self.assertRaises(TimeoutError, ar.get, 0)
57 self.assertRaises(TimeoutError, ar.get, 0)
58 ar.wait(0)
58 ar.wait(0)
59 self.assertFalse(ar.ready())
59 self.assertFalse(ar.ready())
60 self.assertEquals(ar.get(), 0.1)
60 self.assertEqual(ar.get(), 0.1)
61
61
62 def test_get_after_error(self):
62 def test_get_after_error(self):
63 ar = self.client[-1].apply_async(lambda : 1/0)
63 ar = self.client[-1].apply_async(lambda : 1/0)
64 ar.wait(10)
64 ar.wait(10)
65 self.assertRaisesRemote(ZeroDivisionError, ar.get)
65 self.assertRaisesRemote(ZeroDivisionError, ar.get)
66 self.assertRaisesRemote(ZeroDivisionError, ar.get)
66 self.assertRaisesRemote(ZeroDivisionError, ar.get)
67 self.assertRaisesRemote(ZeroDivisionError, ar.get_dict)
67 self.assertRaisesRemote(ZeroDivisionError, ar.get_dict)
68
68
69 def test_get_dict(self):
69 def test_get_dict(self):
70 n = len(self.client)
70 n = len(self.client)
71 ar = self.client[:].apply_async(lambda : 5)
71 ar = self.client[:].apply_async(lambda : 5)
72 self.assertEquals(ar.get(), [5]*n)
72 self.assertEqual(ar.get(), [5]*n)
73 d = ar.get_dict()
73 d = ar.get_dict()
74 self.assertEquals(sorted(d.keys()), sorted(self.client.ids))
74 self.assertEqual(sorted(d.keys()), sorted(self.client.ids))
75 for eid,r in d.iteritems():
75 for eid,r in d.iteritems():
76 self.assertEquals(r, 5)
76 self.assertEqual(r, 5)
77
77
78 def test_list_amr(self):
78 def test_list_amr(self):
79 ar = self.client.load_balanced_view().map_async(wait, [0.1]*5)
79 ar = self.client.load_balanced_view().map_async(wait, [0.1]*5)
80 rlist = list(ar)
80 rlist = list(ar)
81
81
82 def test_getattr(self):
82 def test_getattr(self):
83 ar = self.client[:].apply_async(wait, 0.5)
83 ar = self.client[:].apply_async(wait, 0.5)
84 self.assertRaises(AttributeError, lambda : ar._foo)
84 self.assertRaises(AttributeError, lambda : ar._foo)
85 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
85 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
86 self.assertRaises(AttributeError, lambda : ar.foo)
86 self.assertRaises(AttributeError, lambda : ar.foo)
87 self.assertRaises(AttributeError, lambda : ar.engine_id)
87 self.assertRaises(AttributeError, lambda : ar.engine_id)
88 self.assertFalse(hasattr(ar, '__length_hint__'))
88 self.assertFalse(hasattr(ar, '__length_hint__'))
89 self.assertFalse(hasattr(ar, 'foo'))
89 self.assertFalse(hasattr(ar, 'foo'))
90 self.assertFalse(hasattr(ar, 'engine_id'))
90 self.assertFalse(hasattr(ar, 'engine_id'))
91 ar.get(5)
91 ar.get(5)
92 self.assertRaises(AttributeError, lambda : ar._foo)
92 self.assertRaises(AttributeError, lambda : ar._foo)
93 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
93 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
94 self.assertRaises(AttributeError, lambda : ar.foo)
94 self.assertRaises(AttributeError, lambda : ar.foo)
95 self.assertTrue(isinstance(ar.engine_id, list))
95 self.assertTrue(isinstance(ar.engine_id, list))
96 self.assertEquals(ar.engine_id, ar['engine_id'])
96 self.assertEqual(ar.engine_id, ar['engine_id'])
97 self.assertFalse(hasattr(ar, '__length_hint__'))
97 self.assertFalse(hasattr(ar, '__length_hint__'))
98 self.assertFalse(hasattr(ar, 'foo'))
98 self.assertFalse(hasattr(ar, 'foo'))
99 self.assertTrue(hasattr(ar, 'engine_id'))
99 self.assertTrue(hasattr(ar, 'engine_id'))
100
100
101 def test_getitem(self):
101 def test_getitem(self):
102 ar = self.client[:].apply_async(wait, 0.5)
102 ar = self.client[:].apply_async(wait, 0.5)
103 self.assertRaises(TimeoutError, lambda : ar['foo'])
103 self.assertRaises(TimeoutError, lambda : ar['foo'])
104 self.assertRaises(TimeoutError, lambda : ar['engine_id'])
104 self.assertRaises(TimeoutError, lambda : ar['engine_id'])
105 ar.get(5)
105 ar.get(5)
106 self.assertRaises(KeyError, lambda : ar['foo'])
106 self.assertRaises(KeyError, lambda : ar['foo'])
107 self.assertTrue(isinstance(ar['engine_id'], list))
107 self.assertTrue(isinstance(ar['engine_id'], list))
108 self.assertEquals(ar.engine_id, ar['engine_id'])
108 self.assertEqual(ar.engine_id, ar['engine_id'])
109
109
110 def test_single_result(self):
110 def test_single_result(self):
111 ar = self.client[-1].apply_async(wait, 0.5)
111 ar = self.client[-1].apply_async(wait, 0.5)
112 self.assertRaises(TimeoutError, lambda : ar['foo'])
112 self.assertRaises(TimeoutError, lambda : ar['foo'])
113 self.assertRaises(TimeoutError, lambda : ar['engine_id'])
113 self.assertRaises(TimeoutError, lambda : ar['engine_id'])
114 self.assertTrue(ar.get(5) == 0.5)
114 self.assertTrue(ar.get(5) == 0.5)
115 self.assertTrue(isinstance(ar['engine_id'], int))
115 self.assertTrue(isinstance(ar['engine_id'], int))
116 self.assertTrue(isinstance(ar.engine_id, int))
116 self.assertTrue(isinstance(ar.engine_id, int))
117 self.assertEquals(ar.engine_id, ar['engine_id'])
117 self.assertEqual(ar.engine_id, ar['engine_id'])
118
118
119 def test_abort(self):
119 def test_abort(self):
120 e = self.client[-1]
120 e = self.client[-1]
121 ar = e.execute('import time; time.sleep(1)', block=False)
121 ar = e.execute('import time; time.sleep(1)', block=False)
122 ar2 = e.apply_async(lambda : 2)
122 ar2 = e.apply_async(lambda : 2)
123 ar2.abort()
123 ar2.abort()
124 self.assertRaises(error.TaskAborted, ar2.get)
124 self.assertRaises(error.TaskAborted, ar2.get)
125 ar.get()
125 ar.get()
126
126
127 def test_len(self):
127 def test_len(self):
128 v = self.client.load_balanced_view()
128 v = self.client.load_balanced_view()
129 ar = v.map_async(lambda x: x, range(10))
129 ar = v.map_async(lambda x: x, range(10))
130 self.assertEquals(len(ar), 10)
130 self.assertEqual(len(ar), 10)
131 ar = v.apply_async(lambda x: x, range(10))
131 ar = v.apply_async(lambda x: x, range(10))
132 self.assertEquals(len(ar), 1)
132 self.assertEqual(len(ar), 1)
133 ar = self.client[:].apply_async(lambda x: x, range(10))
133 ar = self.client[:].apply_async(lambda x: x, range(10))
134 self.assertEquals(len(ar), len(self.client.ids))
134 self.assertEqual(len(ar), len(self.client.ids))
135
135
136 def test_wall_time_single(self):
136 def test_wall_time_single(self):
137 v = self.client.load_balanced_view()
137 v = self.client.load_balanced_view()
138 ar = v.apply_async(time.sleep, 0.25)
138 ar = v.apply_async(time.sleep, 0.25)
139 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
139 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
140 ar.get(2)
140 ar.get(2)
141 self.assertTrue(ar.wall_time < 1.)
141 self.assertTrue(ar.wall_time < 1.)
142 self.assertTrue(ar.wall_time > 0.2)
142 self.assertTrue(ar.wall_time > 0.2)
143
143
144 def test_wall_time_multi(self):
144 def test_wall_time_multi(self):
145 self.minimum_engines(4)
145 self.minimum_engines(4)
146 v = self.client[:]
146 v = self.client[:]
147 ar = v.apply_async(time.sleep, 0.25)
147 ar = v.apply_async(time.sleep, 0.25)
148 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
148 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
149 ar.get(2)
149 ar.get(2)
150 self.assertTrue(ar.wall_time < 1.)
150 self.assertTrue(ar.wall_time < 1.)
151 self.assertTrue(ar.wall_time > 0.2)
151 self.assertTrue(ar.wall_time > 0.2)
152
152
153 def test_serial_time_single(self):
153 def test_serial_time_single(self):
154 v = self.client.load_balanced_view()
154 v = self.client.load_balanced_view()
155 ar = v.apply_async(time.sleep, 0.25)
155 ar = v.apply_async(time.sleep, 0.25)
156 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
156 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
157 ar.get(2)
157 ar.get(2)
158 self.assertTrue(ar.serial_time < 1.)
158 self.assertTrue(ar.serial_time < 1.)
159 self.assertTrue(ar.serial_time > 0.2)
159 self.assertTrue(ar.serial_time > 0.2)
160
160
161 def test_serial_time_multi(self):
161 def test_serial_time_multi(self):
162 self.minimum_engines(4)
162 self.minimum_engines(4)
163 v = self.client[:]
163 v = self.client[:]
164 ar = v.apply_async(time.sleep, 0.25)
164 ar = v.apply_async(time.sleep, 0.25)
165 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
165 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
166 ar.get(2)
166 ar.get(2)
167 self.assertTrue(ar.serial_time < 2.)
167 self.assertTrue(ar.serial_time < 2.)
168 self.assertTrue(ar.serial_time > 0.8)
168 self.assertTrue(ar.serial_time > 0.8)
169
169
170 def test_elapsed_single(self):
170 def test_elapsed_single(self):
171 v = self.client.load_balanced_view()
171 v = self.client.load_balanced_view()
172 ar = v.apply_async(time.sleep, 0.25)
172 ar = v.apply_async(time.sleep, 0.25)
173 while not ar.ready():
173 while not ar.ready():
174 time.sleep(0.01)
174 time.sleep(0.01)
175 self.assertTrue(ar.elapsed < 1)
175 self.assertTrue(ar.elapsed < 1)
176 self.assertTrue(ar.elapsed < 1)
176 self.assertTrue(ar.elapsed < 1)
177 ar.get(2)
177 ar.get(2)
178
178
179 def test_elapsed_multi(self):
179 def test_elapsed_multi(self):
180 v = self.client[:]
180 v = self.client[:]
181 ar = v.apply_async(time.sleep, 0.25)
181 ar = v.apply_async(time.sleep, 0.25)
182 while not ar.ready():
182 while not ar.ready():
183 time.sleep(0.01)
183 time.sleep(0.01)
184 self.assertTrue(ar.elapsed < 1)
184 self.assertTrue(ar.elapsed < 1)
185 self.assertTrue(ar.elapsed < 1)
185 self.assertTrue(ar.elapsed < 1)
186 ar.get(2)
186 ar.get(2)
187
187
188 def test_hubresult_timestamps(self):
188 def test_hubresult_timestamps(self):
189 self.minimum_engines(4)
189 self.minimum_engines(4)
190 v = self.client[:]
190 v = self.client[:]
191 ar = v.apply_async(time.sleep, 0.25)
191 ar = v.apply_async(time.sleep, 0.25)
192 ar.get(2)
192 ar.get(2)
193 rc2 = Client(profile='iptest')
193 rc2 = Client(profile='iptest')
194 # must have try/finally to close second Client, otherwise
194 # must have try/finally to close second Client, otherwise
195 # will have dangling sockets causing problems
195 # will have dangling sockets causing problems
196 try:
196 try:
197 time.sleep(0.25)
197 time.sleep(0.25)
198 hr = rc2.get_result(ar.msg_ids)
198 hr = rc2.get_result(ar.msg_ids)
199 self.assertTrue(hr.elapsed > 0., "got bad elapsed: %s" % hr.elapsed)
199 self.assertTrue(hr.elapsed > 0., "got bad elapsed: %s" % hr.elapsed)
200 hr.get(1)
200 hr.get(1)
201 self.assertTrue(hr.wall_time < ar.wall_time + 0.2, "got bad wall_time: %s > %s" % (hr.wall_time, ar.wall_time))
201 self.assertTrue(hr.wall_time < ar.wall_time + 0.2, "got bad wall_time: %s > %s" % (hr.wall_time, ar.wall_time))
202 self.assertEquals(hr.serial_time, ar.serial_time)
202 self.assertEqual(hr.serial_time, ar.serial_time)
203 finally:
203 finally:
204 rc2.close()
204 rc2.close()
205
205
206 def test_display_empty_streams_single(self):
206 def test_display_empty_streams_single(self):
207 """empty stdout/err are not displayed (single result)"""
207 """empty stdout/err are not displayed (single result)"""
208 self.minimum_engines(1)
208 self.minimum_engines(1)
209
209
210 v = self.client[-1]
210 v = self.client[-1]
211 ar = v.execute("print (5555)")
211 ar = v.execute("print (5555)")
212 ar.get(5)
212 ar.get(5)
213 with capture_output() as io:
213 with capture_output() as io:
214 ar.display_outputs()
214 ar.display_outputs()
215 self.assertEquals(io.stderr, '')
215 self.assertEqual(io.stderr, '')
216 self.assertEquals('5555\n', io.stdout)
216 self.assertEqual('5555\n', io.stdout)
217
217
218 ar = v.execute("a=5")
218 ar = v.execute("a=5")
219 ar.get(5)
219 ar.get(5)
220 with capture_output() as io:
220 with capture_output() as io:
221 ar.display_outputs()
221 ar.display_outputs()
222 self.assertEquals(io.stderr, '')
222 self.assertEqual(io.stderr, '')
223 self.assertEquals(io.stdout, '')
223 self.assertEqual(io.stdout, '')
224
224
225 def test_display_empty_streams_type(self):
225 def test_display_empty_streams_type(self):
226 """empty stdout/err are not displayed (groupby type)"""
226 """empty stdout/err are not displayed (groupby type)"""
227 self.minimum_engines(1)
227 self.minimum_engines(1)
228
228
229 v = self.client[:]
229 v = self.client[:]
230 ar = v.execute("print (5555)")
230 ar = v.execute("print (5555)")
231 ar.get(5)
231 ar.get(5)
232 with capture_output() as io:
232 with capture_output() as io:
233 ar.display_outputs()
233 ar.display_outputs()
234 self.assertEquals(io.stderr, '')
234 self.assertEqual(io.stderr, '')
235 self.assertEquals(io.stdout.count('5555'), len(v), io.stdout)
235 self.assertEqual(io.stdout.count('5555'), len(v), io.stdout)
236 self.assertFalse('\n\n' in io.stdout, io.stdout)
236 self.assertFalse('\n\n' in io.stdout, io.stdout)
237 self.assertEquals(io.stdout.count('[stdout:'), len(v), io.stdout)
237 self.assertEqual(io.stdout.count('[stdout:'), len(v), io.stdout)
238
238
239 ar = v.execute("a=5")
239 ar = v.execute("a=5")
240 ar.get(5)
240 ar.get(5)
241 with capture_output() as io:
241 with capture_output() as io:
242 ar.display_outputs()
242 ar.display_outputs()
243 self.assertEquals(io.stderr, '')
243 self.assertEqual(io.stderr, '')
244 self.assertEquals(io.stdout, '')
244 self.assertEqual(io.stdout, '')
245
245
246 def test_display_empty_streams_engine(self):
246 def test_display_empty_streams_engine(self):
247 """empty stdout/err are not displayed (groupby engine)"""
247 """empty stdout/err are not displayed (groupby engine)"""
248 self.minimum_engines(1)
248 self.minimum_engines(1)
249
249
250 v = self.client[:]
250 v = self.client[:]
251 ar = v.execute("print (5555)")
251 ar = v.execute("print (5555)")
252 ar.get(5)
252 ar.get(5)
253 with capture_output() as io:
253 with capture_output() as io:
254 ar.display_outputs('engine')
254 ar.display_outputs('engine')
255 self.assertEquals(io.stderr, '')
255 self.assertEqual(io.stderr, '')
256 self.assertEquals(io.stdout.count('5555'), len(v), io.stdout)
256 self.assertEqual(io.stdout.count('5555'), len(v), io.stdout)
257 self.assertFalse('\n\n' in io.stdout, io.stdout)
257 self.assertFalse('\n\n' in io.stdout, io.stdout)
258 self.assertEquals(io.stdout.count('[stdout:'), len(v), io.stdout)
258 self.assertEqual(io.stdout.count('[stdout:'), len(v), io.stdout)
259
259
260 ar = v.execute("a=5")
260 ar = v.execute("a=5")
261 ar.get(5)
261 ar.get(5)
262 with capture_output() as io:
262 with capture_output() as io:
263 ar.display_outputs('engine')
263 ar.display_outputs('engine')
264 self.assertEquals(io.stderr, '')
264 self.assertEqual(io.stderr, '')
265 self.assertEquals(io.stdout, '')
265 self.assertEqual(io.stdout, '')
266
266
267
267
@@ -1,455 +1,455 b''
1 """Tests for parallel client.py
1 """Tests for parallel client.py
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import time
21 import time
22 from datetime import datetime
22 from datetime import datetime
23 from tempfile import mktemp
23 from tempfile import mktemp
24
24
25 import zmq
25 import zmq
26
26
27 from IPython import parallel
27 from IPython import parallel
28 from IPython.parallel.client import client as clientmod
28 from IPython.parallel.client import client as clientmod
29 from IPython.parallel import error
29 from IPython.parallel import error
30 from IPython.parallel import AsyncResult, AsyncHubResult
30 from IPython.parallel import AsyncResult, AsyncHubResult
31 from IPython.parallel import LoadBalancedView, DirectView
31 from IPython.parallel import LoadBalancedView, DirectView
32
32
33 from clienttest import ClusterTestCase, segfault, wait, add_engines
33 from clienttest import ClusterTestCase, segfault, wait, add_engines
34
34
35 def setup():
35 def setup():
36 add_engines(4, total=True)
36 add_engines(4, total=True)
37
37
38 class TestClient(ClusterTestCase):
38 class TestClient(ClusterTestCase):
39
39
40 def test_ids(self):
40 def test_ids(self):
41 n = len(self.client.ids)
41 n = len(self.client.ids)
42 self.add_engines(2)
42 self.add_engines(2)
43 self.assertEquals(len(self.client.ids), n+2)
43 self.assertEqual(len(self.client.ids), n+2)
44
44
45 def test_view_indexing(self):
45 def test_view_indexing(self):
46 """test index access for views"""
46 """test index access for views"""
47 self.minimum_engines(4)
47 self.minimum_engines(4)
48 targets = self.client._build_targets('all')[-1]
48 targets = self.client._build_targets('all')[-1]
49 v = self.client[:]
49 v = self.client[:]
50 self.assertEquals(v.targets, targets)
50 self.assertEqual(v.targets, targets)
51 t = self.client.ids[2]
51 t = self.client.ids[2]
52 v = self.client[t]
52 v = self.client[t]
53 self.assert_(isinstance(v, DirectView))
53 self.assert_(isinstance(v, DirectView))
54 self.assertEquals(v.targets, t)
54 self.assertEqual(v.targets, t)
55 t = self.client.ids[2:4]
55 t = self.client.ids[2:4]
56 v = self.client[t]
56 v = self.client[t]
57 self.assert_(isinstance(v, DirectView))
57 self.assert_(isinstance(v, DirectView))
58 self.assertEquals(v.targets, t)
58 self.assertEqual(v.targets, t)
59 v = self.client[::2]
59 v = self.client[::2]
60 self.assert_(isinstance(v, DirectView))
60 self.assert_(isinstance(v, DirectView))
61 self.assertEquals(v.targets, targets[::2])
61 self.assertEqual(v.targets, targets[::2])
62 v = self.client[1::3]
62 v = self.client[1::3]
63 self.assert_(isinstance(v, DirectView))
63 self.assert_(isinstance(v, DirectView))
64 self.assertEquals(v.targets, targets[1::3])
64 self.assertEqual(v.targets, targets[1::3])
65 v = self.client[:-3]
65 v = self.client[:-3]
66 self.assert_(isinstance(v, DirectView))
66 self.assert_(isinstance(v, DirectView))
67 self.assertEquals(v.targets, targets[:-3])
67 self.assertEqual(v.targets, targets[:-3])
68 v = self.client[-1]
68 v = self.client[-1]
69 self.assert_(isinstance(v, DirectView))
69 self.assert_(isinstance(v, DirectView))
70 self.assertEquals(v.targets, targets[-1])
70 self.assertEqual(v.targets, targets[-1])
71 self.assertRaises(TypeError, lambda : self.client[None])
71 self.assertRaises(TypeError, lambda : self.client[None])
72
72
73 def test_lbview_targets(self):
73 def test_lbview_targets(self):
74 """test load_balanced_view targets"""
74 """test load_balanced_view targets"""
75 v = self.client.load_balanced_view()
75 v = self.client.load_balanced_view()
76 self.assertEquals(v.targets, None)
76 self.assertEqual(v.targets, None)
77 v = self.client.load_balanced_view(-1)
77 v = self.client.load_balanced_view(-1)
78 self.assertEquals(v.targets, [self.client.ids[-1]])
78 self.assertEqual(v.targets, [self.client.ids[-1]])
79 v = self.client.load_balanced_view('all')
79 v = self.client.load_balanced_view('all')
80 self.assertEquals(v.targets, None)
80 self.assertEqual(v.targets, None)
81
81
82 def test_dview_targets(self):
82 def test_dview_targets(self):
83 """test direct_view targets"""
83 """test direct_view targets"""
84 v = self.client.direct_view()
84 v = self.client.direct_view()
85 self.assertEquals(v.targets, 'all')
85 self.assertEqual(v.targets, 'all')
86 v = self.client.direct_view('all')
86 v = self.client.direct_view('all')
87 self.assertEquals(v.targets, 'all')
87 self.assertEqual(v.targets, 'all')
88 v = self.client.direct_view(-1)
88 v = self.client.direct_view(-1)
89 self.assertEquals(v.targets, self.client.ids[-1])
89 self.assertEqual(v.targets, self.client.ids[-1])
90
90
91 def test_lazy_all_targets(self):
91 def test_lazy_all_targets(self):
92 """test lazy evaluation of rc.direct_view('all')"""
92 """test lazy evaluation of rc.direct_view('all')"""
93 v = self.client.direct_view()
93 v = self.client.direct_view()
94 self.assertEquals(v.targets, 'all')
94 self.assertEqual(v.targets, 'all')
95
95
96 def double(x):
96 def double(x):
97 return x*2
97 return x*2
98 seq = range(100)
98 seq = range(100)
99 ref = [ double(x) for x in seq ]
99 ref = [ double(x) for x in seq ]
100
100
101 # add some engines, which should be used
101 # add some engines, which should be used
102 self.add_engines(1)
102 self.add_engines(1)
103 n1 = len(self.client.ids)
103 n1 = len(self.client.ids)
104
104
105 # simple apply
105 # simple apply
106 r = v.apply_sync(lambda : 1)
106 r = v.apply_sync(lambda : 1)
107 self.assertEquals(r, [1] * n1)
107 self.assertEqual(r, [1] * n1)
108
108
109 # map goes through remotefunction
109 # map goes through remotefunction
110 r = v.map_sync(double, seq)
110 r = v.map_sync(double, seq)
111 self.assertEquals(r, ref)
111 self.assertEqual(r, ref)
112
112
113 # add a couple more engines, and try again
113 # add a couple more engines, and try again
114 self.add_engines(2)
114 self.add_engines(2)
115 n2 = len(self.client.ids)
115 n2 = len(self.client.ids)
116 self.assertNotEquals(n2, n1)
116 self.assertNotEquals(n2, n1)
117
117
118 # apply
118 # apply
119 r = v.apply_sync(lambda : 1)
119 r = v.apply_sync(lambda : 1)
120 self.assertEquals(r, [1] * n2)
120 self.assertEqual(r, [1] * n2)
121
121
122 # map
122 # map
123 r = v.map_sync(double, seq)
123 r = v.map_sync(double, seq)
124 self.assertEquals(r, ref)
124 self.assertEqual(r, ref)
125
125
126 def test_targets(self):
126 def test_targets(self):
127 """test various valid targets arguments"""
127 """test various valid targets arguments"""
128 build = self.client._build_targets
128 build = self.client._build_targets
129 ids = self.client.ids
129 ids = self.client.ids
130 idents,targets = build(None)
130 idents,targets = build(None)
131 self.assertEquals(ids, targets)
131 self.assertEqual(ids, targets)
132
132
133 def test_clear(self):
133 def test_clear(self):
134 """test clear behavior"""
134 """test clear behavior"""
135 self.minimum_engines(2)
135 self.minimum_engines(2)
136 v = self.client[:]
136 v = self.client[:]
137 v.block=True
137 v.block=True
138 v.push(dict(a=5))
138 v.push(dict(a=5))
139 v.pull('a')
139 v.pull('a')
140 id0 = self.client.ids[-1]
140 id0 = self.client.ids[-1]
141 self.client.clear(targets=id0, block=True)
141 self.client.clear(targets=id0, block=True)
142 a = self.client[:-1].get('a')
142 a = self.client[:-1].get('a')
143 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
143 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
144 self.client.clear(block=True)
144 self.client.clear(block=True)
145 for i in self.client.ids:
145 for i in self.client.ids:
146 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
146 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
147
147
148 def test_get_result(self):
148 def test_get_result(self):
149 """test getting results from the Hub."""
149 """test getting results from the Hub."""
150 c = clientmod.Client(profile='iptest')
150 c = clientmod.Client(profile='iptest')
151 t = c.ids[-1]
151 t = c.ids[-1]
152 ar = c[t].apply_async(wait, 1)
152 ar = c[t].apply_async(wait, 1)
153 # give the monitor time to notice the message
153 # give the monitor time to notice the message
154 time.sleep(.25)
154 time.sleep(.25)
155 ahr = self.client.get_result(ar.msg_ids)
155 ahr = self.client.get_result(ar.msg_ids)
156 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 self.assertTrue(isinstance(ahr, AsyncHubResult))
157 self.assertEquals(ahr.get(), ar.get())
157 self.assertEqual(ahr.get(), ar.get())
158 ar2 = self.client.get_result(ar.msg_ids)
158 ar2 = self.client.get_result(ar.msg_ids)
159 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 self.assertFalse(isinstance(ar2, AsyncHubResult))
160 c.close()
160 c.close()
161
161
162 def test_get_execute_result(self):
162 def test_get_execute_result(self):
163 """test getting execute results from the Hub."""
163 """test getting execute results from the Hub."""
164 c = clientmod.Client(profile='iptest')
164 c = clientmod.Client(profile='iptest')
165 t = c.ids[-1]
165 t = c.ids[-1]
166 cell = '\n'.join([
166 cell = '\n'.join([
167 'import time',
167 'import time',
168 'time.sleep(0.25)',
168 'time.sleep(0.25)',
169 '5'
169 '5'
170 ])
170 ])
171 ar = c[t].execute("import time; time.sleep(1)", silent=False)
171 ar = c[t].execute("import time; time.sleep(1)", silent=False)
172 # give the monitor time to notice the message
172 # give the monitor time to notice the message
173 time.sleep(.25)
173 time.sleep(.25)
174 ahr = self.client.get_result(ar.msg_ids)
174 ahr = self.client.get_result(ar.msg_ids)
175 self.assertTrue(isinstance(ahr, AsyncHubResult))
175 self.assertTrue(isinstance(ahr, AsyncHubResult))
176 self.assertEquals(ahr.get().pyout, ar.get().pyout)
176 self.assertEqual(ahr.get().pyout, ar.get().pyout)
177 ar2 = self.client.get_result(ar.msg_ids)
177 ar2 = self.client.get_result(ar.msg_ids)
178 self.assertFalse(isinstance(ar2, AsyncHubResult))
178 self.assertFalse(isinstance(ar2, AsyncHubResult))
179 c.close()
179 c.close()
180
180
181 def test_ids_list(self):
181 def test_ids_list(self):
182 """test client.ids"""
182 """test client.ids"""
183 ids = self.client.ids
183 ids = self.client.ids
184 self.assertEquals(ids, self.client._ids)
184 self.assertEqual(ids, self.client._ids)
185 self.assertFalse(ids is self.client._ids)
185 self.assertFalse(ids is self.client._ids)
186 ids.remove(ids[-1])
186 ids.remove(ids[-1])
187 self.assertNotEquals(ids, self.client._ids)
187 self.assertNotEquals(ids, self.client._ids)
188
188
189 def test_queue_status(self):
189 def test_queue_status(self):
190 ids = self.client.ids
190 ids = self.client.ids
191 id0 = ids[0]
191 id0 = ids[0]
192 qs = self.client.queue_status(targets=id0)
192 qs = self.client.queue_status(targets=id0)
193 self.assertTrue(isinstance(qs, dict))
193 self.assertTrue(isinstance(qs, dict))
194 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
194 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
195 allqs = self.client.queue_status()
195 allqs = self.client.queue_status()
196 self.assertTrue(isinstance(allqs, dict))
196 self.assertTrue(isinstance(allqs, dict))
197 intkeys = list(allqs.keys())
197 intkeys = list(allqs.keys())
198 intkeys.remove('unassigned')
198 intkeys.remove('unassigned')
199 self.assertEquals(sorted(intkeys), sorted(self.client.ids))
199 self.assertEqual(sorted(intkeys), sorted(self.client.ids))
200 unassigned = allqs.pop('unassigned')
200 unassigned = allqs.pop('unassigned')
201 for eid,qs in allqs.items():
201 for eid,qs in allqs.items():
202 self.assertTrue(isinstance(qs, dict))
202 self.assertTrue(isinstance(qs, dict))
203 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
203 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
204
204
205 def test_shutdown(self):
205 def test_shutdown(self):
206 ids = self.client.ids
206 ids = self.client.ids
207 id0 = ids[0]
207 id0 = ids[0]
208 self.client.shutdown(id0, block=True)
208 self.client.shutdown(id0, block=True)
209 while id0 in self.client.ids:
209 while id0 in self.client.ids:
210 time.sleep(0.1)
210 time.sleep(0.1)
211 self.client.spin()
211 self.client.spin()
212
212
213 self.assertRaises(IndexError, lambda : self.client[id0])
213 self.assertRaises(IndexError, lambda : self.client[id0])
214
214
215 def test_result_status(self):
215 def test_result_status(self):
216 pass
216 pass
217 # to be written
217 # to be written
218
218
219 def test_db_query_dt(self):
219 def test_db_query_dt(self):
220 """test db query by date"""
220 """test db query by date"""
221 hist = self.client.hub_history()
221 hist = self.client.hub_history()
222 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
222 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
223 tic = middle['submitted']
223 tic = middle['submitted']
224 before = self.client.db_query({'submitted' : {'$lt' : tic}})
224 before = self.client.db_query({'submitted' : {'$lt' : tic}})
225 after = self.client.db_query({'submitted' : {'$gte' : tic}})
225 after = self.client.db_query({'submitted' : {'$gte' : tic}})
226 self.assertEquals(len(before)+len(after),len(hist))
226 self.assertEqual(len(before)+len(after),len(hist))
227 for b in before:
227 for b in before:
228 self.assertTrue(b['submitted'] < tic)
228 self.assertTrue(b['submitted'] < tic)
229 for a in after:
229 for a in after:
230 self.assertTrue(a['submitted'] >= tic)
230 self.assertTrue(a['submitted'] >= tic)
231 same = self.client.db_query({'submitted' : tic})
231 same = self.client.db_query({'submitted' : tic})
232 for s in same:
232 for s in same:
233 self.assertTrue(s['submitted'] == tic)
233 self.assertTrue(s['submitted'] == tic)
234
234
235 def test_db_query_keys(self):
235 def test_db_query_keys(self):
236 """test extracting subset of record keys"""
236 """test extracting subset of record keys"""
237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
238 for rec in found:
238 for rec in found:
239 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
239 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
240
240
241 def test_db_query_default_keys(self):
241 def test_db_query_default_keys(self):
242 """default db_query excludes buffers"""
242 """default db_query excludes buffers"""
243 found = self.client.db_query({'msg_id': {'$ne' : ''}})
243 found = self.client.db_query({'msg_id': {'$ne' : ''}})
244 for rec in found:
244 for rec in found:
245 keys = set(rec.keys())
245 keys = set(rec.keys())
246 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
246 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
247 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
247 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
248
248
249 def test_db_query_msg_id(self):
249 def test_db_query_msg_id(self):
250 """ensure msg_id is always in db queries"""
250 """ensure msg_id is always in db queries"""
251 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
251 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
252 for rec in found:
252 for rec in found:
253 self.assertTrue('msg_id' in rec.keys())
253 self.assertTrue('msg_id' in rec.keys())
254 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
254 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
255 for rec in found:
255 for rec in found:
256 self.assertTrue('msg_id' in rec.keys())
256 self.assertTrue('msg_id' in rec.keys())
257 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
257 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
258 for rec in found:
258 for rec in found:
259 self.assertTrue('msg_id' in rec.keys())
259 self.assertTrue('msg_id' in rec.keys())
260
260
261 def test_db_query_get_result(self):
261 def test_db_query_get_result(self):
262 """pop in db_query shouldn't pop from result itself"""
262 """pop in db_query shouldn't pop from result itself"""
263 self.client[:].apply_sync(lambda : 1)
263 self.client[:].apply_sync(lambda : 1)
264 found = self.client.db_query({'msg_id': {'$ne' : ''}})
264 found = self.client.db_query({'msg_id': {'$ne' : ''}})
265 rc2 = clientmod.Client(profile='iptest')
265 rc2 = clientmod.Client(profile='iptest')
266 # If this bug is not fixed, this call will hang:
266 # If this bug is not fixed, this call will hang:
267 ar = rc2.get_result(self.client.history[-1])
267 ar = rc2.get_result(self.client.history[-1])
268 ar.wait(2)
268 ar.wait(2)
269 self.assertTrue(ar.ready())
269 self.assertTrue(ar.ready())
270 ar.get()
270 ar.get()
271 rc2.close()
271 rc2.close()
272
272
273 def test_db_query_in(self):
273 def test_db_query_in(self):
274 """test db query with '$in','$nin' operators"""
274 """test db query with '$in','$nin' operators"""
275 hist = self.client.hub_history()
275 hist = self.client.hub_history()
276 even = hist[::2]
276 even = hist[::2]
277 odd = hist[1::2]
277 odd = hist[1::2]
278 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
278 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
279 found = [ r['msg_id'] for r in recs ]
279 found = [ r['msg_id'] for r in recs ]
280 self.assertEquals(set(even), set(found))
280 self.assertEqual(set(even), set(found))
281 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
281 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
282 found = [ r['msg_id'] for r in recs ]
282 found = [ r['msg_id'] for r in recs ]
283 self.assertEquals(set(odd), set(found))
283 self.assertEqual(set(odd), set(found))
284
284
285 def test_hub_history(self):
285 def test_hub_history(self):
286 hist = self.client.hub_history()
286 hist = self.client.hub_history()
287 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
287 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
288 recdict = {}
288 recdict = {}
289 for rec in recs:
289 for rec in recs:
290 recdict[rec['msg_id']] = rec
290 recdict[rec['msg_id']] = rec
291
291
292 latest = datetime(1984,1,1)
292 latest = datetime(1984,1,1)
293 for msg_id in hist:
293 for msg_id in hist:
294 rec = recdict[msg_id]
294 rec = recdict[msg_id]
295 newt = rec['submitted']
295 newt = rec['submitted']
296 self.assertTrue(newt >= latest)
296 self.assertTrue(newt >= latest)
297 latest = newt
297 latest = newt
298 ar = self.client[-1].apply_async(lambda : 1)
298 ar = self.client[-1].apply_async(lambda : 1)
299 ar.get()
299 ar.get()
300 time.sleep(0.25)
300 time.sleep(0.25)
301 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
301 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
302
302
303 def _wait_for_idle(self):
303 def _wait_for_idle(self):
304 """wait for an engine to become idle, according to the Hub"""
304 """wait for an engine to become idle, according to the Hub"""
305 rc = self.client
305 rc = self.client
306
306
307 # timeout 5s, polling every 100ms
307 # timeout 5s, polling every 100ms
308 qs = rc.queue_status()
308 qs = rc.queue_status()
309 for i in range(50):
309 for i in range(50):
310 if qs['unassigned'] or any(qs[eid]['tasks'] for eid in rc.ids):
310 if qs['unassigned'] or any(qs[eid]['tasks'] for eid in rc.ids):
311 time.sleep(0.1)
311 time.sleep(0.1)
312 qs = rc.queue_status()
312 qs = rc.queue_status()
313 else:
313 else:
314 break
314 break
315
315
316 # ensure Hub up to date:
316 # ensure Hub up to date:
317 self.assertEquals(qs['unassigned'], 0)
317 self.assertEqual(qs['unassigned'], 0)
318 for eid in rc.ids:
318 for eid in rc.ids:
319 self.assertEquals(qs[eid]['tasks'], 0)
319 self.assertEqual(qs[eid]['tasks'], 0)
320
320
321
321
322 def test_resubmit(self):
322 def test_resubmit(self):
323 def f():
323 def f():
324 import random
324 import random
325 return random.random()
325 return random.random()
326 v = self.client.load_balanced_view()
326 v = self.client.load_balanced_view()
327 ar = v.apply_async(f)
327 ar = v.apply_async(f)
328 r1 = ar.get(1)
328 r1 = ar.get(1)
329 # give the Hub a chance to notice:
329 # give the Hub a chance to notice:
330 self._wait_for_idle()
330 self._wait_for_idle()
331 ahr = self.client.resubmit(ar.msg_ids)
331 ahr = self.client.resubmit(ar.msg_ids)
332 r2 = ahr.get(1)
332 r2 = ahr.get(1)
333 self.assertFalse(r1 == r2)
333 self.assertFalse(r1 == r2)
334
334
335 def test_resubmit_chain(self):
335 def test_resubmit_chain(self):
336 """resubmit resubmitted tasks"""
336 """resubmit resubmitted tasks"""
337 v = self.client.load_balanced_view()
337 v = self.client.load_balanced_view()
338 ar = v.apply_async(lambda x: x, 'x'*1024)
338 ar = v.apply_async(lambda x: x, 'x'*1024)
339 ar.get()
339 ar.get()
340 self._wait_for_idle()
340 self._wait_for_idle()
341 ars = [ar]
341 ars = [ar]
342
342
343 for i in range(10):
343 for i in range(10):
344 ar = ars[-1]
344 ar = ars[-1]
345 ar2 = self.client.resubmit(ar.msg_ids)
345 ar2 = self.client.resubmit(ar.msg_ids)
346
346
347 [ ar.get() for ar in ars ]
347 [ ar.get() for ar in ars ]
348
348
349 def test_resubmit_header(self):
349 def test_resubmit_header(self):
350 """resubmit shouldn't clobber the whole header"""
350 """resubmit shouldn't clobber the whole header"""
351 def f():
351 def f():
352 import random
352 import random
353 return random.random()
353 return random.random()
354 v = self.client.load_balanced_view()
354 v = self.client.load_balanced_view()
355 v.retries = 1
355 v.retries = 1
356 ar = v.apply_async(f)
356 ar = v.apply_async(f)
357 r1 = ar.get(1)
357 r1 = ar.get(1)
358 # give the Hub a chance to notice:
358 # give the Hub a chance to notice:
359 self._wait_for_idle()
359 self._wait_for_idle()
360 ahr = self.client.resubmit(ar.msg_ids)
360 ahr = self.client.resubmit(ar.msg_ids)
361 ahr.get(1)
361 ahr.get(1)
362 time.sleep(0.5)
362 time.sleep(0.5)
363 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
363 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
364 h1,h2 = [ r['header'] for r in records ]
364 h1,h2 = [ r['header'] for r in records ]
365 for key in set(h1.keys()).union(set(h2.keys())):
365 for key in set(h1.keys()).union(set(h2.keys())):
366 if key in ('msg_id', 'date'):
366 if key in ('msg_id', 'date'):
367 self.assertNotEquals(h1[key], h2[key])
367 self.assertNotEquals(h1[key], h2[key])
368 else:
368 else:
369 self.assertEquals(h1[key], h2[key])
369 self.assertEqual(h1[key], h2[key])
370
370
371 def test_resubmit_aborted(self):
371 def test_resubmit_aborted(self):
372 def f():
372 def f():
373 import random
373 import random
374 return random.random()
374 return random.random()
375 v = self.client.load_balanced_view()
375 v = self.client.load_balanced_view()
376 # restrict to one engine, so we can put a sleep
376 # restrict to one engine, so we can put a sleep
377 # ahead of the task, so it will get aborted
377 # ahead of the task, so it will get aborted
378 eid = self.client.ids[-1]
378 eid = self.client.ids[-1]
379 v.targets = [eid]
379 v.targets = [eid]
380 sleep = v.apply_async(time.sleep, 0.5)
380 sleep = v.apply_async(time.sleep, 0.5)
381 ar = v.apply_async(f)
381 ar = v.apply_async(f)
382 ar.abort()
382 ar.abort()
383 self.assertRaises(error.TaskAborted, ar.get)
383 self.assertRaises(error.TaskAborted, ar.get)
384 # Give the Hub a chance to get up to date:
384 # Give the Hub a chance to get up to date:
385 self._wait_for_idle()
385 self._wait_for_idle()
386 ahr = self.client.resubmit(ar.msg_ids)
386 ahr = self.client.resubmit(ar.msg_ids)
387 r2 = ahr.get(1)
387 r2 = ahr.get(1)
388
388
389 def test_resubmit_inflight(self):
389 def test_resubmit_inflight(self):
390 """resubmit of inflight task"""
390 """resubmit of inflight task"""
391 v = self.client.load_balanced_view()
391 v = self.client.load_balanced_view()
392 ar = v.apply_async(time.sleep,1)
392 ar = v.apply_async(time.sleep,1)
393 # give the message a chance to arrive
393 # give the message a chance to arrive
394 time.sleep(0.2)
394 time.sleep(0.2)
395 ahr = self.client.resubmit(ar.msg_ids)
395 ahr = self.client.resubmit(ar.msg_ids)
396 ar.get(2)
396 ar.get(2)
397 ahr.get(2)
397 ahr.get(2)
398
398
399 def test_resubmit_badkey(self):
399 def test_resubmit_badkey(self):
400 """ensure KeyError on resubmit of nonexistant task"""
400 """ensure KeyError on resubmit of nonexistant task"""
401 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
401 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
402
402
403 def test_purge_results(self):
403 def test_purge_results(self):
404 # ensure there are some tasks
404 # ensure there are some tasks
405 for i in range(5):
405 for i in range(5):
406 self.client[:].apply_sync(lambda : 1)
406 self.client[:].apply_sync(lambda : 1)
407 # Wait for the Hub to realise the result is done:
407 # Wait for the Hub to realise the result is done:
408 # This prevents a race condition, where we
408 # This prevents a race condition, where we
409 # might purge a result the Hub still thinks is pending.
409 # might purge a result the Hub still thinks is pending.
410 time.sleep(0.1)
410 time.sleep(0.1)
411 rc2 = clientmod.Client(profile='iptest')
411 rc2 = clientmod.Client(profile='iptest')
412 hist = self.client.hub_history()
412 hist = self.client.hub_history()
413 ahr = rc2.get_result([hist[-1]])
413 ahr = rc2.get_result([hist[-1]])
414 ahr.wait(10)
414 ahr.wait(10)
415 self.client.purge_results(hist[-1])
415 self.client.purge_results(hist[-1])
416 newhist = self.client.hub_history()
416 newhist = self.client.hub_history()
417 self.assertEquals(len(newhist)+1,len(hist))
417 self.assertEqual(len(newhist)+1,len(hist))
418 rc2.spin()
418 rc2.spin()
419 rc2.close()
419 rc2.close()
420
420
421 def test_purge_all_results(self):
421 def test_purge_all_results(self):
422 self.client.purge_results('all')
422 self.client.purge_results('all')
423 hist = self.client.hub_history()
423 hist = self.client.hub_history()
424 self.assertEquals(len(hist), 0)
424 self.assertEqual(len(hist), 0)
425
425
426 def test_spin_thread(self):
426 def test_spin_thread(self):
427 self.client.spin_thread(0.01)
427 self.client.spin_thread(0.01)
428 ar = self.client[-1].apply_async(lambda : 1)
428 ar = self.client[-1].apply_async(lambda : 1)
429 time.sleep(0.1)
429 time.sleep(0.1)
430 self.assertTrue(ar.wall_time < 0.1,
430 self.assertTrue(ar.wall_time < 0.1,
431 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
431 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
432 )
432 )
433
433
434 def test_stop_spin_thread(self):
434 def test_stop_spin_thread(self):
435 self.client.spin_thread(0.01)
435 self.client.spin_thread(0.01)
436 self.client.stop_spin_thread()
436 self.client.stop_spin_thread()
437 ar = self.client[-1].apply_async(lambda : 1)
437 ar = self.client[-1].apply_async(lambda : 1)
438 time.sleep(0.15)
438 time.sleep(0.15)
439 self.assertTrue(ar.wall_time > 0.1,
439 self.assertTrue(ar.wall_time > 0.1,
440 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
440 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
441 )
441 )
442
442
443 def test_activate(self):
443 def test_activate(self):
444 ip = get_ipython()
444 ip = get_ipython()
445 magics = ip.magics_manager.magics
445 magics = ip.magics_manager.magics
446 self.assertTrue('px' in magics['line'])
446 self.assertTrue('px' in magics['line'])
447 self.assertTrue('px' in magics['cell'])
447 self.assertTrue('px' in magics['cell'])
448 v0 = self.client.activate(-1, '0')
448 v0 = self.client.activate(-1, '0')
449 self.assertTrue('px0' in magics['line'])
449 self.assertTrue('px0' in magics['line'])
450 self.assertTrue('px0' in magics['cell'])
450 self.assertTrue('px0' in magics['cell'])
451 self.assertEquals(v0.targets, self.client.ids[-1])
451 self.assertEqual(v0.targets, self.client.ids[-1])
452 v0 = self.client.activate('all', 'all')
452 v0 = self.client.activate('all', 'all')
453 self.assertTrue('pxall' in magics['line'])
453 self.assertTrue('pxall' in magics['line'])
454 self.assertTrue('pxall' in magics['cell'])
454 self.assertTrue('pxall' in magics['cell'])
455 self.assertEquals(v0.targets, 'all')
455 self.assertEqual(v0.targets, 'all')
@@ -1,249 +1,249 b''
1 """Tests for db backends
1 """Tests for db backends
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import logging
21 import logging
22 import os
22 import os
23 import tempfile
23 import tempfile
24 import time
24 import time
25
25
26 from datetime import datetime, timedelta
26 from datetime import datetime, timedelta
27 from unittest import TestCase
27 from unittest import TestCase
28
28
29 from IPython.parallel import error
29 from IPython.parallel import error
30 from IPython.parallel.controller.dictdb import DictDB
30 from IPython.parallel.controller.dictdb import DictDB
31 from IPython.parallel.controller.sqlitedb import SQLiteDB
31 from IPython.parallel.controller.sqlitedb import SQLiteDB
32 from IPython.parallel.controller.hub import init_record, empty_record
32 from IPython.parallel.controller.hub import init_record, empty_record
33
33
34 from IPython.testing import decorators as dec
34 from IPython.testing import decorators as dec
35 from IPython.zmq.session import Session
35 from IPython.zmq.session import Session
36
36
37
37
38 #-------------------------------------------------------------------------------
38 #-------------------------------------------------------------------------------
39 # TestCases
39 # TestCases
40 #-------------------------------------------------------------------------------
40 #-------------------------------------------------------------------------------
41
41
42
42
43 def setup():
43 def setup():
44 global temp_db
44 global temp_db
45 temp_db = tempfile.NamedTemporaryFile(suffix='.db').name
45 temp_db = tempfile.NamedTemporaryFile(suffix='.db').name
46
46
47
47
48 class TestDictBackend(TestCase):
48 class TestDictBackend(TestCase):
49 def setUp(self):
49 def setUp(self):
50 self.session = Session()
50 self.session = Session()
51 self.db = self.create_db()
51 self.db = self.create_db()
52 self.load_records(16)
52 self.load_records(16)
53
53
54 def create_db(self):
54 def create_db(self):
55 return DictDB()
55 return DictDB()
56
56
57 def load_records(self, n=1):
57 def load_records(self, n=1):
58 """load n records for testing"""
58 """load n records for testing"""
59 #sleep 1/10 s, to ensure timestamp is different to previous calls
59 #sleep 1/10 s, to ensure timestamp is different to previous calls
60 time.sleep(0.1)
60 time.sleep(0.1)
61 msg_ids = []
61 msg_ids = []
62 for i in range(n):
62 for i in range(n):
63 msg = self.session.msg('apply_request', content=dict(a=5))
63 msg = self.session.msg('apply_request', content=dict(a=5))
64 msg['buffers'] = []
64 msg['buffers'] = []
65 rec = init_record(msg)
65 rec = init_record(msg)
66 msg_id = msg['header']['msg_id']
66 msg_id = msg['header']['msg_id']
67 msg_ids.append(msg_id)
67 msg_ids.append(msg_id)
68 self.db.add_record(msg_id, rec)
68 self.db.add_record(msg_id, rec)
69 return msg_ids
69 return msg_ids
70
70
71 def test_add_record(self):
71 def test_add_record(self):
72 before = self.db.get_history()
72 before = self.db.get_history()
73 self.load_records(5)
73 self.load_records(5)
74 after = self.db.get_history()
74 after = self.db.get_history()
75 self.assertEquals(len(after), len(before)+5)
75 self.assertEqual(len(after), len(before)+5)
76 self.assertEquals(after[:-5],before)
76 self.assertEqual(after[:-5],before)
77
77
78 def test_drop_record(self):
78 def test_drop_record(self):
79 msg_id = self.load_records()[-1]
79 msg_id = self.load_records()[-1]
80 rec = self.db.get_record(msg_id)
80 rec = self.db.get_record(msg_id)
81 self.db.drop_record(msg_id)
81 self.db.drop_record(msg_id)
82 self.assertRaises(KeyError,self.db.get_record, msg_id)
82 self.assertRaises(KeyError,self.db.get_record, msg_id)
83
83
84 def _round_to_millisecond(self, dt):
84 def _round_to_millisecond(self, dt):
85 """necessary because mongodb rounds microseconds"""
85 """necessary because mongodb rounds microseconds"""
86 micro = dt.microsecond
86 micro = dt.microsecond
87 extra = int(str(micro)[-3:])
87 extra = int(str(micro)[-3:])
88 return dt - timedelta(microseconds=extra)
88 return dt - timedelta(microseconds=extra)
89
89
90 def test_update_record(self):
90 def test_update_record(self):
91 now = self._round_to_millisecond(datetime.now())
91 now = self._round_to_millisecond(datetime.now())
92 #
92 #
93 msg_id = self.db.get_history()[-1]
93 msg_id = self.db.get_history()[-1]
94 rec1 = self.db.get_record(msg_id)
94 rec1 = self.db.get_record(msg_id)
95 data = {'stdout': 'hello there', 'completed' : now}
95 data = {'stdout': 'hello there', 'completed' : now}
96 self.db.update_record(msg_id, data)
96 self.db.update_record(msg_id, data)
97 rec2 = self.db.get_record(msg_id)
97 rec2 = self.db.get_record(msg_id)
98 self.assertEquals(rec2['stdout'], 'hello there')
98 self.assertEqual(rec2['stdout'], 'hello there')
99 self.assertEquals(rec2['completed'], now)
99 self.assertEqual(rec2['completed'], now)
100 rec1.update(data)
100 rec1.update(data)
101 self.assertEquals(rec1, rec2)
101 self.assertEqual(rec1, rec2)
102
102
103 # def test_update_record_bad(self):
103 # def test_update_record_bad(self):
104 # """test updating nonexistant records"""
104 # """test updating nonexistant records"""
105 # msg_id = str(uuid.uuid4())
105 # msg_id = str(uuid.uuid4())
106 # data = {'stdout': 'hello there'}
106 # data = {'stdout': 'hello there'}
107 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
107 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
108
108
109 def test_find_records_dt(self):
109 def test_find_records_dt(self):
110 """test finding records by date"""
110 """test finding records by date"""
111 hist = self.db.get_history()
111 hist = self.db.get_history()
112 middle = self.db.get_record(hist[len(hist)//2])
112 middle = self.db.get_record(hist[len(hist)//2])
113 tic = middle['submitted']
113 tic = middle['submitted']
114 before = self.db.find_records({'submitted' : {'$lt' : tic}})
114 before = self.db.find_records({'submitted' : {'$lt' : tic}})
115 after = self.db.find_records({'submitted' : {'$gte' : tic}})
115 after = self.db.find_records({'submitted' : {'$gte' : tic}})
116 self.assertEquals(len(before)+len(after),len(hist))
116 self.assertEqual(len(before)+len(after),len(hist))
117 for b in before:
117 for b in before:
118 self.assertTrue(b['submitted'] < tic)
118 self.assertTrue(b['submitted'] < tic)
119 for a in after:
119 for a in after:
120 self.assertTrue(a['submitted'] >= tic)
120 self.assertTrue(a['submitted'] >= tic)
121 same = self.db.find_records({'submitted' : tic})
121 same = self.db.find_records({'submitted' : tic})
122 for s in same:
122 for s in same:
123 self.assertTrue(s['submitted'] == tic)
123 self.assertTrue(s['submitted'] == tic)
124
124
125 def test_find_records_keys(self):
125 def test_find_records_keys(self):
126 """test extracting subset of record keys"""
126 """test extracting subset of record keys"""
127 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
127 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
128 for rec in found:
128 for rec in found:
129 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
129 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
130
130
131 def test_find_records_msg_id(self):
131 def test_find_records_msg_id(self):
132 """ensure msg_id is always in found records"""
132 """ensure msg_id is always in found records"""
133 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
133 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
134 for rec in found:
134 for rec in found:
135 self.assertTrue('msg_id' in rec.keys())
135 self.assertTrue('msg_id' in rec.keys())
136 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
136 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
137 for rec in found:
137 for rec in found:
138 self.assertTrue('msg_id' in rec.keys())
138 self.assertTrue('msg_id' in rec.keys())
139 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
139 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
140 for rec in found:
140 for rec in found:
141 self.assertTrue('msg_id' in rec.keys())
141 self.assertTrue('msg_id' in rec.keys())
142
142
143 def test_find_records_in(self):
143 def test_find_records_in(self):
144 """test finding records with '$in','$nin' operators"""
144 """test finding records with '$in','$nin' operators"""
145 hist = self.db.get_history()
145 hist = self.db.get_history()
146 even = hist[::2]
146 even = hist[::2]
147 odd = hist[1::2]
147 odd = hist[1::2]
148 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
148 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
149 found = [ r['msg_id'] for r in recs ]
149 found = [ r['msg_id'] for r in recs ]
150 self.assertEquals(set(even), set(found))
150 self.assertEqual(set(even), set(found))
151 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
151 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
152 found = [ r['msg_id'] for r in recs ]
152 found = [ r['msg_id'] for r in recs ]
153 self.assertEquals(set(odd), set(found))
153 self.assertEqual(set(odd), set(found))
154
154
155 def test_get_history(self):
155 def test_get_history(self):
156 msg_ids = self.db.get_history()
156 msg_ids = self.db.get_history()
157 latest = datetime(1984,1,1)
157 latest = datetime(1984,1,1)
158 for msg_id in msg_ids:
158 for msg_id in msg_ids:
159 rec = self.db.get_record(msg_id)
159 rec = self.db.get_record(msg_id)
160 newt = rec['submitted']
160 newt = rec['submitted']
161 self.assertTrue(newt >= latest)
161 self.assertTrue(newt >= latest)
162 latest = newt
162 latest = newt
163 msg_id = self.load_records(1)[-1]
163 msg_id = self.load_records(1)[-1]
164 self.assertEquals(self.db.get_history()[-1],msg_id)
164 self.assertEqual(self.db.get_history()[-1],msg_id)
165
165
166 def test_datetime(self):
166 def test_datetime(self):
167 """get/set timestamps with datetime objects"""
167 """get/set timestamps with datetime objects"""
168 msg_id = self.db.get_history()[-1]
168 msg_id = self.db.get_history()[-1]
169 rec = self.db.get_record(msg_id)
169 rec = self.db.get_record(msg_id)
170 self.assertTrue(isinstance(rec['submitted'], datetime))
170 self.assertTrue(isinstance(rec['submitted'], datetime))
171 self.db.update_record(msg_id, dict(completed=datetime.now()))
171 self.db.update_record(msg_id, dict(completed=datetime.now()))
172 rec = self.db.get_record(msg_id)
172 rec = self.db.get_record(msg_id)
173 self.assertTrue(isinstance(rec['completed'], datetime))
173 self.assertTrue(isinstance(rec['completed'], datetime))
174
174
175 def test_drop_matching(self):
175 def test_drop_matching(self):
176 msg_ids = self.load_records(10)
176 msg_ids = self.load_records(10)
177 query = {'msg_id' : {'$in':msg_ids}}
177 query = {'msg_id' : {'$in':msg_ids}}
178 self.db.drop_matching_records(query)
178 self.db.drop_matching_records(query)
179 recs = self.db.find_records(query)
179 recs = self.db.find_records(query)
180 self.assertEquals(len(recs), 0)
180 self.assertEqual(len(recs), 0)
181
181
182 def test_null(self):
182 def test_null(self):
183 """test None comparison queries"""
183 """test None comparison queries"""
184 msg_ids = self.load_records(10)
184 msg_ids = self.load_records(10)
185
185
186 query = {'msg_id' : None}
186 query = {'msg_id' : None}
187 recs = self.db.find_records(query)
187 recs = self.db.find_records(query)
188 self.assertEquals(len(recs), 0)
188 self.assertEqual(len(recs), 0)
189
189
190 query = {'msg_id' : {'$ne' : None}}
190 query = {'msg_id' : {'$ne' : None}}
191 recs = self.db.find_records(query)
191 recs = self.db.find_records(query)
192 self.assertTrue(len(recs) >= 10)
192 self.assertTrue(len(recs) >= 10)
193
193
194 def test_pop_safe_get(self):
194 def test_pop_safe_get(self):
195 """editing query results shouldn't affect record [get]"""
195 """editing query results shouldn't affect record [get]"""
196 msg_id = self.db.get_history()[-1]
196 msg_id = self.db.get_history()[-1]
197 rec = self.db.get_record(msg_id)
197 rec = self.db.get_record(msg_id)
198 rec.pop('buffers')
198 rec.pop('buffers')
199 rec['garbage'] = 'hello'
199 rec['garbage'] = 'hello'
200 rec['header']['msg_id'] = 'fubar'
200 rec['header']['msg_id'] = 'fubar'
201 rec2 = self.db.get_record(msg_id)
201 rec2 = self.db.get_record(msg_id)
202 self.assertTrue('buffers' in rec2)
202 self.assertTrue('buffers' in rec2)
203 self.assertFalse('garbage' in rec2)
203 self.assertFalse('garbage' in rec2)
204 self.assertEquals(rec2['header']['msg_id'], msg_id)
204 self.assertEqual(rec2['header']['msg_id'], msg_id)
205
205
206 def test_pop_safe_find(self):
206 def test_pop_safe_find(self):
207 """editing query results shouldn't affect record [find]"""
207 """editing query results shouldn't affect record [find]"""
208 msg_id = self.db.get_history()[-1]
208 msg_id = self.db.get_history()[-1]
209 rec = self.db.find_records({'msg_id' : msg_id})[0]
209 rec = self.db.find_records({'msg_id' : msg_id})[0]
210 rec.pop('buffers')
210 rec.pop('buffers')
211 rec['garbage'] = 'hello'
211 rec['garbage'] = 'hello'
212 rec['header']['msg_id'] = 'fubar'
212 rec['header']['msg_id'] = 'fubar'
213 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
213 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
214 self.assertTrue('buffers' in rec2)
214 self.assertTrue('buffers' in rec2)
215 self.assertFalse('garbage' in rec2)
215 self.assertFalse('garbage' in rec2)
216 self.assertEquals(rec2['header']['msg_id'], msg_id)
216 self.assertEqual(rec2['header']['msg_id'], msg_id)
217
217
218 def test_pop_safe_find_keys(self):
218 def test_pop_safe_find_keys(self):
219 """editing query results shouldn't affect record [find+keys]"""
219 """editing query results shouldn't affect record [find+keys]"""
220 msg_id = self.db.get_history()[-1]
220 msg_id = self.db.get_history()[-1]
221 rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers', 'header'])[0]
221 rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers', 'header'])[0]
222 rec.pop('buffers')
222 rec.pop('buffers')
223 rec['garbage'] = 'hello'
223 rec['garbage'] = 'hello'
224 rec['header']['msg_id'] = 'fubar'
224 rec['header']['msg_id'] = 'fubar'
225 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
225 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
226 self.assertTrue('buffers' in rec2)
226 self.assertTrue('buffers' in rec2)
227 self.assertFalse('garbage' in rec2)
227 self.assertFalse('garbage' in rec2)
228 self.assertEquals(rec2['header']['msg_id'], msg_id)
228 self.assertEqual(rec2['header']['msg_id'], msg_id)
229
229
230
230
231 class TestSQLiteBackend(TestDictBackend):
231 class TestSQLiteBackend(TestDictBackend):
232
232
233 @dec.skip_without('sqlite3')
233 @dec.skip_without('sqlite3')
234 def create_db(self):
234 def create_db(self):
235 location, fname = os.path.split(temp_db)
235 location, fname = os.path.split(temp_db)
236 log = logging.getLogger('test')
236 log = logging.getLogger('test')
237 log.setLevel(logging.CRITICAL)
237 log.setLevel(logging.CRITICAL)
238 return SQLiteDB(location=location, fname=fname, log=log)
238 return SQLiteDB(location=location, fname=fname, log=log)
239
239
240 def tearDown(self):
240 def tearDown(self):
241 self.db._db.close()
241 self.db._db.close()
242
242
243
243
244 def teardown():
244 def teardown():
245 """cleanup task db file after all tests have run"""
245 """cleanup task db file after all tests have run"""
246 try:
246 try:
247 os.remove(temp_db)
247 os.remove(temp_db)
248 except:
248 except:
249 pass
249 pass
@@ -1,106 +1,106 b''
1 """Tests for dependency.py
1 """Tests for dependency.py
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 __docformat__ = "restructuredtext en"
8 __docformat__ = "restructuredtext en"
9
9
10 #-------------------------------------------------------------------------------
10 #-------------------------------------------------------------------------------
11 # Copyright (C) 2011 The IPython Development Team
11 # Copyright (C) 2011 The IPython Development Team
12 #
12 #
13 # Distributed under the terms of the BSD License. The full license is in
13 # Distributed under the terms of the BSD License. The full license is in
14 # the file COPYING, distributed as part of this software.
14 # the file COPYING, distributed as part of this software.
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16
16
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18 # Imports
18 # Imports
19 #-------------------------------------------------------------------------------
19 #-------------------------------------------------------------------------------
20
20
21 # import
21 # import
22 import os
22 import os
23
23
24 from IPython.utils.pickleutil import can, uncan
24 from IPython.utils.pickleutil import can, uncan
25
25
26 import IPython.parallel as pmod
26 import IPython.parallel as pmod
27 from IPython.parallel.util import interactive
27 from IPython.parallel.util import interactive
28
28
29 from IPython.parallel.tests import add_engines
29 from IPython.parallel.tests import add_engines
30 from .clienttest import ClusterTestCase
30 from .clienttest import ClusterTestCase
31
31
32 def setup():
32 def setup():
33 add_engines(1, total=True)
33 add_engines(1, total=True)
34
34
35 @pmod.require('time')
35 @pmod.require('time')
36 def wait(n):
36 def wait(n):
37 time.sleep(n)
37 time.sleep(n)
38 return n
38 return n
39
39
40 mixed = map(str, range(10))
40 mixed = map(str, range(10))
41 completed = map(str, range(0,10,2))
41 completed = map(str, range(0,10,2))
42 failed = map(str, range(1,10,2))
42 failed = map(str, range(1,10,2))
43
43
44 class DependencyTest(ClusterTestCase):
44 class DependencyTest(ClusterTestCase):
45
45
46 def setUp(self):
46 def setUp(self):
47 ClusterTestCase.setUp(self)
47 ClusterTestCase.setUp(self)
48 self.user_ns = {'__builtins__' : __builtins__}
48 self.user_ns = {'__builtins__' : __builtins__}
49 self.view = self.client.load_balanced_view()
49 self.view = self.client.load_balanced_view()
50 self.dview = self.client[-1]
50 self.dview = self.client[-1]
51 self.succeeded = set(map(str, range(0,25,2)))
51 self.succeeded = set(map(str, range(0,25,2)))
52 self.failed = set(map(str, range(1,25,2)))
52 self.failed = set(map(str, range(1,25,2)))
53
53
54 def assertMet(self, dep):
54 def assertMet(self, dep):
55 self.assertTrue(dep.check(self.succeeded, self.failed), "Dependency should be met")
55 self.assertTrue(dep.check(self.succeeded, self.failed), "Dependency should be met")
56
56
57 def assertUnmet(self, dep):
57 def assertUnmet(self, dep):
58 self.assertFalse(dep.check(self.succeeded, self.failed), "Dependency should not be met")
58 self.assertFalse(dep.check(self.succeeded, self.failed), "Dependency should not be met")
59
59
60 def assertUnreachable(self, dep):
60 def assertUnreachable(self, dep):
61 self.assertTrue(dep.unreachable(self.succeeded, self.failed), "Dependency should be unreachable")
61 self.assertTrue(dep.unreachable(self.succeeded, self.failed), "Dependency should be unreachable")
62
62
63 def assertReachable(self, dep):
63 def assertReachable(self, dep):
64 self.assertFalse(dep.unreachable(self.succeeded, self.failed), "Dependency should be reachable")
64 self.assertFalse(dep.unreachable(self.succeeded, self.failed), "Dependency should be reachable")
65
65
66 def cancan(self, f):
66 def cancan(self, f):
67 """decorator to pass through canning into self.user_ns"""
67 """decorator to pass through canning into self.user_ns"""
68 return uncan(can(f), self.user_ns)
68 return uncan(can(f), self.user_ns)
69
69
70 def test_require_imports(self):
70 def test_require_imports(self):
71 """test that @require imports names"""
71 """test that @require imports names"""
72 @self.cancan
72 @self.cancan
73 @pmod.require('urllib')
73 @pmod.require('urllib')
74 @interactive
74 @interactive
75 def encode(dikt):
75 def encode(dikt):
76 return urllib.urlencode(dikt)
76 return urllib.urlencode(dikt)
77 # must pass through canning to properly connect namespaces
77 # must pass through canning to properly connect namespaces
78 self.assertEquals(encode(dict(a=5)), 'a=5')
78 self.assertEqual(encode(dict(a=5)), 'a=5')
79
79
80 def test_success_only(self):
80 def test_success_only(self):
81 dep = pmod.Dependency(mixed, success=True, failure=False)
81 dep = pmod.Dependency(mixed, success=True, failure=False)
82 self.assertUnmet(dep)
82 self.assertUnmet(dep)
83 self.assertUnreachable(dep)
83 self.assertUnreachable(dep)
84 dep.all=False
84 dep.all=False
85 self.assertMet(dep)
85 self.assertMet(dep)
86 self.assertReachable(dep)
86 self.assertReachable(dep)
87 dep = pmod.Dependency(completed, success=True, failure=False)
87 dep = pmod.Dependency(completed, success=True, failure=False)
88 self.assertMet(dep)
88 self.assertMet(dep)
89 self.assertReachable(dep)
89 self.assertReachable(dep)
90 dep.all=False
90 dep.all=False
91 self.assertMet(dep)
91 self.assertMet(dep)
92 self.assertReachable(dep)
92 self.assertReachable(dep)
93
93
94 def test_failure_only(self):
94 def test_failure_only(self):
95 dep = pmod.Dependency(mixed, success=False, failure=True)
95 dep = pmod.Dependency(mixed, success=False, failure=True)
96 self.assertUnmet(dep)
96 self.assertUnmet(dep)
97 self.assertUnreachable(dep)
97 self.assertUnreachable(dep)
98 dep.all=False
98 dep.all=False
99 self.assertMet(dep)
99 self.assertMet(dep)
100 self.assertReachable(dep)
100 self.assertReachable(dep)
101 dep = pmod.Dependency(completed, success=False, failure=True)
101 dep = pmod.Dependency(completed, success=False, failure=True)
102 self.assertUnmet(dep)
102 self.assertUnmet(dep)
103 self.assertUnreachable(dep)
103 self.assertUnreachable(dep)
104 dep.all=False
104 dep.all=False
105 self.assertUnmet(dep)
105 self.assertUnmet(dep)
106 self.assertUnreachable(dep)
106 self.assertUnreachable(dep)
@@ -1,176 +1,176 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test LoadBalancedView objects
2 """test LoadBalancedView objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import time
20 import time
21
21
22 import zmq
22 import zmq
23 from nose import SkipTest
23 from nose import SkipTest
24
24
25 from IPython import parallel as pmod
25 from IPython import parallel as pmod
26 from IPython.parallel import error
26 from IPython.parallel import error
27
27
28 from IPython.parallel.tests import add_engines
28 from IPython.parallel.tests import add_engines
29
29
30 from .clienttest import ClusterTestCase, crash, wait, skip_without
30 from .clienttest import ClusterTestCase, crash, wait, skip_without
31
31
32 def setup():
32 def setup():
33 add_engines(3, total=True)
33 add_engines(3, total=True)
34
34
35 class TestLoadBalancedView(ClusterTestCase):
35 class TestLoadBalancedView(ClusterTestCase):
36
36
37 def setUp(self):
37 def setUp(self):
38 ClusterTestCase.setUp(self)
38 ClusterTestCase.setUp(self)
39 self.view = self.client.load_balanced_view()
39 self.view = self.client.load_balanced_view()
40
40
41 def test_z_crash_task(self):
41 def test_z_crash_task(self):
42 """test graceful handling of engine death (balanced)"""
42 """test graceful handling of engine death (balanced)"""
43 raise SkipTest("crash tests disabled, due to undesirable crash reports")
43 raise SkipTest("crash tests disabled, due to undesirable crash reports")
44 # self.add_engines(1)
44 # self.add_engines(1)
45 ar = self.view.apply_async(crash)
45 ar = self.view.apply_async(crash)
46 self.assertRaisesRemote(error.EngineError, ar.get, 10)
46 self.assertRaisesRemote(error.EngineError, ar.get, 10)
47 eid = ar.engine_id
47 eid = ar.engine_id
48 tic = time.time()
48 tic = time.time()
49 while eid in self.client.ids and time.time()-tic < 5:
49 while eid in self.client.ids and time.time()-tic < 5:
50 time.sleep(.01)
50 time.sleep(.01)
51 self.client.spin()
51 self.client.spin()
52 self.assertFalse(eid in self.client.ids, "Engine should have died")
52 self.assertFalse(eid in self.client.ids, "Engine should have died")
53
53
54 def test_map(self):
54 def test_map(self):
55 def f(x):
55 def f(x):
56 return x**2
56 return x**2
57 data = range(16)
57 data = range(16)
58 r = self.view.map_sync(f, data)
58 r = self.view.map_sync(f, data)
59 self.assertEquals(r, map(f, data))
59 self.assertEqual(r, map(f, data))
60
60
61 def test_map_unordered(self):
61 def test_map_unordered(self):
62 def f(x):
62 def f(x):
63 return x**2
63 return x**2
64 def slow_f(x):
64 def slow_f(x):
65 import time
65 import time
66 time.sleep(0.05*x)
66 time.sleep(0.05*x)
67 return x**2
67 return x**2
68 data = range(16,0,-1)
68 data = range(16,0,-1)
69 reference = map(f, data)
69 reference = map(f, data)
70
70
71 amr = self.view.map_async(slow_f, data, ordered=False)
71 amr = self.view.map_async(slow_f, data, ordered=False)
72 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
72 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
73 # check individual elements, retrieved as they come
73 # check individual elements, retrieved as they come
74 # list comprehension uses __iter__
74 # list comprehension uses __iter__
75 astheycame = [ r for r in amr ]
75 astheycame = [ r for r in amr ]
76 # Ensure that at least one result came out of order:
76 # Ensure that at least one result came out of order:
77 self.assertNotEquals(astheycame, reference, "should not have preserved order")
77 self.assertNotEquals(astheycame, reference, "should not have preserved order")
78 self.assertEquals(sorted(astheycame, reverse=True), reference, "result corrupted")
78 self.assertEqual(sorted(astheycame, reverse=True), reference, "result corrupted")
79
79
80 def test_map_ordered(self):
80 def test_map_ordered(self):
81 def f(x):
81 def f(x):
82 return x**2
82 return x**2
83 def slow_f(x):
83 def slow_f(x):
84 import time
84 import time
85 time.sleep(0.05*x)
85 time.sleep(0.05*x)
86 return x**2
86 return x**2
87 data = range(16,0,-1)
87 data = range(16,0,-1)
88 reference = map(f, data)
88 reference = map(f, data)
89
89
90 amr = self.view.map_async(slow_f, data)
90 amr = self.view.map_async(slow_f, data)
91 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
91 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
92 # check individual elements, retrieved as they come
92 # check individual elements, retrieved as they come
93 # list(amr) uses __iter__
93 # list(amr) uses __iter__
94 astheycame = list(amr)
94 astheycame = list(amr)
95 # Ensure that results came in order
95 # Ensure that results came in order
96 self.assertEquals(astheycame, reference)
96 self.assertEqual(astheycame, reference)
97 self.assertEquals(amr.result, reference)
97 self.assertEqual(amr.result, reference)
98
98
99 def test_map_iterable(self):
99 def test_map_iterable(self):
100 """test map on iterables (balanced)"""
100 """test map on iterables (balanced)"""
101 view = self.view
101 view = self.view
102 # 101 is prime, so it won't be evenly distributed
102 # 101 is prime, so it won't be evenly distributed
103 arr = range(101)
103 arr = range(101)
104 # so that it will be an iterator, even in Python 3
104 # so that it will be an iterator, even in Python 3
105 it = iter(arr)
105 it = iter(arr)
106 r = view.map_sync(lambda x:x, arr)
106 r = view.map_sync(lambda x:x, arr)
107 self.assertEquals(r, list(arr))
107 self.assertEqual(r, list(arr))
108
108
109
109
110 def test_abort(self):
110 def test_abort(self):
111 view = self.view
111 view = self.view
112 ar = self.client[:].apply_async(time.sleep, .5)
112 ar = self.client[:].apply_async(time.sleep, .5)
113 ar = self.client[:].apply_async(time.sleep, .5)
113 ar = self.client[:].apply_async(time.sleep, .5)
114 time.sleep(0.2)
114 time.sleep(0.2)
115 ar2 = view.apply_async(lambda : 2)
115 ar2 = view.apply_async(lambda : 2)
116 ar3 = view.apply_async(lambda : 3)
116 ar3 = view.apply_async(lambda : 3)
117 view.abort(ar2)
117 view.abort(ar2)
118 view.abort(ar3.msg_ids)
118 view.abort(ar3.msg_ids)
119 self.assertRaises(error.TaskAborted, ar2.get)
119 self.assertRaises(error.TaskAborted, ar2.get)
120 self.assertRaises(error.TaskAborted, ar3.get)
120 self.assertRaises(error.TaskAborted, ar3.get)
121
121
122 def test_retries(self):
122 def test_retries(self):
123 view = self.view
123 view = self.view
124 view.timeout = 1 # prevent hang if this doesn't behave
124 view.timeout = 1 # prevent hang if this doesn't behave
125 def fail():
125 def fail():
126 assert False
126 assert False
127 for r in range(len(self.client)-1):
127 for r in range(len(self.client)-1):
128 with view.temp_flags(retries=r):
128 with view.temp_flags(retries=r):
129 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
129 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
130
130
131 with view.temp_flags(retries=len(self.client), timeout=0.25):
131 with view.temp_flags(retries=len(self.client), timeout=0.25):
132 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
132 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
133
133
134 def test_invalid_dependency(self):
134 def test_invalid_dependency(self):
135 view = self.view
135 view = self.view
136 with view.temp_flags(after='12345'):
136 with view.temp_flags(after='12345'):
137 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
137 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
138
138
139 def test_impossible_dependency(self):
139 def test_impossible_dependency(self):
140 self.minimum_engines(2)
140 self.minimum_engines(2)
141 view = self.client.load_balanced_view()
141 view = self.client.load_balanced_view()
142 ar1 = view.apply_async(lambda : 1)
142 ar1 = view.apply_async(lambda : 1)
143 ar1.get()
143 ar1.get()
144 e1 = ar1.engine_id
144 e1 = ar1.engine_id
145 e2 = e1
145 e2 = e1
146 while e2 == e1:
146 while e2 == e1:
147 ar2 = view.apply_async(lambda : 1)
147 ar2 = view.apply_async(lambda : 1)
148 ar2.get()
148 ar2.get()
149 e2 = ar2.engine_id
149 e2 = ar2.engine_id
150
150
151 with view.temp_flags(follow=[ar1, ar2]):
151 with view.temp_flags(follow=[ar1, ar2]):
152 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
152 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
153
153
154
154
155 def test_follow(self):
155 def test_follow(self):
156 ar = self.view.apply_async(lambda : 1)
156 ar = self.view.apply_async(lambda : 1)
157 ar.get()
157 ar.get()
158 ars = []
158 ars = []
159 first_id = ar.engine_id
159 first_id = ar.engine_id
160
160
161 self.view.follow = ar
161 self.view.follow = ar
162 for i in range(5):
162 for i in range(5):
163 ars.append(self.view.apply_async(lambda : 1))
163 ars.append(self.view.apply_async(lambda : 1))
164 self.view.wait(ars)
164 self.view.wait(ars)
165 for ar in ars:
165 for ar in ars:
166 self.assertEquals(ar.engine_id, first_id)
166 self.assertEqual(ar.engine_id, first_id)
167
167
168 def test_after(self):
168 def test_after(self):
169 view = self.view
169 view = self.view
170 ar = view.apply_async(time.sleep, 0.5)
170 ar = view.apply_async(time.sleep, 0.5)
171 with view.temp_flags(after=ar):
171 with view.temp_flags(after=ar):
172 ar2 = view.apply_async(lambda : 1)
172 ar2 = view.apply_async(lambda : 1)
173
173
174 ar.wait()
174 ar.wait()
175 ar2.wait()
175 ar2.wait()
176 self.assertTrue(ar2.started >= ar.completed, "%s not >= %s"%(ar.started, ar.completed))
176 self.assertTrue(ar2.started >= ar.completed, "%s not >= %s"%(ar.started, ar.completed))
@@ -1,386 +1,386 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Test Parallel magics
2 """Test Parallel magics
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import re
19 import re
20 import sys
20 import sys
21 import time
21 import time
22
22
23 import zmq
23 import zmq
24 from nose import SkipTest
24 from nose import SkipTest
25
25
26 from IPython.testing import decorators as dec
26 from IPython.testing import decorators as dec
27 from IPython.testing.ipunittest import ParametricTestCase
27 from IPython.testing.ipunittest import ParametricTestCase
28 from IPython.utils.io import capture_output
28 from IPython.utils.io import capture_output
29
29
30 from IPython import parallel as pmod
30 from IPython import parallel as pmod
31 from IPython.parallel import error
31 from IPython.parallel import error
32 from IPython.parallel import AsyncResult
32 from IPython.parallel import AsyncResult
33 from IPython.parallel.util import interactive
33 from IPython.parallel.util import interactive
34
34
35 from IPython.parallel.tests import add_engines
35 from IPython.parallel.tests import add_engines
36
36
37 from .clienttest import ClusterTestCase, generate_output
37 from .clienttest import ClusterTestCase, generate_output
38
38
39 def setup():
39 def setup():
40 add_engines(3, total=True)
40 add_engines(3, total=True)
41
41
42 class TestParallelMagics(ClusterTestCase, ParametricTestCase):
42 class TestParallelMagics(ClusterTestCase, ParametricTestCase):
43
43
44 def test_px_blocking(self):
44 def test_px_blocking(self):
45 ip = get_ipython()
45 ip = get_ipython()
46 v = self.client[-1:]
46 v = self.client[-1:]
47 v.activate()
47 v.activate()
48 v.block=True
48 v.block=True
49
49
50 ip.magic('px a=5')
50 ip.magic('px a=5')
51 self.assertEquals(v['a'], [5])
51 self.assertEqual(v['a'], [5])
52 ip.magic('px a=10')
52 ip.magic('px a=10')
53 self.assertEquals(v['a'], [10])
53 self.assertEqual(v['a'], [10])
54 # just 'print a' works ~99% of the time, but this ensures that
54 # just 'print a' works ~99% of the time, but this ensures that
55 # the stdout message has arrived when the result is finished:
55 # the stdout message has arrived when the result is finished:
56 with capture_output() as io:
56 with capture_output() as io:
57 ip.magic(
57 ip.magic(
58 'px import sys,time;print(a);sys.stdout.flush();time.sleep(0.2)'
58 'px import sys,time;print(a);sys.stdout.flush();time.sleep(0.2)'
59 )
59 )
60 out = io.stdout
60 out = io.stdout
61 self.assertTrue('[stdout:' in out, out)
61 self.assertTrue('[stdout:' in out, out)
62 self.assertFalse('\n\n' in out)
62 self.assertFalse('\n\n' in out)
63 self.assertTrue(out.rstrip().endswith('10'))
63 self.assertTrue(out.rstrip().endswith('10'))
64 self.assertRaisesRemote(ZeroDivisionError, ip.magic, 'px 1/0')
64 self.assertRaisesRemote(ZeroDivisionError, ip.magic, 'px 1/0')
65
65
66 def _check_generated_stderr(self, stderr, n):
66 def _check_generated_stderr(self, stderr, n):
67 expected = [
67 expected = [
68 r'\[stderr:\d+\]',
68 r'\[stderr:\d+\]',
69 '^stderr$',
69 '^stderr$',
70 '^stderr2$',
70 '^stderr2$',
71 ] * n
71 ] * n
72
72
73 self.assertFalse('\n\n' in stderr, stderr)
73 self.assertFalse('\n\n' in stderr, stderr)
74 lines = stderr.splitlines()
74 lines = stderr.splitlines()
75 self.assertEquals(len(lines), len(expected), stderr)
75 self.assertEqual(len(lines), len(expected), stderr)
76 for line,expect in zip(lines, expected):
76 for line,expect in zip(lines, expected):
77 if isinstance(expect, str):
77 if isinstance(expect, str):
78 expect = [expect]
78 expect = [expect]
79 for ex in expect:
79 for ex in expect:
80 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
80 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
81
81
82 def test_cellpx_block_args(self):
82 def test_cellpx_block_args(self):
83 """%%px --[no]block flags work"""
83 """%%px --[no]block flags work"""
84 ip = get_ipython()
84 ip = get_ipython()
85 v = self.client[-1:]
85 v = self.client[-1:]
86 v.activate()
86 v.activate()
87 v.block=False
87 v.block=False
88
88
89 for block in (True, False):
89 for block in (True, False):
90 v.block = block
90 v.block = block
91 ip.magic("pxconfig --verbose")
91 ip.magic("pxconfig --verbose")
92 with capture_output() as io:
92 with capture_output() as io:
93 ip.run_cell_magic("px", "", "1")
93 ip.run_cell_magic("px", "", "1")
94 if block:
94 if block:
95 self.assertTrue(io.stdout.startswith("Parallel"), io.stdout)
95 self.assertTrue(io.stdout.startswith("Parallel"), io.stdout)
96 else:
96 else:
97 self.assertTrue(io.stdout.startswith("Async"), io.stdout)
97 self.assertTrue(io.stdout.startswith("Async"), io.stdout)
98
98
99 with capture_output() as io:
99 with capture_output() as io:
100 ip.run_cell_magic("px", "--block", "1")
100 ip.run_cell_magic("px", "--block", "1")
101 self.assertTrue(io.stdout.startswith("Parallel"), io.stdout)
101 self.assertTrue(io.stdout.startswith("Parallel"), io.stdout)
102
102
103 with capture_output() as io:
103 with capture_output() as io:
104 ip.run_cell_magic("px", "--noblock", "1")
104 ip.run_cell_magic("px", "--noblock", "1")
105 self.assertTrue(io.stdout.startswith("Async"), io.stdout)
105 self.assertTrue(io.stdout.startswith("Async"), io.stdout)
106
106
107 def test_cellpx_groupby_engine(self):
107 def test_cellpx_groupby_engine(self):
108 """%%px --group-outputs=engine"""
108 """%%px --group-outputs=engine"""
109 ip = get_ipython()
109 ip = get_ipython()
110 v = self.client[:]
110 v = self.client[:]
111 v.block = True
111 v.block = True
112 v.activate()
112 v.activate()
113
113
114 v['generate_output'] = generate_output
114 v['generate_output'] = generate_output
115
115
116 with capture_output() as io:
116 with capture_output() as io:
117 ip.run_cell_magic('px', '--group-outputs=engine', 'generate_output()')
117 ip.run_cell_magic('px', '--group-outputs=engine', 'generate_output()')
118
118
119 self.assertFalse('\n\n' in io.stdout)
119 self.assertFalse('\n\n' in io.stdout)
120 lines = io.stdout.splitlines()
120 lines = io.stdout.splitlines()
121 expected = [
121 expected = [
122 r'\[stdout:\d+\]',
122 r'\[stdout:\d+\]',
123 'stdout',
123 'stdout',
124 'stdout2',
124 'stdout2',
125 r'\[output:\d+\]',
125 r'\[output:\d+\]',
126 r'IPython\.core\.display\.HTML',
126 r'IPython\.core\.display\.HTML',
127 r'IPython\.core\.display\.Math',
127 r'IPython\.core\.display\.Math',
128 r'Out\[\d+:\d+\]:.*IPython\.core\.display\.Math',
128 r'Out\[\d+:\d+\]:.*IPython\.core\.display\.Math',
129 ] * len(v)
129 ] * len(v)
130
130
131 self.assertEquals(len(lines), len(expected), io.stdout)
131 self.assertEqual(len(lines), len(expected), io.stdout)
132 for line,expect in zip(lines, expected):
132 for line,expect in zip(lines, expected):
133 if isinstance(expect, str):
133 if isinstance(expect, str):
134 expect = [expect]
134 expect = [expect]
135 for ex in expect:
135 for ex in expect:
136 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
136 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
137
137
138 self._check_generated_stderr(io.stderr, len(v))
138 self._check_generated_stderr(io.stderr, len(v))
139
139
140
140
141 def test_cellpx_groupby_order(self):
141 def test_cellpx_groupby_order(self):
142 """%%px --group-outputs=order"""
142 """%%px --group-outputs=order"""
143 ip = get_ipython()
143 ip = get_ipython()
144 v = self.client[:]
144 v = self.client[:]
145 v.block = True
145 v.block = True
146 v.activate()
146 v.activate()
147
147
148 v['generate_output'] = generate_output
148 v['generate_output'] = generate_output
149
149
150 with capture_output() as io:
150 with capture_output() as io:
151 ip.run_cell_magic('px', '--group-outputs=order', 'generate_output()')
151 ip.run_cell_magic('px', '--group-outputs=order', 'generate_output()')
152
152
153 self.assertFalse('\n\n' in io.stdout)
153 self.assertFalse('\n\n' in io.stdout)
154 lines = io.stdout.splitlines()
154 lines = io.stdout.splitlines()
155 expected = []
155 expected = []
156 expected.extend([
156 expected.extend([
157 r'\[stdout:\d+\]',
157 r'\[stdout:\d+\]',
158 'stdout',
158 'stdout',
159 'stdout2',
159 'stdout2',
160 ] * len(v))
160 ] * len(v))
161 expected.extend([
161 expected.extend([
162 r'\[output:\d+\]',
162 r'\[output:\d+\]',
163 'IPython.core.display.HTML',
163 'IPython.core.display.HTML',
164 ] * len(v))
164 ] * len(v))
165 expected.extend([
165 expected.extend([
166 r'\[output:\d+\]',
166 r'\[output:\d+\]',
167 'IPython.core.display.Math',
167 'IPython.core.display.Math',
168 ] * len(v))
168 ] * len(v))
169 expected.extend([
169 expected.extend([
170 r'Out\[\d+:\d+\]:.*IPython\.core\.display\.Math'
170 r'Out\[\d+:\d+\]:.*IPython\.core\.display\.Math'
171 ] * len(v))
171 ] * len(v))
172
172
173 self.assertEquals(len(lines), len(expected), io.stdout)
173 self.assertEqual(len(lines), len(expected), io.stdout)
174 for line,expect in zip(lines, expected):
174 for line,expect in zip(lines, expected):
175 if isinstance(expect, str):
175 if isinstance(expect, str):
176 expect = [expect]
176 expect = [expect]
177 for ex in expect:
177 for ex in expect:
178 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
178 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
179
179
180 self._check_generated_stderr(io.stderr, len(v))
180 self._check_generated_stderr(io.stderr, len(v))
181
181
182 def test_cellpx_groupby_type(self):
182 def test_cellpx_groupby_type(self):
183 """%%px --group-outputs=type"""
183 """%%px --group-outputs=type"""
184 ip = get_ipython()
184 ip = get_ipython()
185 v = self.client[:]
185 v = self.client[:]
186 v.block = True
186 v.block = True
187 v.activate()
187 v.activate()
188
188
189 v['generate_output'] = generate_output
189 v['generate_output'] = generate_output
190
190
191 with capture_output() as io:
191 with capture_output() as io:
192 ip.run_cell_magic('px', '--group-outputs=type', 'generate_output()')
192 ip.run_cell_magic('px', '--group-outputs=type', 'generate_output()')
193
193
194 self.assertFalse('\n\n' in io.stdout)
194 self.assertFalse('\n\n' in io.stdout)
195 lines = io.stdout.splitlines()
195 lines = io.stdout.splitlines()
196
196
197 expected = []
197 expected = []
198 expected.extend([
198 expected.extend([
199 r'\[stdout:\d+\]',
199 r'\[stdout:\d+\]',
200 'stdout',
200 'stdout',
201 'stdout2',
201 'stdout2',
202 ] * len(v))
202 ] * len(v))
203 expected.extend([
203 expected.extend([
204 r'\[output:\d+\]',
204 r'\[output:\d+\]',
205 r'IPython\.core\.display\.HTML',
205 r'IPython\.core\.display\.HTML',
206 r'IPython\.core\.display\.Math',
206 r'IPython\.core\.display\.Math',
207 ] * len(v))
207 ] * len(v))
208 expected.extend([
208 expected.extend([
209 (r'Out\[\d+:\d+\]', r'IPython\.core\.display\.Math')
209 (r'Out\[\d+:\d+\]', r'IPython\.core\.display\.Math')
210 ] * len(v))
210 ] * len(v))
211
211
212 self.assertEquals(len(lines), len(expected), io.stdout)
212 self.assertEqual(len(lines), len(expected), io.stdout)
213 for line,expect in zip(lines, expected):
213 for line,expect in zip(lines, expected):
214 if isinstance(expect, str):
214 if isinstance(expect, str):
215 expect = [expect]
215 expect = [expect]
216 for ex in expect:
216 for ex in expect:
217 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
217 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
218
218
219 self._check_generated_stderr(io.stderr, len(v))
219 self._check_generated_stderr(io.stderr, len(v))
220
220
221
221
222 def test_px_nonblocking(self):
222 def test_px_nonblocking(self):
223 ip = get_ipython()
223 ip = get_ipython()
224 v = self.client[-1:]
224 v = self.client[-1:]
225 v.activate()
225 v.activate()
226 v.block=False
226 v.block=False
227
227
228 ip.magic('px a=5')
228 ip.magic('px a=5')
229 self.assertEquals(v['a'], [5])
229 self.assertEqual(v['a'], [5])
230 ip.magic('px a=10')
230 ip.magic('px a=10')
231 self.assertEquals(v['a'], [10])
231 self.assertEqual(v['a'], [10])
232 ip.magic('pxconfig --verbose')
232 ip.magic('pxconfig --verbose')
233 with capture_output() as io:
233 with capture_output() as io:
234 ar = ip.magic('px print (a)')
234 ar = ip.magic('px print (a)')
235 self.assertTrue(isinstance(ar, AsyncResult))
235 self.assertTrue(isinstance(ar, AsyncResult))
236 self.assertTrue('Async' in io.stdout)
236 self.assertTrue('Async' in io.stdout)
237 self.assertFalse('[stdout:' in io.stdout)
237 self.assertFalse('[stdout:' in io.stdout)
238 self.assertFalse('\n\n' in io.stdout)
238 self.assertFalse('\n\n' in io.stdout)
239
239
240 ar = ip.magic('px 1/0')
240 ar = ip.magic('px 1/0')
241 self.assertRaisesRemote(ZeroDivisionError, ar.get)
241 self.assertRaisesRemote(ZeroDivisionError, ar.get)
242
242
243 def test_autopx_blocking(self):
243 def test_autopx_blocking(self):
244 ip = get_ipython()
244 ip = get_ipython()
245 v = self.client[-1]
245 v = self.client[-1]
246 v.activate()
246 v.activate()
247 v.block=True
247 v.block=True
248
248
249 with capture_output() as io:
249 with capture_output() as io:
250 ip.magic('autopx')
250 ip.magic('autopx')
251 ip.run_cell('\n'.join(('a=5','b=12345','c=0')))
251 ip.run_cell('\n'.join(('a=5','b=12345','c=0')))
252 ip.run_cell('b*=2')
252 ip.run_cell('b*=2')
253 ip.run_cell('print (b)')
253 ip.run_cell('print (b)')
254 ip.run_cell('b')
254 ip.run_cell('b')
255 ip.run_cell("b/c")
255 ip.run_cell("b/c")
256 ip.magic('autopx')
256 ip.magic('autopx')
257
257
258 output = io.stdout
258 output = io.stdout
259
259
260 self.assertTrue(output.startswith('%autopx enabled'), output)
260 self.assertTrue(output.startswith('%autopx enabled'), output)
261 self.assertTrue(output.rstrip().endswith('%autopx disabled'), output)
261 self.assertTrue(output.rstrip().endswith('%autopx disabled'), output)
262 self.assertTrue('ZeroDivisionError' in output, output)
262 self.assertTrue('ZeroDivisionError' in output, output)
263 self.assertTrue('\nOut[' in output, output)
263 self.assertTrue('\nOut[' in output, output)
264 self.assertTrue(': 24690' in output, output)
264 self.assertTrue(': 24690' in output, output)
265 ar = v.get_result(-1)
265 ar = v.get_result(-1)
266 self.assertEquals(v['a'], 5)
266 self.assertEqual(v['a'], 5)
267 self.assertEquals(v['b'], 24690)
267 self.assertEqual(v['b'], 24690)
268 self.assertRaisesRemote(ZeroDivisionError, ar.get)
268 self.assertRaisesRemote(ZeroDivisionError, ar.get)
269
269
270 def test_autopx_nonblocking(self):
270 def test_autopx_nonblocking(self):
271 ip = get_ipython()
271 ip = get_ipython()
272 v = self.client[-1]
272 v = self.client[-1]
273 v.activate()
273 v.activate()
274 v.block=False
274 v.block=False
275
275
276 with capture_output() as io:
276 with capture_output() as io:
277 ip.magic('autopx')
277 ip.magic('autopx')
278 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
278 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
279 ip.run_cell('print (b)')
279 ip.run_cell('print (b)')
280 ip.run_cell('import time; time.sleep(0.1)')
280 ip.run_cell('import time; time.sleep(0.1)')
281 ip.run_cell("b/c")
281 ip.run_cell("b/c")
282 ip.run_cell('b*=2')
282 ip.run_cell('b*=2')
283 ip.magic('autopx')
283 ip.magic('autopx')
284
284
285 output = io.stdout.rstrip()
285 output = io.stdout.rstrip()
286
286
287 self.assertTrue(output.startswith('%autopx enabled'))
287 self.assertTrue(output.startswith('%autopx enabled'))
288 self.assertTrue(output.endswith('%autopx disabled'))
288 self.assertTrue(output.endswith('%autopx disabled'))
289 self.assertFalse('ZeroDivisionError' in output)
289 self.assertFalse('ZeroDivisionError' in output)
290 ar = v.get_result(-2)
290 ar = v.get_result(-2)
291 self.assertRaisesRemote(ZeroDivisionError, ar.get)
291 self.assertRaisesRemote(ZeroDivisionError, ar.get)
292 # prevent TaskAborted on pulls, due to ZeroDivisionError
292 # prevent TaskAborted on pulls, due to ZeroDivisionError
293 time.sleep(0.5)
293 time.sleep(0.5)
294 self.assertEquals(v['a'], 5)
294 self.assertEqual(v['a'], 5)
295 # b*=2 will not fire, due to abort
295 # b*=2 will not fire, due to abort
296 self.assertEquals(v['b'], 10)
296 self.assertEqual(v['b'], 10)
297
297
298 def test_result(self):
298 def test_result(self):
299 ip = get_ipython()
299 ip = get_ipython()
300 v = self.client[-1]
300 v = self.client[-1]
301 v.activate()
301 v.activate()
302 data = dict(a=111,b=222)
302 data = dict(a=111,b=222)
303 v.push(data, block=True)
303 v.push(data, block=True)
304
304
305 for name in ('a', 'b'):
305 for name in ('a', 'b'):
306 ip.magic('px ' + name)
306 ip.magic('px ' + name)
307 with capture_output() as io:
307 with capture_output() as io:
308 ip.magic('pxresult')
308 ip.magic('pxresult')
309 output = io.stdout
309 output = io.stdout
310 msg = "expected %s output to include %s, but got: %s" % \
310 msg = "expected %s output to include %s, but got: %s" % \
311 ('%pxresult', str(data[name]), output)
311 ('%pxresult', str(data[name]), output)
312 self.assertTrue(str(data[name]) in output, msg)
312 self.assertTrue(str(data[name]) in output, msg)
313
313
314 @dec.skipif_not_matplotlib
314 @dec.skipif_not_matplotlib
315 def test_px_pylab(self):
315 def test_px_pylab(self):
316 """%pylab works on engines"""
316 """%pylab works on engines"""
317 ip = get_ipython()
317 ip = get_ipython()
318 v = self.client[-1]
318 v = self.client[-1]
319 v.block = True
319 v.block = True
320 v.activate()
320 v.activate()
321
321
322 with capture_output() as io:
322 with capture_output() as io:
323 ip.magic("px %pylab inline")
323 ip.magic("px %pylab inline")
324
324
325 self.assertTrue("Welcome to pylab" in io.stdout, io.stdout)
325 self.assertTrue("Welcome to pylab" in io.stdout, io.stdout)
326 self.assertTrue("backend_inline" in io.stdout, io.stdout)
326 self.assertTrue("backend_inline" in io.stdout, io.stdout)
327
327
328 with capture_output() as io:
328 with capture_output() as io:
329 ip.magic("px plot(rand(100))")
329 ip.magic("px plot(rand(100))")
330
330
331 self.assertTrue('Out[' in io.stdout, io.stdout)
331 self.assertTrue('Out[' in io.stdout, io.stdout)
332 self.assertTrue('matplotlib.lines' in io.stdout, io.stdout)
332 self.assertTrue('matplotlib.lines' in io.stdout, io.stdout)
333
333
334 def test_pxconfig(self):
334 def test_pxconfig(self):
335 ip = get_ipython()
335 ip = get_ipython()
336 rc = self.client
336 rc = self.client
337 v = rc.activate(-1, '_tst')
337 v = rc.activate(-1, '_tst')
338 self.assertEquals(v.targets, rc.ids[-1])
338 self.assertEqual(v.targets, rc.ids[-1])
339 ip.magic("%pxconfig_tst -t :")
339 ip.magic("%pxconfig_tst -t :")
340 self.assertEquals(v.targets, rc.ids)
340 self.assertEqual(v.targets, rc.ids)
341 ip.magic("%pxconfig_tst -t ::2")
341 ip.magic("%pxconfig_tst -t ::2")
342 self.assertEquals(v.targets, rc.ids[::2])
342 self.assertEqual(v.targets, rc.ids[::2])
343 ip.magic("%pxconfig_tst -t 1::2")
343 ip.magic("%pxconfig_tst -t 1::2")
344 self.assertEquals(v.targets, rc.ids[1::2])
344 self.assertEqual(v.targets, rc.ids[1::2])
345 ip.magic("%pxconfig_tst -t 1")
345 ip.magic("%pxconfig_tst -t 1")
346 self.assertEquals(v.targets, 1)
346 self.assertEqual(v.targets, 1)
347 ip.magic("%pxconfig_tst --block")
347 ip.magic("%pxconfig_tst --block")
348 self.assertEquals(v.block, True)
348 self.assertEqual(v.block, True)
349 ip.magic("%pxconfig_tst --noblock")
349 ip.magic("%pxconfig_tst --noblock")
350 self.assertEquals(v.block, False)
350 self.assertEqual(v.block, False)
351
351
352 def test_cellpx_targets(self):
352 def test_cellpx_targets(self):
353 """%%px --targets doesn't change defaults"""
353 """%%px --targets doesn't change defaults"""
354 ip = get_ipython()
354 ip = get_ipython()
355 rc = self.client
355 rc = self.client
356 view = rc.activate(rc.ids)
356 view = rc.activate(rc.ids)
357 self.assertEquals(view.targets, rc.ids)
357 self.assertEqual(view.targets, rc.ids)
358 ip.magic('pxconfig --verbose')
358 ip.magic('pxconfig --verbose')
359 for cell in ("pass", "1/0"):
359 for cell in ("pass", "1/0"):
360 with capture_output() as io:
360 with capture_output() as io:
361 try:
361 try:
362 ip.run_cell_magic("px", "--targets all", cell)
362 ip.run_cell_magic("px", "--targets all", cell)
363 except pmod.RemoteError:
363 except pmod.RemoteError:
364 pass
364 pass
365 self.assertTrue('engine(s): all' in io.stdout)
365 self.assertTrue('engine(s): all' in io.stdout)
366 self.assertEquals(view.targets, rc.ids)
366 self.assertEqual(view.targets, rc.ids)
367
367
368
368
369 def test_cellpx_block(self):
369 def test_cellpx_block(self):
370 """%%px --block doesn't change default"""
370 """%%px --block doesn't change default"""
371 ip = get_ipython()
371 ip = get_ipython()
372 rc = self.client
372 rc = self.client
373 view = rc.activate(rc.ids)
373 view = rc.activate(rc.ids)
374 view.block = False
374 view.block = False
375 self.assertEquals(view.targets, rc.ids)
375 self.assertEqual(view.targets, rc.ids)
376 ip.magic('pxconfig --verbose')
376 ip.magic('pxconfig --verbose')
377 for cell in ("pass", "1/0"):
377 for cell in ("pass", "1/0"):
378 with capture_output() as io:
378 with capture_output() as io:
379 try:
379 try:
380 ip.run_cell_magic("px", "--block", cell)
380 ip.run_cell_magic("px", "--block", cell)
381 except pmod.RemoteError:
381 except pmod.RemoteError:
382 pass
382 pass
383 self.assertFalse('Async' in io.stdout)
383 self.assertFalse('Async' in io.stdout)
384 self.assertFalse(view.block)
384 self.assertFalse(view.block)
385
385
386
386
@@ -1,117 +1,117 b''
1 """test serialization with newserialized
1 """test serialization with newserialized
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20
20
21 from unittest import TestCase
21 from unittest import TestCase
22
22
23 from IPython.testing.decorators import parametric
23 from IPython.testing.decorators import parametric
24 from IPython.utils import newserialized as ns
24 from IPython.utils import newserialized as ns
25 from IPython.utils.pickleutil import can, uncan, CannedObject, CannedFunction
25 from IPython.utils.pickleutil import can, uncan, CannedObject, CannedFunction
26 from IPython.parallel.tests.clienttest import skip_without
26 from IPython.parallel.tests.clienttest import skip_without
27
27
28 if sys.version_info[0] >= 3:
28 if sys.version_info[0] >= 3:
29 buffer = memoryview
29 buffer = memoryview
30
30
31 class CanningTestCase(TestCase):
31 class CanningTestCase(TestCase):
32 def test_canning(self):
32 def test_canning(self):
33 d = dict(a=5,b=6)
33 d = dict(a=5,b=6)
34 cd = can(d)
34 cd = can(d)
35 self.assertTrue(isinstance(cd, dict))
35 self.assertTrue(isinstance(cd, dict))
36
36
37 def test_canned_function(self):
37 def test_canned_function(self):
38 f = lambda : 7
38 f = lambda : 7
39 cf = can(f)
39 cf = can(f)
40 self.assertTrue(isinstance(cf, CannedFunction))
40 self.assertTrue(isinstance(cf, CannedFunction))
41
41
42 @parametric
42 @parametric
43 def test_can_roundtrip(cls):
43 def test_can_roundtrip(cls):
44 objs = [
44 objs = [
45 dict(),
45 dict(),
46 set(),
46 set(),
47 list(),
47 list(),
48 ['a',1,['a',1],u'e'],
48 ['a',1,['a',1],u'e'],
49 ]
49 ]
50 return map(cls.run_roundtrip, objs)
50 return map(cls.run_roundtrip, objs)
51
51
52 @classmethod
52 @classmethod
53 def run_roundtrip(self, obj):
53 def run_roundtrip(self, obj):
54 o = uncan(can(obj))
54 o = uncan(can(obj))
55 assert o == obj, "failed assertion: %r == %r"%(o,obj)
55 assert o == obj, "failed assertion: %r == %r"%(o,obj)
56
56
57 def test_serialized_interfaces(self):
57 def test_serialized_interfaces(self):
58
58
59 us = {'a':10, 'b':range(10)}
59 us = {'a':10, 'b':range(10)}
60 s = ns.serialize(us)
60 s = ns.serialize(us)
61 uus = ns.unserialize(s)
61 uus = ns.unserialize(s)
62 self.assertTrue(isinstance(s, ns.SerializeIt))
62 self.assertTrue(isinstance(s, ns.SerializeIt))
63 self.assertEquals(uus, us)
63 self.assertEqual(uus, us)
64
64
65 def test_pickle_serialized(self):
65 def test_pickle_serialized(self):
66 obj = {'a':1.45345, 'b':'asdfsdf', 'c':10000L}
66 obj = {'a':1.45345, 'b':'asdfsdf', 'c':10000L}
67 original = ns.UnSerialized(obj)
67 original = ns.UnSerialized(obj)
68 originalSer = ns.SerializeIt(original)
68 originalSer = ns.SerializeIt(original)
69 firstData = originalSer.getData()
69 firstData = originalSer.getData()
70 firstTD = originalSer.getTypeDescriptor()
70 firstTD = originalSer.getTypeDescriptor()
71 firstMD = originalSer.getMetadata()
71 firstMD = originalSer.getMetadata()
72 self.assertEquals(firstTD, 'pickle')
72 self.assertEqual(firstTD, 'pickle')
73 self.assertEquals(firstMD, {})
73 self.assertEqual(firstMD, {})
74 unSerialized = ns.UnSerializeIt(originalSer)
74 unSerialized = ns.UnSerializeIt(originalSer)
75 secondObj = unSerialized.getObject()
75 secondObj = unSerialized.getObject()
76 for k, v in secondObj.iteritems():
76 for k, v in secondObj.iteritems():
77 self.assertEquals(obj[k], v)
77 self.assertEqual(obj[k], v)
78 secondSer = ns.SerializeIt(ns.UnSerialized(secondObj))
78 secondSer = ns.SerializeIt(ns.UnSerialized(secondObj))
79 self.assertEquals(firstData, secondSer.getData())
79 self.assertEqual(firstData, secondSer.getData())
80 self.assertEquals(firstTD, secondSer.getTypeDescriptor() )
80 self.assertEqual(firstTD, secondSer.getTypeDescriptor() )
81 self.assertEquals(firstMD, secondSer.getMetadata())
81 self.assertEqual(firstMD, secondSer.getMetadata())
82
82
83 @skip_without('numpy')
83 @skip_without('numpy')
84 def test_ndarray_serialized(self):
84 def test_ndarray_serialized(self):
85 import numpy
85 import numpy
86 a = numpy.linspace(0.0, 1.0, 1000)
86 a = numpy.linspace(0.0, 1.0, 1000)
87 unSer1 = ns.UnSerialized(a)
87 unSer1 = ns.UnSerialized(a)
88 ser1 = ns.SerializeIt(unSer1)
88 ser1 = ns.SerializeIt(unSer1)
89 td = ser1.getTypeDescriptor()
89 td = ser1.getTypeDescriptor()
90 self.assertEquals(td, 'ndarray')
90 self.assertEqual(td, 'ndarray')
91 md = ser1.getMetadata()
91 md = ser1.getMetadata()
92 self.assertEquals(md['shape'], a.shape)
92 self.assertEqual(md['shape'], a.shape)
93 self.assertEquals(md['dtype'], a.dtype)
93 self.assertEqual(md['dtype'], a.dtype)
94 buff = ser1.getData()
94 buff = ser1.getData()
95 self.assertEquals(buff, buffer(a))
95 self.assertEqual(buff, buffer(a))
96 s = ns.Serialized(buff, td, md)
96 s = ns.Serialized(buff, td, md)
97 final = ns.unserialize(s)
97 final = ns.unserialize(s)
98 self.assertEquals(buffer(a), buffer(final))
98 self.assertEqual(buffer(a), buffer(final))
99 self.assertTrue((a==final).all())
99 self.assertTrue((a==final).all())
100 self.assertEquals(a.dtype, final.dtype)
100 self.assertEqual(a.dtype, final.dtype)
101 self.assertEquals(a.shape, final.shape)
101 self.assertEqual(a.shape, final.shape)
102 # test non-copying:
102 # test non-copying:
103 a[2] = 1e9
103 a[2] = 1e9
104 self.assertTrue((a==final).all())
104 self.assertTrue((a==final).all())
105
105
106 def test_uncan_function_globals(self):
106 def test_uncan_function_globals(self):
107 """test that uncanning a module function restores it into its module"""
107 """test that uncanning a module function restores it into its module"""
108 from re import search
108 from re import search
109 cf = can(search)
109 cf = can(search)
110 csearch = uncan(cf)
110 csearch = uncan(cf)
111 self.assertEqual(csearch.__module__, search.__module__)
111 self.assertEqual(csearch.__module__, search.__module__)
112 self.assertNotEqual(csearch('asd', 'asdf'), None)
112 self.assertNotEqual(csearch('asd', 'asdf'), None)
113 csearch = uncan(cf, dict(a=5))
113 csearch = uncan(cf, dict(a=5))
114 self.assertEqual(csearch.__module__, search.__module__)
114 self.assertEqual(csearch.__module__, search.__module__)
115 self.assertNotEqual(csearch('asd', 'asdf'), None)
115 self.assertNotEqual(csearch('asd', 'asdf'), None)
116
116
117 No newline at end of file
117
@@ -1,597 +1,597 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import platform
20 import platform
21 import time
21 import time
22 from tempfile import mktemp
22 from tempfile import mktemp
23 from StringIO import StringIO
23 from StringIO import StringIO
24
24
25 import zmq
25 import zmq
26 from nose import SkipTest
26 from nose import SkipTest
27
27
28 from IPython.testing import decorators as dec
28 from IPython.testing import decorators as dec
29 from IPython.testing.ipunittest import ParametricTestCase
29 from IPython.testing.ipunittest import ParametricTestCase
30
30
31 from IPython import parallel as pmod
31 from IPython import parallel as pmod
32 from IPython.parallel import error
32 from IPython.parallel import error
33 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
33 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
34 from IPython.parallel import DirectView
34 from IPython.parallel import DirectView
35 from IPython.parallel.util import interactive
35 from IPython.parallel.util import interactive
36
36
37 from IPython.parallel.tests import add_engines
37 from IPython.parallel.tests import add_engines
38
38
39 from .clienttest import ClusterTestCase, crash, wait, skip_without
39 from .clienttest import ClusterTestCase, crash, wait, skip_without
40
40
41 def setup():
41 def setup():
42 add_engines(3, total=True)
42 add_engines(3, total=True)
43
43
44 class TestView(ClusterTestCase, ParametricTestCase):
44 class TestView(ClusterTestCase, ParametricTestCase):
45
45
46 def setUp(self):
46 def setUp(self):
47 # On Win XP, wait for resource cleanup, else parallel test group fails
47 # On Win XP, wait for resource cleanup, else parallel test group fails
48 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
48 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
49 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
49 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
50 time.sleep(2)
50 time.sleep(2)
51 super(TestView, self).setUp()
51 super(TestView, self).setUp()
52
52
53 def test_z_crash_mux(self):
53 def test_z_crash_mux(self):
54 """test graceful handling of engine death (direct)"""
54 """test graceful handling of engine death (direct)"""
55 raise SkipTest("crash tests disabled, due to undesirable crash reports")
55 raise SkipTest("crash tests disabled, due to undesirable crash reports")
56 # self.add_engines(1)
56 # self.add_engines(1)
57 eid = self.client.ids[-1]
57 eid = self.client.ids[-1]
58 ar = self.client[eid].apply_async(crash)
58 ar = self.client[eid].apply_async(crash)
59 self.assertRaisesRemote(error.EngineError, ar.get, 10)
59 self.assertRaisesRemote(error.EngineError, ar.get, 10)
60 eid = ar.engine_id
60 eid = ar.engine_id
61 tic = time.time()
61 tic = time.time()
62 while eid in self.client.ids and time.time()-tic < 5:
62 while eid in self.client.ids and time.time()-tic < 5:
63 time.sleep(.01)
63 time.sleep(.01)
64 self.client.spin()
64 self.client.spin()
65 self.assertFalse(eid in self.client.ids, "Engine should have died")
65 self.assertFalse(eid in self.client.ids, "Engine should have died")
66
66
67 def test_push_pull(self):
67 def test_push_pull(self):
68 """test pushing and pulling"""
68 """test pushing and pulling"""
69 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
69 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
70 t = self.client.ids[-1]
70 t = self.client.ids[-1]
71 v = self.client[t]
71 v = self.client[t]
72 push = v.push
72 push = v.push
73 pull = v.pull
73 pull = v.pull
74 v.block=True
74 v.block=True
75 nengines = len(self.client)
75 nengines = len(self.client)
76 push({'data':data})
76 push({'data':data})
77 d = pull('data')
77 d = pull('data')
78 self.assertEquals(d, data)
78 self.assertEqual(d, data)
79 self.client[:].push({'data':data})
79 self.client[:].push({'data':data})
80 d = self.client[:].pull('data', block=True)
80 d = self.client[:].pull('data', block=True)
81 self.assertEquals(d, nengines*[data])
81 self.assertEqual(d, nengines*[data])
82 ar = push({'data':data}, block=False)
82 ar = push({'data':data}, block=False)
83 self.assertTrue(isinstance(ar, AsyncResult))
83 self.assertTrue(isinstance(ar, AsyncResult))
84 r = ar.get()
84 r = ar.get()
85 ar = self.client[:].pull('data', block=False)
85 ar = self.client[:].pull('data', block=False)
86 self.assertTrue(isinstance(ar, AsyncResult))
86 self.assertTrue(isinstance(ar, AsyncResult))
87 r = ar.get()
87 r = ar.get()
88 self.assertEquals(r, nengines*[data])
88 self.assertEqual(r, nengines*[data])
89 self.client[:].push(dict(a=10,b=20))
89 self.client[:].push(dict(a=10,b=20))
90 r = self.client[:].pull(('a','b'), block=True)
90 r = self.client[:].pull(('a','b'), block=True)
91 self.assertEquals(r, nengines*[[10,20]])
91 self.assertEqual(r, nengines*[[10,20]])
92
92
93 def test_push_pull_function(self):
93 def test_push_pull_function(self):
94 "test pushing and pulling functions"
94 "test pushing and pulling functions"
95 def testf(x):
95 def testf(x):
96 return 2.0*x
96 return 2.0*x
97
97
98 t = self.client.ids[-1]
98 t = self.client.ids[-1]
99 v = self.client[t]
99 v = self.client[t]
100 v.block=True
100 v.block=True
101 push = v.push
101 push = v.push
102 pull = v.pull
102 pull = v.pull
103 execute = v.execute
103 execute = v.execute
104 push({'testf':testf})
104 push({'testf':testf})
105 r = pull('testf')
105 r = pull('testf')
106 self.assertEqual(r(1.0), testf(1.0))
106 self.assertEqual(r(1.0), testf(1.0))
107 execute('r = testf(10)')
107 execute('r = testf(10)')
108 r = pull('r')
108 r = pull('r')
109 self.assertEquals(r, testf(10))
109 self.assertEqual(r, testf(10))
110 ar = self.client[:].push({'testf':testf}, block=False)
110 ar = self.client[:].push({'testf':testf}, block=False)
111 ar.get()
111 ar.get()
112 ar = self.client[:].pull('testf', block=False)
112 ar = self.client[:].pull('testf', block=False)
113 rlist = ar.get()
113 rlist = ar.get()
114 for r in rlist:
114 for r in rlist:
115 self.assertEqual(r(1.0), testf(1.0))
115 self.assertEqual(r(1.0), testf(1.0))
116 execute("def g(x): return x*x")
116 execute("def g(x): return x*x")
117 r = pull(('testf','g'))
117 r = pull(('testf','g'))
118 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
118 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
119
119
120 def test_push_function_globals(self):
120 def test_push_function_globals(self):
121 """test that pushed functions have access to globals"""
121 """test that pushed functions have access to globals"""
122 @interactive
122 @interactive
123 def geta():
123 def geta():
124 return a
124 return a
125 # self.add_engines(1)
125 # self.add_engines(1)
126 v = self.client[-1]
126 v = self.client[-1]
127 v.block=True
127 v.block=True
128 v['f'] = geta
128 v['f'] = geta
129 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
129 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
130 v.execute('a=5')
130 v.execute('a=5')
131 v.execute('b=f()')
131 v.execute('b=f()')
132 self.assertEquals(v['b'], 5)
132 self.assertEqual(v['b'], 5)
133
133
134 def test_push_function_defaults(self):
134 def test_push_function_defaults(self):
135 """test that pushed functions preserve default args"""
135 """test that pushed functions preserve default args"""
136 def echo(a=10):
136 def echo(a=10):
137 return a
137 return a
138 v = self.client[-1]
138 v = self.client[-1]
139 v.block=True
139 v.block=True
140 v['f'] = echo
140 v['f'] = echo
141 v.execute('b=f()')
141 v.execute('b=f()')
142 self.assertEquals(v['b'], 10)
142 self.assertEqual(v['b'], 10)
143
143
144 def test_get_result(self):
144 def test_get_result(self):
145 """test getting results from the Hub."""
145 """test getting results from the Hub."""
146 c = pmod.Client(profile='iptest')
146 c = pmod.Client(profile='iptest')
147 # self.add_engines(1)
147 # self.add_engines(1)
148 t = c.ids[-1]
148 t = c.ids[-1]
149 v = c[t]
149 v = c[t]
150 v2 = self.client[t]
150 v2 = self.client[t]
151 ar = v.apply_async(wait, 1)
151 ar = v.apply_async(wait, 1)
152 # give the monitor time to notice the message
152 # give the monitor time to notice the message
153 time.sleep(.25)
153 time.sleep(.25)
154 ahr = v2.get_result(ar.msg_ids)
154 ahr = v2.get_result(ar.msg_ids)
155 self.assertTrue(isinstance(ahr, AsyncHubResult))
155 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 self.assertEquals(ahr.get(), ar.get())
156 self.assertEqual(ahr.get(), ar.get())
157 ar2 = v2.get_result(ar.msg_ids)
157 ar2 = v2.get_result(ar.msg_ids)
158 self.assertFalse(isinstance(ar2, AsyncHubResult))
158 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 c.spin()
159 c.spin()
160 c.close()
160 c.close()
161
161
162 def test_run_newline(self):
162 def test_run_newline(self):
163 """test that run appends newline to files"""
163 """test that run appends newline to files"""
164 tmpfile = mktemp()
164 tmpfile = mktemp()
165 with open(tmpfile, 'w') as f:
165 with open(tmpfile, 'w') as f:
166 f.write("""def g():
166 f.write("""def g():
167 return 5
167 return 5
168 """)
168 """)
169 v = self.client[-1]
169 v = self.client[-1]
170 v.run(tmpfile, block=True)
170 v.run(tmpfile, block=True)
171 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
171 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
172
172
173 def test_apply_tracked(self):
173 def test_apply_tracked(self):
174 """test tracking for apply"""
174 """test tracking for apply"""
175 # self.add_engines(1)
175 # self.add_engines(1)
176 t = self.client.ids[-1]
176 t = self.client.ids[-1]
177 v = self.client[t]
177 v = self.client[t]
178 v.block=False
178 v.block=False
179 def echo(n=1024*1024, **kwargs):
179 def echo(n=1024*1024, **kwargs):
180 with v.temp_flags(**kwargs):
180 with v.temp_flags(**kwargs):
181 return v.apply(lambda x: x, 'x'*n)
181 return v.apply(lambda x: x, 'x'*n)
182 ar = echo(1, track=False)
182 ar = echo(1, track=False)
183 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
183 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
184 self.assertTrue(ar.sent)
184 self.assertTrue(ar.sent)
185 ar = echo(track=True)
185 ar = echo(track=True)
186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertEquals(ar.sent, ar._tracker.done)
187 self.assertEqual(ar.sent, ar._tracker.done)
188 ar._tracker.wait()
188 ar._tracker.wait()
189 self.assertTrue(ar.sent)
189 self.assertTrue(ar.sent)
190
190
191 def test_push_tracked(self):
191 def test_push_tracked(self):
192 t = self.client.ids[-1]
192 t = self.client.ids[-1]
193 ns = dict(x='x'*1024*1024)
193 ns = dict(x='x'*1024*1024)
194 v = self.client[t]
194 v = self.client[t]
195 ar = v.push(ns, block=False, track=False)
195 ar = v.push(ns, block=False, track=False)
196 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
196 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
197 self.assertTrue(ar.sent)
197 self.assertTrue(ar.sent)
198
198
199 ar = v.push(ns, block=False, track=True)
199 ar = v.push(ns, block=False, track=True)
200 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
200 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 ar._tracker.wait()
201 ar._tracker.wait()
202 self.assertEquals(ar.sent, ar._tracker.done)
202 self.assertEqual(ar.sent, ar._tracker.done)
203 self.assertTrue(ar.sent)
203 self.assertTrue(ar.sent)
204 ar.get()
204 ar.get()
205
205
206 def test_scatter_tracked(self):
206 def test_scatter_tracked(self):
207 t = self.client.ids
207 t = self.client.ids
208 x='x'*1024*1024
208 x='x'*1024*1024
209 ar = self.client[t].scatter('x', x, block=False, track=False)
209 ar = self.client[t].scatter('x', x, block=False, track=False)
210 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
210 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
211 self.assertTrue(ar.sent)
211 self.assertTrue(ar.sent)
212
212
213 ar = self.client[t].scatter('x', x, block=False, track=True)
213 ar = self.client[t].scatter('x', x, block=False, track=True)
214 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
214 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
215 self.assertEquals(ar.sent, ar._tracker.done)
215 self.assertEqual(ar.sent, ar._tracker.done)
216 ar._tracker.wait()
216 ar._tracker.wait()
217 self.assertTrue(ar.sent)
217 self.assertTrue(ar.sent)
218 ar.get()
218 ar.get()
219
219
220 def test_remote_reference(self):
220 def test_remote_reference(self):
221 v = self.client[-1]
221 v = self.client[-1]
222 v['a'] = 123
222 v['a'] = 123
223 ra = pmod.Reference('a')
223 ra = pmod.Reference('a')
224 b = v.apply_sync(lambda x: x, ra)
224 b = v.apply_sync(lambda x: x, ra)
225 self.assertEquals(b, 123)
225 self.assertEqual(b, 123)
226
226
227
227
228 def test_scatter_gather(self):
228 def test_scatter_gather(self):
229 view = self.client[:]
229 view = self.client[:]
230 seq1 = range(16)
230 seq1 = range(16)
231 view.scatter('a', seq1)
231 view.scatter('a', seq1)
232 seq2 = view.gather('a', block=True)
232 seq2 = view.gather('a', block=True)
233 self.assertEquals(seq2, seq1)
233 self.assertEqual(seq2, seq1)
234 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
234 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
235
235
236 @skip_without('numpy')
236 @skip_without('numpy')
237 def test_scatter_gather_numpy(self):
237 def test_scatter_gather_numpy(self):
238 import numpy
238 import numpy
239 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
239 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
240 view = self.client[:]
240 view = self.client[:]
241 a = numpy.arange(64)
241 a = numpy.arange(64)
242 view.scatter('a', a)
242 view.scatter('a', a)
243 b = view.gather('a', block=True)
243 b = view.gather('a', block=True)
244 assert_array_equal(b, a)
244 assert_array_equal(b, a)
245
245
246 def test_scatter_gather_lazy(self):
246 def test_scatter_gather_lazy(self):
247 """scatter/gather with targets='all'"""
247 """scatter/gather with targets='all'"""
248 view = self.client.direct_view(targets='all')
248 view = self.client.direct_view(targets='all')
249 x = range(64)
249 x = range(64)
250 view.scatter('x', x)
250 view.scatter('x', x)
251 gathered = view.gather('x', block=True)
251 gathered = view.gather('x', block=True)
252 self.assertEquals(gathered, x)
252 self.assertEqual(gathered, x)
253
253
254
254
255 @dec.known_failure_py3
255 @dec.known_failure_py3
256 @skip_without('numpy')
256 @skip_without('numpy')
257 def test_push_numpy_nocopy(self):
257 def test_push_numpy_nocopy(self):
258 import numpy
258 import numpy
259 view = self.client[:]
259 view = self.client[:]
260 a = numpy.arange(64)
260 a = numpy.arange(64)
261 view['A'] = a
261 view['A'] = a
262 @interactive
262 @interactive
263 def check_writeable(x):
263 def check_writeable(x):
264 return x.flags.writeable
264 return x.flags.writeable
265
265
266 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
266 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
267 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
267 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
268
268
269 view.push(dict(B=a))
269 view.push(dict(B=a))
270 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
270 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
271 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
271 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
272
272
273 @skip_without('numpy')
273 @skip_without('numpy')
274 def test_apply_numpy(self):
274 def test_apply_numpy(self):
275 """view.apply(f, ndarray)"""
275 """view.apply(f, ndarray)"""
276 import numpy
276 import numpy
277 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
277 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
278
278
279 A = numpy.random.random((100,100))
279 A = numpy.random.random((100,100))
280 view = self.client[-1]
280 view = self.client[-1]
281 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
281 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
282 B = A.astype(dt)
282 B = A.astype(dt)
283 C = view.apply_sync(lambda x:x, B)
283 C = view.apply_sync(lambda x:x, B)
284 assert_array_equal(B,C)
284 assert_array_equal(B,C)
285
285
286 @skip_without('numpy')
286 @skip_without('numpy')
287 def test_push_pull_recarray(self):
287 def test_push_pull_recarray(self):
288 """push/pull recarrays"""
288 """push/pull recarrays"""
289 import numpy
289 import numpy
290 from numpy.testing.utils import assert_array_equal
290 from numpy.testing.utils import assert_array_equal
291
291
292 view = self.client[-1]
292 view = self.client[-1]
293
293
294 R = numpy.array([
294 R = numpy.array([
295 (1, 'hi', 0.),
295 (1, 'hi', 0.),
296 (2**30, 'there', 2.5),
296 (2**30, 'there', 2.5),
297 (-99999, 'world', -12345.6789),
297 (-99999, 'world', -12345.6789),
298 ], [('n', int), ('s', '|S10'), ('f', float)])
298 ], [('n', int), ('s', '|S10'), ('f', float)])
299
299
300 view['RR'] = R
300 view['RR'] = R
301 R2 = view['RR']
301 R2 = view['RR']
302
302
303 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
303 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
304 self.assertEquals(r_dtype, R.dtype)
304 self.assertEqual(r_dtype, R.dtype)
305 self.assertEquals(r_shape, R.shape)
305 self.assertEqual(r_shape, R.shape)
306 self.assertEquals(R2.dtype, R.dtype)
306 self.assertEqual(R2.dtype, R.dtype)
307 self.assertEquals(R2.shape, R.shape)
307 self.assertEqual(R2.shape, R.shape)
308 assert_array_equal(R2, R)
308 assert_array_equal(R2, R)
309
309
310 def test_map(self):
310 def test_map(self):
311 view = self.client[:]
311 view = self.client[:]
312 def f(x):
312 def f(x):
313 return x**2
313 return x**2
314 data = range(16)
314 data = range(16)
315 r = view.map_sync(f, data)
315 r = view.map_sync(f, data)
316 self.assertEquals(r, map(f, data))
316 self.assertEqual(r, map(f, data))
317
317
318 def test_map_iterable(self):
318 def test_map_iterable(self):
319 """test map on iterables (direct)"""
319 """test map on iterables (direct)"""
320 view = self.client[:]
320 view = self.client[:]
321 # 101 is prime, so it won't be evenly distributed
321 # 101 is prime, so it won't be evenly distributed
322 arr = range(101)
322 arr = range(101)
323 # ensure it will be an iterator, even in Python 3
323 # ensure it will be an iterator, even in Python 3
324 it = iter(arr)
324 it = iter(arr)
325 r = view.map_sync(lambda x:x, arr)
325 r = view.map_sync(lambda x:x, arr)
326 self.assertEquals(r, list(arr))
326 self.assertEqual(r, list(arr))
327
327
328 def test_scatterGatherNonblocking(self):
328 def test_scatterGatherNonblocking(self):
329 data = range(16)
329 data = range(16)
330 view = self.client[:]
330 view = self.client[:]
331 view.scatter('a', data, block=False)
331 view.scatter('a', data, block=False)
332 ar = view.gather('a', block=False)
332 ar = view.gather('a', block=False)
333 self.assertEquals(ar.get(), data)
333 self.assertEqual(ar.get(), data)
334
334
335 @skip_without('numpy')
335 @skip_without('numpy')
336 def test_scatter_gather_numpy_nonblocking(self):
336 def test_scatter_gather_numpy_nonblocking(self):
337 import numpy
337 import numpy
338 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
338 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
339 a = numpy.arange(64)
339 a = numpy.arange(64)
340 view = self.client[:]
340 view = self.client[:]
341 ar = view.scatter('a', a, block=False)
341 ar = view.scatter('a', a, block=False)
342 self.assertTrue(isinstance(ar, AsyncResult))
342 self.assertTrue(isinstance(ar, AsyncResult))
343 amr = view.gather('a', block=False)
343 amr = view.gather('a', block=False)
344 self.assertTrue(isinstance(amr, AsyncMapResult))
344 self.assertTrue(isinstance(amr, AsyncMapResult))
345 assert_array_equal(amr.get(), a)
345 assert_array_equal(amr.get(), a)
346
346
347 def test_execute(self):
347 def test_execute(self):
348 view = self.client[:]
348 view = self.client[:]
349 # self.client.debug=True
349 # self.client.debug=True
350 execute = view.execute
350 execute = view.execute
351 ar = execute('c=30', block=False)
351 ar = execute('c=30', block=False)
352 self.assertTrue(isinstance(ar, AsyncResult))
352 self.assertTrue(isinstance(ar, AsyncResult))
353 ar = execute('d=[0,1,2]', block=False)
353 ar = execute('d=[0,1,2]', block=False)
354 self.client.wait(ar, 1)
354 self.client.wait(ar, 1)
355 self.assertEquals(len(ar.get()), len(self.client))
355 self.assertEqual(len(ar.get()), len(self.client))
356 for c in view['c']:
356 for c in view['c']:
357 self.assertEquals(c, 30)
357 self.assertEqual(c, 30)
358
358
359 def test_abort(self):
359 def test_abort(self):
360 view = self.client[-1]
360 view = self.client[-1]
361 ar = view.execute('import time; time.sleep(1)', block=False)
361 ar = view.execute('import time; time.sleep(1)', block=False)
362 ar2 = view.apply_async(lambda : 2)
362 ar2 = view.apply_async(lambda : 2)
363 ar3 = view.apply_async(lambda : 3)
363 ar3 = view.apply_async(lambda : 3)
364 view.abort(ar2)
364 view.abort(ar2)
365 view.abort(ar3.msg_ids)
365 view.abort(ar3.msg_ids)
366 self.assertRaises(error.TaskAborted, ar2.get)
366 self.assertRaises(error.TaskAborted, ar2.get)
367 self.assertRaises(error.TaskAborted, ar3.get)
367 self.assertRaises(error.TaskAborted, ar3.get)
368
368
369 def test_abort_all(self):
369 def test_abort_all(self):
370 """view.abort() aborts all outstanding tasks"""
370 """view.abort() aborts all outstanding tasks"""
371 view = self.client[-1]
371 view = self.client[-1]
372 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
372 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
373 view.abort()
373 view.abort()
374 view.wait(timeout=5)
374 view.wait(timeout=5)
375 for ar in ars[5:]:
375 for ar in ars[5:]:
376 self.assertRaises(error.TaskAborted, ar.get)
376 self.assertRaises(error.TaskAborted, ar.get)
377
377
378 def test_temp_flags(self):
378 def test_temp_flags(self):
379 view = self.client[-1]
379 view = self.client[-1]
380 view.block=True
380 view.block=True
381 with view.temp_flags(block=False):
381 with view.temp_flags(block=False):
382 self.assertFalse(view.block)
382 self.assertFalse(view.block)
383 self.assertTrue(view.block)
383 self.assertTrue(view.block)
384
384
385 @dec.known_failure_py3
385 @dec.known_failure_py3
386 def test_importer(self):
386 def test_importer(self):
387 view = self.client[-1]
387 view = self.client[-1]
388 view.clear(block=True)
388 view.clear(block=True)
389 with view.importer:
389 with view.importer:
390 import re
390 import re
391
391
392 @interactive
392 @interactive
393 def findall(pat, s):
393 def findall(pat, s):
394 # this globals() step isn't necessary in real code
394 # this globals() step isn't necessary in real code
395 # only to prevent a closure in the test
395 # only to prevent a closure in the test
396 re = globals()['re']
396 re = globals()['re']
397 return re.findall(pat, s)
397 return re.findall(pat, s)
398
398
399 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
399 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
400
400
401 def test_unicode_execute(self):
401 def test_unicode_execute(self):
402 """test executing unicode strings"""
402 """test executing unicode strings"""
403 v = self.client[-1]
403 v = self.client[-1]
404 v.block=True
404 v.block=True
405 if sys.version_info[0] >= 3:
405 if sys.version_info[0] >= 3:
406 code="a='é'"
406 code="a='é'"
407 else:
407 else:
408 code=u"a=u'é'"
408 code=u"a=u'é'"
409 v.execute(code)
409 v.execute(code)
410 self.assertEquals(v['a'], u'é')
410 self.assertEqual(v['a'], u'é')
411
411
412 def test_unicode_apply_result(self):
412 def test_unicode_apply_result(self):
413 """test unicode apply results"""
413 """test unicode apply results"""
414 v = self.client[-1]
414 v = self.client[-1]
415 r = v.apply_sync(lambda : u'é')
415 r = v.apply_sync(lambda : u'é')
416 self.assertEquals(r, u'é')
416 self.assertEqual(r, u'é')
417
417
418 def test_unicode_apply_arg(self):
418 def test_unicode_apply_arg(self):
419 """test passing unicode arguments to apply"""
419 """test passing unicode arguments to apply"""
420 v = self.client[-1]
420 v = self.client[-1]
421
421
422 @interactive
422 @interactive
423 def check_unicode(a, check):
423 def check_unicode(a, check):
424 assert isinstance(a, unicode), "%r is not unicode"%a
424 assert isinstance(a, unicode), "%r is not unicode"%a
425 assert isinstance(check, bytes), "%r is not bytes"%check
425 assert isinstance(check, bytes), "%r is not bytes"%check
426 assert a.encode('utf8') == check, "%s != %s"%(a,check)
426 assert a.encode('utf8') == check, "%s != %s"%(a,check)
427
427
428 for s in [ u'é', u'ßø®∫',u'asdf' ]:
428 for s in [ u'é', u'ßø®∫',u'asdf' ]:
429 try:
429 try:
430 v.apply_sync(check_unicode, s, s.encode('utf8'))
430 v.apply_sync(check_unicode, s, s.encode('utf8'))
431 except error.RemoteError as e:
431 except error.RemoteError as e:
432 if e.ename == 'AssertionError':
432 if e.ename == 'AssertionError':
433 self.fail(e.evalue)
433 self.fail(e.evalue)
434 else:
434 else:
435 raise e
435 raise e
436
436
437 def test_map_reference(self):
437 def test_map_reference(self):
438 """view.map(<Reference>, *seqs) should work"""
438 """view.map(<Reference>, *seqs) should work"""
439 v = self.client[:]
439 v = self.client[:]
440 v.scatter('n', self.client.ids, flatten=True)
440 v.scatter('n', self.client.ids, flatten=True)
441 v.execute("f = lambda x,y: x*y")
441 v.execute("f = lambda x,y: x*y")
442 rf = pmod.Reference('f')
442 rf = pmod.Reference('f')
443 nlist = list(range(10))
443 nlist = list(range(10))
444 mlist = nlist[::-1]
444 mlist = nlist[::-1]
445 expected = [ m*n for m,n in zip(mlist, nlist) ]
445 expected = [ m*n for m,n in zip(mlist, nlist) ]
446 result = v.map_sync(rf, mlist, nlist)
446 result = v.map_sync(rf, mlist, nlist)
447 self.assertEquals(result, expected)
447 self.assertEqual(result, expected)
448
448
449 def test_apply_reference(self):
449 def test_apply_reference(self):
450 """view.apply(<Reference>, *args) should work"""
450 """view.apply(<Reference>, *args) should work"""
451 v = self.client[:]
451 v = self.client[:]
452 v.scatter('n', self.client.ids, flatten=True)
452 v.scatter('n', self.client.ids, flatten=True)
453 v.execute("f = lambda x: n*x")
453 v.execute("f = lambda x: n*x")
454 rf = pmod.Reference('f')
454 rf = pmod.Reference('f')
455 result = v.apply_sync(rf, 5)
455 result = v.apply_sync(rf, 5)
456 expected = [ 5*id for id in self.client.ids ]
456 expected = [ 5*id for id in self.client.ids ]
457 self.assertEquals(result, expected)
457 self.assertEqual(result, expected)
458
458
459 def test_eval_reference(self):
459 def test_eval_reference(self):
460 v = self.client[self.client.ids[0]]
460 v = self.client[self.client.ids[0]]
461 v['g'] = range(5)
461 v['g'] = range(5)
462 rg = pmod.Reference('g[0]')
462 rg = pmod.Reference('g[0]')
463 echo = lambda x:x
463 echo = lambda x:x
464 self.assertEquals(v.apply_sync(echo, rg), 0)
464 self.assertEqual(v.apply_sync(echo, rg), 0)
465
465
466 def test_reference_nameerror(self):
466 def test_reference_nameerror(self):
467 v = self.client[self.client.ids[0]]
467 v = self.client[self.client.ids[0]]
468 r = pmod.Reference('elvis_has_left')
468 r = pmod.Reference('elvis_has_left')
469 echo = lambda x:x
469 echo = lambda x:x
470 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
470 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
471
471
472 def test_single_engine_map(self):
472 def test_single_engine_map(self):
473 e0 = self.client[self.client.ids[0]]
473 e0 = self.client[self.client.ids[0]]
474 r = range(5)
474 r = range(5)
475 check = [ -1*i for i in r ]
475 check = [ -1*i for i in r ]
476 result = e0.map_sync(lambda x: -1*x, r)
476 result = e0.map_sync(lambda x: -1*x, r)
477 self.assertEquals(result, check)
477 self.assertEqual(result, check)
478
478
479 def test_len(self):
479 def test_len(self):
480 """len(view) makes sense"""
480 """len(view) makes sense"""
481 e0 = self.client[self.client.ids[0]]
481 e0 = self.client[self.client.ids[0]]
482 yield self.assertEquals(len(e0), 1)
482 yield self.assertEqual(len(e0), 1)
483 v = self.client[:]
483 v = self.client[:]
484 yield self.assertEquals(len(v), len(self.client.ids))
484 yield self.assertEqual(len(v), len(self.client.ids))
485 v = self.client.direct_view('all')
485 v = self.client.direct_view('all')
486 yield self.assertEquals(len(v), len(self.client.ids))
486 yield self.assertEqual(len(v), len(self.client.ids))
487 v = self.client[:2]
487 v = self.client[:2]
488 yield self.assertEquals(len(v), 2)
488 yield self.assertEqual(len(v), 2)
489 v = self.client[:1]
489 v = self.client[:1]
490 yield self.assertEquals(len(v), 1)
490 yield self.assertEqual(len(v), 1)
491 v = self.client.load_balanced_view()
491 v = self.client.load_balanced_view()
492 yield self.assertEquals(len(v), len(self.client.ids))
492 yield self.assertEqual(len(v), len(self.client.ids))
493 # parametric tests seem to require manual closing?
493 # parametric tests seem to require manual closing?
494 self.client.close()
494 self.client.close()
495
495
496
496
497 # begin execute tests
497 # begin execute tests
498
498
499 def test_execute_reply(self):
499 def test_execute_reply(self):
500 e0 = self.client[self.client.ids[0]]
500 e0 = self.client[self.client.ids[0]]
501 e0.block = True
501 e0.block = True
502 ar = e0.execute("5", silent=False)
502 ar = e0.execute("5", silent=False)
503 er = ar.get()
503 er = ar.get()
504 self.assertEquals(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
504 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
505 self.assertEquals(er.pyout['data']['text/plain'], '5')
505 self.assertEqual(er.pyout['data']['text/plain'], '5')
506
506
507 def test_execute_reply_stdout(self):
507 def test_execute_reply_stdout(self):
508 e0 = self.client[self.client.ids[0]]
508 e0 = self.client[self.client.ids[0]]
509 e0.block = True
509 e0.block = True
510 ar = e0.execute("print (5)", silent=False)
510 ar = e0.execute("print (5)", silent=False)
511 er = ar.get()
511 er = ar.get()
512 self.assertEquals(er.stdout.strip(), '5')
512 self.assertEqual(er.stdout.strip(), '5')
513
513
514 def test_execute_pyout(self):
514 def test_execute_pyout(self):
515 """execute triggers pyout with silent=False"""
515 """execute triggers pyout with silent=False"""
516 view = self.client[:]
516 view = self.client[:]
517 ar = view.execute("5", silent=False, block=True)
517 ar = view.execute("5", silent=False, block=True)
518
518
519 expected = [{'text/plain' : '5'}] * len(view)
519 expected = [{'text/plain' : '5'}] * len(view)
520 mimes = [ out['data'] for out in ar.pyout ]
520 mimes = [ out['data'] for out in ar.pyout ]
521 self.assertEquals(mimes, expected)
521 self.assertEqual(mimes, expected)
522
522
523 def test_execute_silent(self):
523 def test_execute_silent(self):
524 """execute does not trigger pyout with silent=True"""
524 """execute does not trigger pyout with silent=True"""
525 view = self.client[:]
525 view = self.client[:]
526 ar = view.execute("5", block=True)
526 ar = view.execute("5", block=True)
527 expected = [None] * len(view)
527 expected = [None] * len(view)
528 self.assertEquals(ar.pyout, expected)
528 self.assertEqual(ar.pyout, expected)
529
529
530 def test_execute_magic(self):
530 def test_execute_magic(self):
531 """execute accepts IPython commands"""
531 """execute accepts IPython commands"""
532 view = self.client[:]
532 view = self.client[:]
533 view.execute("a = 5")
533 view.execute("a = 5")
534 ar = view.execute("%whos", block=True)
534 ar = view.execute("%whos", block=True)
535 # this will raise, if that failed
535 # this will raise, if that failed
536 ar.get(5)
536 ar.get(5)
537 for stdout in ar.stdout:
537 for stdout in ar.stdout:
538 lines = stdout.splitlines()
538 lines = stdout.splitlines()
539 self.assertEquals(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
539 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
540 found = False
540 found = False
541 for line in lines[2:]:
541 for line in lines[2:]:
542 split = line.split()
542 split = line.split()
543 if split == ['a', 'int', '5']:
543 if split == ['a', 'int', '5']:
544 found = True
544 found = True
545 break
545 break
546 self.assertTrue(found, "whos output wrong: %s" % stdout)
546 self.assertTrue(found, "whos output wrong: %s" % stdout)
547
547
548 def test_execute_displaypub(self):
548 def test_execute_displaypub(self):
549 """execute tracks display_pub output"""
549 """execute tracks display_pub output"""
550 view = self.client[:]
550 view = self.client[:]
551 view.execute("from IPython.core.display import *")
551 view.execute("from IPython.core.display import *")
552 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
552 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
553
553
554 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
554 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
555 for outputs in ar.outputs:
555 for outputs in ar.outputs:
556 mimes = [ out['data'] for out in outputs ]
556 mimes = [ out['data'] for out in outputs ]
557 self.assertEquals(mimes, expected)
557 self.assertEqual(mimes, expected)
558
558
559 def test_apply_displaypub(self):
559 def test_apply_displaypub(self):
560 """apply tracks display_pub output"""
560 """apply tracks display_pub output"""
561 view = self.client[:]
561 view = self.client[:]
562 view.execute("from IPython.core.display import *")
562 view.execute("from IPython.core.display import *")
563
563
564 @interactive
564 @interactive
565 def publish():
565 def publish():
566 [ display(i) for i in range(5) ]
566 [ display(i) for i in range(5) ]
567
567
568 ar = view.apply_async(publish)
568 ar = view.apply_async(publish)
569 ar.get(5)
569 ar.get(5)
570 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
570 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
571 for outputs in ar.outputs:
571 for outputs in ar.outputs:
572 mimes = [ out['data'] for out in outputs ]
572 mimes = [ out['data'] for out in outputs ]
573 self.assertEquals(mimes, expected)
573 self.assertEqual(mimes, expected)
574
574
575 def test_execute_raises(self):
575 def test_execute_raises(self):
576 """exceptions in execute requests raise appropriately"""
576 """exceptions in execute requests raise appropriately"""
577 view = self.client[-1]
577 view = self.client[-1]
578 ar = view.execute("1/0")
578 ar = view.execute("1/0")
579 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
579 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
580
580
581 @dec.skipif_not_matplotlib
581 @dec.skipif_not_matplotlib
582 def test_magic_pylab(self):
582 def test_magic_pylab(self):
583 """%pylab works on engines"""
583 """%pylab works on engines"""
584 view = self.client[-1]
584 view = self.client[-1]
585 ar = view.execute("%pylab inline")
585 ar = view.execute("%pylab inline")
586 # at least check if this raised:
586 # at least check if this raised:
587 reply = ar.get(5)
587 reply = ar.get(5)
588 # include imports, in case user config
588 # include imports, in case user config
589 ar = view.execute("plot(rand(100))", silent=False)
589 ar = view.execute("plot(rand(100))", silent=False)
590 reply = ar.get(5)
590 reply = ar.get(5)
591 self.assertEquals(len(reply.outputs), 1)
591 self.assertEqual(len(reply.outputs), 1)
592 output = reply.outputs[0]
592 output = reply.outputs[0]
593 self.assertTrue("data" in output)
593 self.assertTrue("data" in output)
594 data = output['data']
594 data = output['data']
595 self.assertTrue("image/png" in data)
595 self.assertTrue("image/png" in data)
596
596
597
597
@@ -1,131 +1,131 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Tests for platutils.py
3 Tests for platutils.py
4 """
4 """
5
5
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Copyright (C) 2008-2011 The IPython Development Team
7 # Copyright (C) 2008-2011 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16
16
17 import sys
17 import sys
18 from unittest import TestCase
18 from unittest import TestCase
19
19
20 import nose.tools as nt
20 import nose.tools as nt
21
21
22 from IPython.utils.process import (find_cmd, FindCmdError, arg_split,
22 from IPython.utils.process import (find_cmd, FindCmdError, arg_split,
23 system, getoutput, getoutputerror)
23 system, getoutput, getoutputerror)
24 from IPython.testing import decorators as dec
24 from IPython.testing import decorators as dec
25 from IPython.testing import tools as tt
25 from IPython.testing import tools as tt
26
26
27 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
28 # Tests
28 # Tests
29 #-----------------------------------------------------------------------------
29 #-----------------------------------------------------------------------------
30
30
31 def test_find_cmd_python():
31 def test_find_cmd_python():
32 """Make sure we find sys.exectable for python."""
32 """Make sure we find sys.exectable for python."""
33 nt.assert_equals(find_cmd('python'), sys.executable)
33 nt.assert_equals(find_cmd('python'), sys.executable)
34
34
35
35
36 @dec.skip_win32
36 @dec.skip_win32
37 def test_find_cmd_ls():
37 def test_find_cmd_ls():
38 """Make sure we can find the full path to ls."""
38 """Make sure we can find the full path to ls."""
39 path = find_cmd('ls')
39 path = find_cmd('ls')
40 nt.assert_true(path.endswith('ls'))
40 nt.assert_true(path.endswith('ls'))
41
41
42
42
43 def has_pywin32():
43 def has_pywin32():
44 try:
44 try:
45 import win32api
45 import win32api
46 except ImportError:
46 except ImportError:
47 return False
47 return False
48 return True
48 return True
49
49
50
50
51 @dec.onlyif(has_pywin32, "This test requires win32api to run")
51 @dec.onlyif(has_pywin32, "This test requires win32api to run")
52 def test_find_cmd_pythonw():
52 def test_find_cmd_pythonw():
53 """Try to find pythonw on Windows."""
53 """Try to find pythonw on Windows."""
54 path = find_cmd('pythonw')
54 path = find_cmd('pythonw')
55 nt.assert_true(path.endswith('pythonw.exe'))
55 nt.assert_true(path.endswith('pythonw.exe'))
56
56
57
57
58 @dec.onlyif(lambda : sys.platform != 'win32' or has_pywin32(),
58 @dec.onlyif(lambda : sys.platform != 'win32' or has_pywin32(),
59 "This test runs on posix or in win32 with win32api installed")
59 "This test runs on posix or in win32 with win32api installed")
60 def test_find_cmd_fail():
60 def test_find_cmd_fail():
61 """Make sure that FindCmdError is raised if we can't find the cmd."""
61 """Make sure that FindCmdError is raised if we can't find the cmd."""
62 nt.assert_raises(FindCmdError,find_cmd,'asdfasdf')
62 nt.assert_raises(FindCmdError,find_cmd,'asdfasdf')
63
63
64
64
65 @dec.skip_win32
65 @dec.skip_win32
66 def test_arg_split():
66 def test_arg_split():
67 """Ensure that argument lines are correctly split like in a shell."""
67 """Ensure that argument lines are correctly split like in a shell."""
68 tests = [['hi', ['hi']],
68 tests = [['hi', ['hi']],
69 [u'hi', [u'hi']],
69 [u'hi', [u'hi']],
70 ['hello there', ['hello', 'there']],
70 ['hello there', ['hello', 'there']],
71 # \u01ce == \N{LATIN SMALL LETTER A WITH CARON}
71 # \u01ce == \N{LATIN SMALL LETTER A WITH CARON}
72 # Do not use \N because the tests crash with syntax error in
72 # Do not use \N because the tests crash with syntax error in
73 # some cases, for example windows python2.6.
73 # some cases, for example windows python2.6.
74 [u'h\u01cello', [u'h\u01cello']],
74 [u'h\u01cello', [u'h\u01cello']],
75 ['something "with quotes"', ['something', '"with quotes"']],
75 ['something "with quotes"', ['something', '"with quotes"']],
76 ]
76 ]
77 for argstr, argv in tests:
77 for argstr, argv in tests:
78 nt.assert_equal(arg_split(argstr), argv)
78 nt.assert_equal(arg_split(argstr), argv)
79
79
80 @dec.skip_if_not_win32
80 @dec.skip_if_not_win32
81 def test_arg_split_win32():
81 def test_arg_split_win32():
82 """Ensure that argument lines are correctly split like in a shell."""
82 """Ensure that argument lines are correctly split like in a shell."""
83 tests = [['hi', ['hi']],
83 tests = [['hi', ['hi']],
84 [u'hi', [u'hi']],
84 [u'hi', [u'hi']],
85 ['hello there', ['hello', 'there']],
85 ['hello there', ['hello', 'there']],
86 [u'h\u01cello', [u'h\u01cello']],
86 [u'h\u01cello', [u'h\u01cello']],
87 ['something "with quotes"', ['something', 'with quotes']],
87 ['something "with quotes"', ['something', 'with quotes']],
88 ]
88 ]
89 for argstr, argv in tests:
89 for argstr, argv in tests:
90 nt.assert_equal(arg_split(argstr), argv)
90 nt.assert_equal(arg_split(argstr), argv)
91
91
92
92
93 class SubProcessTestCase(TestCase, tt.TempFileMixin):
93 class SubProcessTestCase(TestCase, tt.TempFileMixin):
94 def setUp(self):
94 def setUp(self):
95 """Make a valid python temp file."""
95 """Make a valid python temp file."""
96 lines = ["from __future__ import print_function",
96 lines = ["from __future__ import print_function",
97 "import sys",
97 "import sys",
98 "print('on stdout', end='', file=sys.stdout)",
98 "print('on stdout', end='', file=sys.stdout)",
99 "print('on stderr', end='', file=sys.stderr)",
99 "print('on stderr', end='', file=sys.stderr)",
100 "sys.stdout.flush()",
100 "sys.stdout.flush()",
101 "sys.stderr.flush()"]
101 "sys.stderr.flush()"]
102 self.mktmp('\n'.join(lines))
102 self.mktmp('\n'.join(lines))
103
103
104 def test_system(self):
104 def test_system(self):
105 status = system('python "%s"' % self.fname)
105 status = system('python "%s"' % self.fname)
106 self.assertEquals(status, 0)
106 self.assertEqual(status, 0)
107
107
108 def test_system_quotes(self):
108 def test_system_quotes(self):
109 status = system('python -c "import sys"')
109 status = system('python -c "import sys"')
110 self.assertEquals(status, 0)
110 self.assertEqual(status, 0)
111
111
112 def test_getoutput(self):
112 def test_getoutput(self):
113 out = getoutput('python "%s"' % self.fname)
113 out = getoutput('python "%s"' % self.fname)
114 self.assertEquals(out, 'on stdout')
114 self.assertEqual(out, 'on stdout')
115
115
116 def test_getoutput_quoted(self):
116 def test_getoutput_quoted(self):
117 out = getoutput('python -c "print (1)"')
117 out = getoutput('python -c "print (1)"')
118 self.assertEquals(out.strip(), '1')
118 self.assertEqual(out.strip(), '1')
119
119
120 #Invalid quoting on windows
120 #Invalid quoting on windows
121 @dec.skip_win32
121 @dec.skip_win32
122 def test_getoutput_quoted2(self):
122 def test_getoutput_quoted2(self):
123 out = getoutput("python -c 'print (1)'")
123 out = getoutput("python -c 'print (1)'")
124 self.assertEquals(out.strip(), '1')
124 self.assertEqual(out.strip(), '1')
125 out = getoutput("python -c 'print (\"1\")'")
125 out = getoutput("python -c 'print (\"1\")'")
126 self.assertEquals(out.strip(), '1')
126 self.assertEqual(out.strip(), '1')
127
127
128 def test_getoutput(self):
128 def test_getoutput(self):
129 out, err = getoutputerror('python "%s"' % self.fname)
129 out, err = getoutputerror('python "%s"' % self.fname)
130 self.assertEquals(out, 'on stdout')
130 self.assertEqual(out, 'on stdout')
131 self.assertEquals(err, 'on stderr')
131 self.assertEqual(err, 'on stderr')
@@ -1,908 +1,908 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Tests for IPython.utils.traitlets.
3 Tests for IPython.utils.traitlets.
4
4
5 Authors:
5 Authors:
6
6
7 * Brian Granger
7 * Brian Granger
8 * Enthought, Inc. Some of the code in this file comes from enthought.traits
8 * Enthought, Inc. Some of the code in this file comes from enthought.traits
9 and is licensed under the BSD license. Also, many of the ideas also come
9 and is licensed under the BSD license. Also, many of the ideas also come
10 from enthought.traits even though our implementation is very different.
10 from enthought.traits even though our implementation is very different.
11 """
11 """
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Copyright (C) 2008-2011 The IPython Development Team
14 # Copyright (C) 2008-2011 The IPython Development Team
15 #
15 #
16 # Distributed under the terms of the BSD License. The full license is in
16 # Distributed under the terms of the BSD License. The full license is in
17 # the file COPYING, distributed as part of this software.
17 # the file COPYING, distributed as part of this software.
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21 # Imports
21 # Imports
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23
23
24 import re
24 import re
25 import sys
25 import sys
26 from unittest import TestCase
26 from unittest import TestCase
27
27
28 from nose import SkipTest
28 from nose import SkipTest
29
29
30 from IPython.utils.traitlets import (
30 from IPython.utils.traitlets import (
31 HasTraits, MetaHasTraits, TraitType, Any, CBytes,
31 HasTraits, MetaHasTraits, TraitType, Any, CBytes,
32 Int, Long, Integer, Float, Complex, Bytes, Unicode, TraitError,
32 Int, Long, Integer, Float, Complex, Bytes, Unicode, TraitError,
33 Undefined, Type, This, Instance, TCPAddress, List, Tuple,
33 Undefined, Type, This, Instance, TCPAddress, List, Tuple,
34 ObjectName, DottedObjectName, CRegExp
34 ObjectName, DottedObjectName, CRegExp
35 )
35 )
36 from IPython.utils import py3compat
36 from IPython.utils import py3compat
37 from IPython.testing.decorators import skipif
37 from IPython.testing.decorators import skipif
38
38
39 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
40 # Helper classes for testing
40 # Helper classes for testing
41 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
42
42
43
43
44 class HasTraitsStub(HasTraits):
44 class HasTraitsStub(HasTraits):
45
45
46 def _notify_trait(self, name, old, new):
46 def _notify_trait(self, name, old, new):
47 self._notify_name = name
47 self._notify_name = name
48 self._notify_old = old
48 self._notify_old = old
49 self._notify_new = new
49 self._notify_new = new
50
50
51
51
52 #-----------------------------------------------------------------------------
52 #-----------------------------------------------------------------------------
53 # Test classes
53 # Test classes
54 #-----------------------------------------------------------------------------
54 #-----------------------------------------------------------------------------
55
55
56
56
57 class TestTraitType(TestCase):
57 class TestTraitType(TestCase):
58
58
59 def test_get_undefined(self):
59 def test_get_undefined(self):
60 class A(HasTraits):
60 class A(HasTraits):
61 a = TraitType
61 a = TraitType
62 a = A()
62 a = A()
63 self.assertEquals(a.a, Undefined)
63 self.assertEqual(a.a, Undefined)
64
64
65 def test_set(self):
65 def test_set(self):
66 class A(HasTraitsStub):
66 class A(HasTraitsStub):
67 a = TraitType
67 a = TraitType
68
68
69 a = A()
69 a = A()
70 a.a = 10
70 a.a = 10
71 self.assertEquals(a.a, 10)
71 self.assertEqual(a.a, 10)
72 self.assertEquals(a._notify_name, 'a')
72 self.assertEqual(a._notify_name, 'a')
73 self.assertEquals(a._notify_old, Undefined)
73 self.assertEqual(a._notify_old, Undefined)
74 self.assertEquals(a._notify_new, 10)
74 self.assertEqual(a._notify_new, 10)
75
75
76 def test_validate(self):
76 def test_validate(self):
77 class MyTT(TraitType):
77 class MyTT(TraitType):
78 def validate(self, inst, value):
78 def validate(self, inst, value):
79 return -1
79 return -1
80 class A(HasTraitsStub):
80 class A(HasTraitsStub):
81 tt = MyTT
81 tt = MyTT
82
82
83 a = A()
83 a = A()
84 a.tt = 10
84 a.tt = 10
85 self.assertEquals(a.tt, -1)
85 self.assertEqual(a.tt, -1)
86
86
87 def test_default_validate(self):
87 def test_default_validate(self):
88 class MyIntTT(TraitType):
88 class MyIntTT(TraitType):
89 def validate(self, obj, value):
89 def validate(self, obj, value):
90 if isinstance(value, int):
90 if isinstance(value, int):
91 return value
91 return value
92 self.error(obj, value)
92 self.error(obj, value)
93 class A(HasTraits):
93 class A(HasTraits):
94 tt = MyIntTT(10)
94 tt = MyIntTT(10)
95 a = A()
95 a = A()
96 self.assertEquals(a.tt, 10)
96 self.assertEqual(a.tt, 10)
97
97
98 # Defaults are validated when the HasTraits is instantiated
98 # Defaults are validated when the HasTraits is instantiated
99 class B(HasTraits):
99 class B(HasTraits):
100 tt = MyIntTT('bad default')
100 tt = MyIntTT('bad default')
101 self.assertRaises(TraitError, B)
101 self.assertRaises(TraitError, B)
102
102
103 def test_is_valid_for(self):
103 def test_is_valid_for(self):
104 class MyTT(TraitType):
104 class MyTT(TraitType):
105 def is_valid_for(self, value):
105 def is_valid_for(self, value):
106 return True
106 return True
107 class A(HasTraits):
107 class A(HasTraits):
108 tt = MyTT
108 tt = MyTT
109
109
110 a = A()
110 a = A()
111 a.tt = 10
111 a.tt = 10
112 self.assertEquals(a.tt, 10)
112 self.assertEqual(a.tt, 10)
113
113
114 def test_value_for(self):
114 def test_value_for(self):
115 class MyTT(TraitType):
115 class MyTT(TraitType):
116 def value_for(self, value):
116 def value_for(self, value):
117 return 20
117 return 20
118 class A(HasTraits):
118 class A(HasTraits):
119 tt = MyTT
119 tt = MyTT
120
120
121 a = A()
121 a = A()
122 a.tt = 10
122 a.tt = 10
123 self.assertEquals(a.tt, 20)
123 self.assertEqual(a.tt, 20)
124
124
125 def test_info(self):
125 def test_info(self):
126 class A(HasTraits):
126 class A(HasTraits):
127 tt = TraitType
127 tt = TraitType
128 a = A()
128 a = A()
129 self.assertEquals(A.tt.info(), 'any value')
129 self.assertEqual(A.tt.info(), 'any value')
130
130
131 def test_error(self):
131 def test_error(self):
132 class A(HasTraits):
132 class A(HasTraits):
133 tt = TraitType
133 tt = TraitType
134 a = A()
134 a = A()
135 self.assertRaises(TraitError, A.tt.error, a, 10)
135 self.assertRaises(TraitError, A.tt.error, a, 10)
136
136
137 def test_dynamic_initializer(self):
137 def test_dynamic_initializer(self):
138 class A(HasTraits):
138 class A(HasTraits):
139 x = Int(10)
139 x = Int(10)
140 def _x_default(self):
140 def _x_default(self):
141 return 11
141 return 11
142 class B(A):
142 class B(A):
143 x = Int(20)
143 x = Int(20)
144 class C(A):
144 class C(A):
145 def _x_default(self):
145 def _x_default(self):
146 return 21
146 return 21
147
147
148 a = A()
148 a = A()
149 self.assertEquals(a._trait_values, {})
149 self.assertEqual(a._trait_values, {})
150 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
150 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
151 self.assertEquals(a.x, 11)
151 self.assertEqual(a.x, 11)
152 self.assertEquals(a._trait_values, {'x': 11})
152 self.assertEqual(a._trait_values, {'x': 11})
153 b = B()
153 b = B()
154 self.assertEquals(b._trait_values, {'x': 20})
154 self.assertEqual(b._trait_values, {'x': 20})
155 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
155 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
156 self.assertEquals(b.x, 20)
156 self.assertEqual(b.x, 20)
157 c = C()
157 c = C()
158 self.assertEquals(c._trait_values, {})
158 self.assertEqual(c._trait_values, {})
159 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
159 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
160 self.assertEquals(c.x, 21)
160 self.assertEqual(c.x, 21)
161 self.assertEquals(c._trait_values, {'x': 21})
161 self.assertEqual(c._trait_values, {'x': 21})
162 # Ensure that the base class remains unmolested when the _default
162 # Ensure that the base class remains unmolested when the _default
163 # initializer gets overridden in a subclass.
163 # initializer gets overridden in a subclass.
164 a = A()
164 a = A()
165 c = C()
165 c = C()
166 self.assertEquals(a._trait_values, {})
166 self.assertEqual(a._trait_values, {})
167 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
167 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
168 self.assertEquals(a.x, 11)
168 self.assertEqual(a.x, 11)
169 self.assertEquals(a._trait_values, {'x': 11})
169 self.assertEqual(a._trait_values, {'x': 11})
170
170
171
171
172
172
173 class TestHasTraitsMeta(TestCase):
173 class TestHasTraitsMeta(TestCase):
174
174
175 def test_metaclass(self):
175 def test_metaclass(self):
176 self.assertEquals(type(HasTraits), MetaHasTraits)
176 self.assertEqual(type(HasTraits), MetaHasTraits)
177
177
178 class A(HasTraits):
178 class A(HasTraits):
179 a = Int
179 a = Int
180
180
181 a = A()
181 a = A()
182 self.assertEquals(type(a.__class__), MetaHasTraits)
182 self.assertEqual(type(a.__class__), MetaHasTraits)
183 self.assertEquals(a.a,0)
183 self.assertEqual(a.a,0)
184 a.a = 10
184 a.a = 10
185 self.assertEquals(a.a,10)
185 self.assertEqual(a.a,10)
186
186
187 class B(HasTraits):
187 class B(HasTraits):
188 b = Int()
188 b = Int()
189
189
190 b = B()
190 b = B()
191 self.assertEquals(b.b,0)
191 self.assertEqual(b.b,0)
192 b.b = 10
192 b.b = 10
193 self.assertEquals(b.b,10)
193 self.assertEqual(b.b,10)
194
194
195 class C(HasTraits):
195 class C(HasTraits):
196 c = Int(30)
196 c = Int(30)
197
197
198 c = C()
198 c = C()
199 self.assertEquals(c.c,30)
199 self.assertEqual(c.c,30)
200 c.c = 10
200 c.c = 10
201 self.assertEquals(c.c,10)
201 self.assertEqual(c.c,10)
202
202
203 def test_this_class(self):
203 def test_this_class(self):
204 class A(HasTraits):
204 class A(HasTraits):
205 t = This()
205 t = This()
206 tt = This()
206 tt = This()
207 class B(A):
207 class B(A):
208 tt = This()
208 tt = This()
209 ttt = This()
209 ttt = This()
210 self.assertEquals(A.t.this_class, A)
210 self.assertEqual(A.t.this_class, A)
211 self.assertEquals(B.t.this_class, A)
211 self.assertEqual(B.t.this_class, A)
212 self.assertEquals(B.tt.this_class, B)
212 self.assertEqual(B.tt.this_class, B)
213 self.assertEquals(B.ttt.this_class, B)
213 self.assertEqual(B.ttt.this_class, B)
214
214
215 class TestHasTraitsNotify(TestCase):
215 class TestHasTraitsNotify(TestCase):
216
216
217 def setUp(self):
217 def setUp(self):
218 self._notify1 = []
218 self._notify1 = []
219 self._notify2 = []
219 self._notify2 = []
220
220
221 def notify1(self, name, old, new):
221 def notify1(self, name, old, new):
222 self._notify1.append((name, old, new))
222 self._notify1.append((name, old, new))
223
223
224 def notify2(self, name, old, new):
224 def notify2(self, name, old, new):
225 self._notify2.append((name, old, new))
225 self._notify2.append((name, old, new))
226
226
227 def test_notify_all(self):
227 def test_notify_all(self):
228
228
229 class A(HasTraits):
229 class A(HasTraits):
230 a = Int
230 a = Int
231 b = Float
231 b = Float
232
232
233 a = A()
233 a = A()
234 a.on_trait_change(self.notify1)
234 a.on_trait_change(self.notify1)
235 a.a = 0
235 a.a = 0
236 self.assertEquals(len(self._notify1),0)
236 self.assertEqual(len(self._notify1),0)
237 a.b = 0.0
237 a.b = 0.0
238 self.assertEquals(len(self._notify1),0)
238 self.assertEqual(len(self._notify1),0)
239 a.a = 10
239 a.a = 10
240 self.assert_(('a',0,10) in self._notify1)
240 self.assert_(('a',0,10) in self._notify1)
241 a.b = 10.0
241 a.b = 10.0
242 self.assert_(('b',0.0,10.0) in self._notify1)
242 self.assert_(('b',0.0,10.0) in self._notify1)
243 self.assertRaises(TraitError,setattr,a,'a','bad string')
243 self.assertRaises(TraitError,setattr,a,'a','bad string')
244 self.assertRaises(TraitError,setattr,a,'b','bad string')
244 self.assertRaises(TraitError,setattr,a,'b','bad string')
245 self._notify1 = []
245 self._notify1 = []
246 a.on_trait_change(self.notify1,remove=True)
246 a.on_trait_change(self.notify1,remove=True)
247 a.a = 20
247 a.a = 20
248 a.b = 20.0
248 a.b = 20.0
249 self.assertEquals(len(self._notify1),0)
249 self.assertEqual(len(self._notify1),0)
250
250
251 def test_notify_one(self):
251 def test_notify_one(self):
252
252
253 class A(HasTraits):
253 class A(HasTraits):
254 a = Int
254 a = Int
255 b = Float
255 b = Float
256
256
257 a = A()
257 a = A()
258 a.on_trait_change(self.notify1, 'a')
258 a.on_trait_change(self.notify1, 'a')
259 a.a = 0
259 a.a = 0
260 self.assertEquals(len(self._notify1),0)
260 self.assertEqual(len(self._notify1),0)
261 a.a = 10
261 a.a = 10
262 self.assert_(('a',0,10) in self._notify1)
262 self.assert_(('a',0,10) in self._notify1)
263 self.assertRaises(TraitError,setattr,a,'a','bad string')
263 self.assertRaises(TraitError,setattr,a,'a','bad string')
264
264
265 def test_subclass(self):
265 def test_subclass(self):
266
266
267 class A(HasTraits):
267 class A(HasTraits):
268 a = Int
268 a = Int
269
269
270 class B(A):
270 class B(A):
271 b = Float
271 b = Float
272
272
273 b = B()
273 b = B()
274 self.assertEquals(b.a,0)
274 self.assertEqual(b.a,0)
275 self.assertEquals(b.b,0.0)
275 self.assertEqual(b.b,0.0)
276 b.a = 100
276 b.a = 100
277 b.b = 100.0
277 b.b = 100.0
278 self.assertEquals(b.a,100)
278 self.assertEqual(b.a,100)
279 self.assertEquals(b.b,100.0)
279 self.assertEqual(b.b,100.0)
280
280
281 def test_notify_subclass(self):
281 def test_notify_subclass(self):
282
282
283 class A(HasTraits):
283 class A(HasTraits):
284 a = Int
284 a = Int
285
285
286 class B(A):
286 class B(A):
287 b = Float
287 b = Float
288
288
289 b = B()
289 b = B()
290 b.on_trait_change(self.notify1, 'a')
290 b.on_trait_change(self.notify1, 'a')
291 b.on_trait_change(self.notify2, 'b')
291 b.on_trait_change(self.notify2, 'b')
292 b.a = 0
292 b.a = 0
293 b.b = 0.0
293 b.b = 0.0
294 self.assertEquals(len(self._notify1),0)
294 self.assertEqual(len(self._notify1),0)
295 self.assertEquals(len(self._notify2),0)
295 self.assertEqual(len(self._notify2),0)
296 b.a = 10
296 b.a = 10
297 b.b = 10.0
297 b.b = 10.0
298 self.assert_(('a',0,10) in self._notify1)
298 self.assert_(('a',0,10) in self._notify1)
299 self.assert_(('b',0.0,10.0) in self._notify2)
299 self.assert_(('b',0.0,10.0) in self._notify2)
300
300
301 def test_static_notify(self):
301 def test_static_notify(self):
302
302
303 class A(HasTraits):
303 class A(HasTraits):
304 a = Int
304 a = Int
305 _notify1 = []
305 _notify1 = []
306 def _a_changed(self, name, old, new):
306 def _a_changed(self, name, old, new):
307 self._notify1.append((name, old, new))
307 self._notify1.append((name, old, new))
308
308
309 a = A()
309 a = A()
310 a.a = 0
310 a.a = 0
311 # This is broken!!!
311 # This is broken!!!
312 self.assertEquals(len(a._notify1),0)
312 self.assertEqual(len(a._notify1),0)
313 a.a = 10
313 a.a = 10
314 self.assert_(('a',0,10) in a._notify1)
314 self.assert_(('a',0,10) in a._notify1)
315
315
316 class B(A):
316 class B(A):
317 b = Float
317 b = Float
318 _notify2 = []
318 _notify2 = []
319 def _b_changed(self, name, old, new):
319 def _b_changed(self, name, old, new):
320 self._notify2.append((name, old, new))
320 self._notify2.append((name, old, new))
321
321
322 b = B()
322 b = B()
323 b.a = 10
323 b.a = 10
324 b.b = 10.0
324 b.b = 10.0
325 self.assert_(('a',0,10) in b._notify1)
325 self.assert_(('a',0,10) in b._notify1)
326 self.assert_(('b',0.0,10.0) in b._notify2)
326 self.assert_(('b',0.0,10.0) in b._notify2)
327
327
328 def test_notify_args(self):
328 def test_notify_args(self):
329
329
330 def callback0():
330 def callback0():
331 self.cb = ()
331 self.cb = ()
332 def callback1(name):
332 def callback1(name):
333 self.cb = (name,)
333 self.cb = (name,)
334 def callback2(name, new):
334 def callback2(name, new):
335 self.cb = (name, new)
335 self.cb = (name, new)
336 def callback3(name, old, new):
336 def callback3(name, old, new):
337 self.cb = (name, old, new)
337 self.cb = (name, old, new)
338
338
339 class A(HasTraits):
339 class A(HasTraits):
340 a = Int
340 a = Int
341
341
342 a = A()
342 a = A()
343 a.on_trait_change(callback0, 'a')
343 a.on_trait_change(callback0, 'a')
344 a.a = 10
344 a.a = 10
345 self.assertEquals(self.cb,())
345 self.assertEqual(self.cb,())
346 a.on_trait_change(callback0, 'a', remove=True)
346 a.on_trait_change(callback0, 'a', remove=True)
347
347
348 a.on_trait_change(callback1, 'a')
348 a.on_trait_change(callback1, 'a')
349 a.a = 100
349 a.a = 100
350 self.assertEquals(self.cb,('a',))
350 self.assertEqual(self.cb,('a',))
351 a.on_trait_change(callback1, 'a', remove=True)
351 a.on_trait_change(callback1, 'a', remove=True)
352
352
353 a.on_trait_change(callback2, 'a')
353 a.on_trait_change(callback2, 'a')
354 a.a = 1000
354 a.a = 1000
355 self.assertEquals(self.cb,('a',1000))
355 self.assertEqual(self.cb,('a',1000))
356 a.on_trait_change(callback2, 'a', remove=True)
356 a.on_trait_change(callback2, 'a', remove=True)
357
357
358 a.on_trait_change(callback3, 'a')
358 a.on_trait_change(callback3, 'a')
359 a.a = 10000
359 a.a = 10000
360 self.assertEquals(self.cb,('a',1000,10000))
360 self.assertEqual(self.cb,('a',1000,10000))
361 a.on_trait_change(callback3, 'a', remove=True)
361 a.on_trait_change(callback3, 'a', remove=True)
362
362
363 self.assertEquals(len(a._trait_notifiers['a']),0)
363 self.assertEqual(len(a._trait_notifiers['a']),0)
364
364
365
365
366 class TestHasTraits(TestCase):
366 class TestHasTraits(TestCase):
367
367
368 def test_trait_names(self):
368 def test_trait_names(self):
369 class A(HasTraits):
369 class A(HasTraits):
370 i = Int
370 i = Int
371 f = Float
371 f = Float
372 a = A()
372 a = A()
373 self.assertEquals(sorted(a.trait_names()),['f','i'])
373 self.assertEqual(sorted(a.trait_names()),['f','i'])
374 self.assertEquals(sorted(A.class_trait_names()),['f','i'])
374 self.assertEqual(sorted(A.class_trait_names()),['f','i'])
375
375
376 def test_trait_metadata(self):
376 def test_trait_metadata(self):
377 class A(HasTraits):
377 class A(HasTraits):
378 i = Int(config_key='MY_VALUE')
378 i = Int(config_key='MY_VALUE')
379 a = A()
379 a = A()
380 self.assertEquals(a.trait_metadata('i','config_key'), 'MY_VALUE')
380 self.assertEqual(a.trait_metadata('i','config_key'), 'MY_VALUE')
381
381
382 def test_traits(self):
382 def test_traits(self):
383 class A(HasTraits):
383 class A(HasTraits):
384 i = Int
384 i = Int
385 f = Float
385 f = Float
386 a = A()
386 a = A()
387 self.assertEquals(a.traits(), dict(i=A.i, f=A.f))
387 self.assertEqual(a.traits(), dict(i=A.i, f=A.f))
388 self.assertEquals(A.class_traits(), dict(i=A.i, f=A.f))
388 self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f))
389
389
390 def test_traits_metadata(self):
390 def test_traits_metadata(self):
391 class A(HasTraits):
391 class A(HasTraits):
392 i = Int(config_key='VALUE1', other_thing='VALUE2')
392 i = Int(config_key='VALUE1', other_thing='VALUE2')
393 f = Float(config_key='VALUE3', other_thing='VALUE2')
393 f = Float(config_key='VALUE3', other_thing='VALUE2')
394 j = Int(0)
394 j = Int(0)
395 a = A()
395 a = A()
396 self.assertEquals(a.traits(), dict(i=A.i, f=A.f, j=A.j))
396 self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
397 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
397 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
398 self.assertEquals(traits, dict(i=A.i))
398 self.assertEqual(traits, dict(i=A.i))
399
399
400 # This passes, but it shouldn't because I am replicating a bug in
400 # This passes, but it shouldn't because I am replicating a bug in
401 # traits.
401 # traits.
402 traits = a.traits(config_key=lambda v: True)
402 traits = a.traits(config_key=lambda v: True)
403 self.assertEquals(traits, dict(i=A.i, f=A.f, j=A.j))
403 self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
404
404
405 def test_init(self):
405 def test_init(self):
406 class A(HasTraits):
406 class A(HasTraits):
407 i = Int()
407 i = Int()
408 x = Float()
408 x = Float()
409 a = A(i=1, x=10.0)
409 a = A(i=1, x=10.0)
410 self.assertEquals(a.i, 1)
410 self.assertEqual(a.i, 1)
411 self.assertEquals(a.x, 10.0)
411 self.assertEqual(a.x, 10.0)
412
412
413 #-----------------------------------------------------------------------------
413 #-----------------------------------------------------------------------------
414 # Tests for specific trait types
414 # Tests for specific trait types
415 #-----------------------------------------------------------------------------
415 #-----------------------------------------------------------------------------
416
416
417
417
418 class TestType(TestCase):
418 class TestType(TestCase):
419
419
420 def test_default(self):
420 def test_default(self):
421
421
422 class B(object): pass
422 class B(object): pass
423 class A(HasTraits):
423 class A(HasTraits):
424 klass = Type
424 klass = Type
425
425
426 a = A()
426 a = A()
427 self.assertEquals(a.klass, None)
427 self.assertEqual(a.klass, None)
428
428
429 a.klass = B
429 a.klass = B
430 self.assertEquals(a.klass, B)
430 self.assertEqual(a.klass, B)
431 self.assertRaises(TraitError, setattr, a, 'klass', 10)
431 self.assertRaises(TraitError, setattr, a, 'klass', 10)
432
432
433 def test_value(self):
433 def test_value(self):
434
434
435 class B(object): pass
435 class B(object): pass
436 class C(object): pass
436 class C(object): pass
437 class A(HasTraits):
437 class A(HasTraits):
438 klass = Type(B)
438 klass = Type(B)
439
439
440 a = A()
440 a = A()
441 self.assertEquals(a.klass, B)
441 self.assertEqual(a.klass, B)
442 self.assertRaises(TraitError, setattr, a, 'klass', C)
442 self.assertRaises(TraitError, setattr, a, 'klass', C)
443 self.assertRaises(TraitError, setattr, a, 'klass', object)
443 self.assertRaises(TraitError, setattr, a, 'klass', object)
444 a.klass = B
444 a.klass = B
445
445
446 def test_allow_none(self):
446 def test_allow_none(self):
447
447
448 class B(object): pass
448 class B(object): pass
449 class C(B): pass
449 class C(B): pass
450 class A(HasTraits):
450 class A(HasTraits):
451 klass = Type(B, allow_none=False)
451 klass = Type(B, allow_none=False)
452
452
453 a = A()
453 a = A()
454 self.assertEquals(a.klass, B)
454 self.assertEqual(a.klass, B)
455 self.assertRaises(TraitError, setattr, a, 'klass', None)
455 self.assertRaises(TraitError, setattr, a, 'klass', None)
456 a.klass = C
456 a.klass = C
457 self.assertEquals(a.klass, C)
457 self.assertEqual(a.klass, C)
458
458
459 def test_validate_klass(self):
459 def test_validate_klass(self):
460
460
461 class A(HasTraits):
461 class A(HasTraits):
462 klass = Type('no strings allowed')
462 klass = Type('no strings allowed')
463
463
464 self.assertRaises(ImportError, A)
464 self.assertRaises(ImportError, A)
465
465
466 class A(HasTraits):
466 class A(HasTraits):
467 klass = Type('rub.adub.Duck')
467 klass = Type('rub.adub.Duck')
468
468
469 self.assertRaises(ImportError, A)
469 self.assertRaises(ImportError, A)
470
470
471 def test_validate_default(self):
471 def test_validate_default(self):
472
472
473 class B(object): pass
473 class B(object): pass
474 class A(HasTraits):
474 class A(HasTraits):
475 klass = Type('bad default', B)
475 klass = Type('bad default', B)
476
476
477 self.assertRaises(ImportError, A)
477 self.assertRaises(ImportError, A)
478
478
479 class C(HasTraits):
479 class C(HasTraits):
480 klass = Type(None, B, allow_none=False)
480 klass = Type(None, B, allow_none=False)
481
481
482 self.assertRaises(TraitError, C)
482 self.assertRaises(TraitError, C)
483
483
484 def test_str_klass(self):
484 def test_str_klass(self):
485
485
486 class A(HasTraits):
486 class A(HasTraits):
487 klass = Type('IPython.utils.ipstruct.Struct')
487 klass = Type('IPython.utils.ipstruct.Struct')
488
488
489 from IPython.utils.ipstruct import Struct
489 from IPython.utils.ipstruct import Struct
490 a = A()
490 a = A()
491 a.klass = Struct
491 a.klass = Struct
492 self.assertEquals(a.klass, Struct)
492 self.assertEqual(a.klass, Struct)
493
493
494 self.assertRaises(TraitError, setattr, a, 'klass', 10)
494 self.assertRaises(TraitError, setattr, a, 'klass', 10)
495
495
496 class TestInstance(TestCase):
496 class TestInstance(TestCase):
497
497
498 def test_basic(self):
498 def test_basic(self):
499 class Foo(object): pass
499 class Foo(object): pass
500 class Bar(Foo): pass
500 class Bar(Foo): pass
501 class Bah(object): pass
501 class Bah(object): pass
502
502
503 class A(HasTraits):
503 class A(HasTraits):
504 inst = Instance(Foo)
504 inst = Instance(Foo)
505
505
506 a = A()
506 a = A()
507 self.assert_(a.inst is None)
507 self.assert_(a.inst is None)
508 a.inst = Foo()
508 a.inst = Foo()
509 self.assert_(isinstance(a.inst, Foo))
509 self.assert_(isinstance(a.inst, Foo))
510 a.inst = Bar()
510 a.inst = Bar()
511 self.assert_(isinstance(a.inst, Foo))
511 self.assert_(isinstance(a.inst, Foo))
512 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
512 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
513 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
513 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
514 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
514 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
515
515
516 def test_unique_default_value(self):
516 def test_unique_default_value(self):
517 class Foo(object): pass
517 class Foo(object): pass
518 class A(HasTraits):
518 class A(HasTraits):
519 inst = Instance(Foo,(),{})
519 inst = Instance(Foo,(),{})
520
520
521 a = A()
521 a = A()
522 b = A()
522 b = A()
523 self.assert_(a.inst is not b.inst)
523 self.assert_(a.inst is not b.inst)
524
524
525 def test_args_kw(self):
525 def test_args_kw(self):
526 class Foo(object):
526 class Foo(object):
527 def __init__(self, c): self.c = c
527 def __init__(self, c): self.c = c
528 class Bar(object): pass
528 class Bar(object): pass
529 class Bah(object):
529 class Bah(object):
530 def __init__(self, c, d):
530 def __init__(self, c, d):
531 self.c = c; self.d = d
531 self.c = c; self.d = d
532
532
533 class A(HasTraits):
533 class A(HasTraits):
534 inst = Instance(Foo, (10,))
534 inst = Instance(Foo, (10,))
535 a = A()
535 a = A()
536 self.assertEquals(a.inst.c, 10)
536 self.assertEqual(a.inst.c, 10)
537
537
538 class B(HasTraits):
538 class B(HasTraits):
539 inst = Instance(Bah, args=(10,), kw=dict(d=20))
539 inst = Instance(Bah, args=(10,), kw=dict(d=20))
540 b = B()
540 b = B()
541 self.assertEquals(b.inst.c, 10)
541 self.assertEqual(b.inst.c, 10)
542 self.assertEquals(b.inst.d, 20)
542 self.assertEqual(b.inst.d, 20)
543
543
544 class C(HasTraits):
544 class C(HasTraits):
545 inst = Instance(Foo)
545 inst = Instance(Foo)
546 c = C()
546 c = C()
547 self.assert_(c.inst is None)
547 self.assert_(c.inst is None)
548
548
549 def test_bad_default(self):
549 def test_bad_default(self):
550 class Foo(object): pass
550 class Foo(object): pass
551
551
552 class A(HasTraits):
552 class A(HasTraits):
553 inst = Instance(Foo, allow_none=False)
553 inst = Instance(Foo, allow_none=False)
554
554
555 self.assertRaises(TraitError, A)
555 self.assertRaises(TraitError, A)
556
556
557 def test_instance(self):
557 def test_instance(self):
558 class Foo(object): pass
558 class Foo(object): pass
559
559
560 def inner():
560 def inner():
561 class A(HasTraits):
561 class A(HasTraits):
562 inst = Instance(Foo())
562 inst = Instance(Foo())
563
563
564 self.assertRaises(TraitError, inner)
564 self.assertRaises(TraitError, inner)
565
565
566
566
567 class TestThis(TestCase):
567 class TestThis(TestCase):
568
568
569 def test_this_class(self):
569 def test_this_class(self):
570 class Foo(HasTraits):
570 class Foo(HasTraits):
571 this = This
571 this = This
572
572
573 f = Foo()
573 f = Foo()
574 self.assertEquals(f.this, None)
574 self.assertEqual(f.this, None)
575 g = Foo()
575 g = Foo()
576 f.this = g
576 f.this = g
577 self.assertEquals(f.this, g)
577 self.assertEqual(f.this, g)
578 self.assertRaises(TraitError, setattr, f, 'this', 10)
578 self.assertRaises(TraitError, setattr, f, 'this', 10)
579
579
580 def test_this_inst(self):
580 def test_this_inst(self):
581 class Foo(HasTraits):
581 class Foo(HasTraits):
582 this = This()
582 this = This()
583
583
584 f = Foo()
584 f = Foo()
585 f.this = Foo()
585 f.this = Foo()
586 self.assert_(isinstance(f.this, Foo))
586 self.assert_(isinstance(f.this, Foo))
587
587
588 def test_subclass(self):
588 def test_subclass(self):
589 class Foo(HasTraits):
589 class Foo(HasTraits):
590 t = This()
590 t = This()
591 class Bar(Foo):
591 class Bar(Foo):
592 pass
592 pass
593 f = Foo()
593 f = Foo()
594 b = Bar()
594 b = Bar()
595 f.t = b
595 f.t = b
596 b.t = f
596 b.t = f
597 self.assertEquals(f.t, b)
597 self.assertEqual(f.t, b)
598 self.assertEquals(b.t, f)
598 self.assertEqual(b.t, f)
599
599
600 def test_subclass_override(self):
600 def test_subclass_override(self):
601 class Foo(HasTraits):
601 class Foo(HasTraits):
602 t = This()
602 t = This()
603 class Bar(Foo):
603 class Bar(Foo):
604 t = This()
604 t = This()
605 f = Foo()
605 f = Foo()
606 b = Bar()
606 b = Bar()
607 f.t = b
607 f.t = b
608 self.assertEquals(f.t, b)
608 self.assertEqual(f.t, b)
609 self.assertRaises(TraitError, setattr, b, 't', f)
609 self.assertRaises(TraitError, setattr, b, 't', f)
610
610
611 class TraitTestBase(TestCase):
611 class TraitTestBase(TestCase):
612 """A best testing class for basic trait types."""
612 """A best testing class for basic trait types."""
613
613
614 def assign(self, value):
614 def assign(self, value):
615 self.obj.value = value
615 self.obj.value = value
616
616
617 def coerce(self, value):
617 def coerce(self, value):
618 return value
618 return value
619
619
620 def test_good_values(self):
620 def test_good_values(self):
621 if hasattr(self, '_good_values'):
621 if hasattr(self, '_good_values'):
622 for value in self._good_values:
622 for value in self._good_values:
623 self.assign(value)
623 self.assign(value)
624 self.assertEquals(self.obj.value, self.coerce(value))
624 self.assertEqual(self.obj.value, self.coerce(value))
625
625
626 def test_bad_values(self):
626 def test_bad_values(self):
627 if hasattr(self, '_bad_values'):
627 if hasattr(self, '_bad_values'):
628 for value in self._bad_values:
628 for value in self._bad_values:
629 try:
629 try:
630 self.assertRaises(TraitError, self.assign, value)
630 self.assertRaises(TraitError, self.assign, value)
631 except AssertionError:
631 except AssertionError:
632 assert False, value
632 assert False, value
633
633
634 def test_default_value(self):
634 def test_default_value(self):
635 if hasattr(self, '_default_value'):
635 if hasattr(self, '_default_value'):
636 self.assertEquals(self._default_value, self.obj.value)
636 self.assertEqual(self._default_value, self.obj.value)
637
637
638 def tearDown(self):
638 def tearDown(self):
639 # restore default value after tests, if set
639 # restore default value after tests, if set
640 if hasattr(self, '_default_value'):
640 if hasattr(self, '_default_value'):
641 self.obj.value = self._default_value
641 self.obj.value = self._default_value
642
642
643
643
644 class AnyTrait(HasTraits):
644 class AnyTrait(HasTraits):
645
645
646 value = Any
646 value = Any
647
647
648 class AnyTraitTest(TraitTestBase):
648 class AnyTraitTest(TraitTestBase):
649
649
650 obj = AnyTrait()
650 obj = AnyTrait()
651
651
652 _default_value = None
652 _default_value = None
653 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
653 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
654 _bad_values = []
654 _bad_values = []
655
655
656
656
657 class IntTrait(HasTraits):
657 class IntTrait(HasTraits):
658
658
659 value = Int(99)
659 value = Int(99)
660
660
661 class TestInt(TraitTestBase):
661 class TestInt(TraitTestBase):
662
662
663 obj = IntTrait()
663 obj = IntTrait()
664 _default_value = 99
664 _default_value = 99
665 _good_values = [10, -10]
665 _good_values = [10, -10]
666 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j,
666 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j,
667 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
667 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
668 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
668 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
669 if not py3compat.PY3:
669 if not py3compat.PY3:
670 _bad_values.extend([10L, -10L, 10*sys.maxint, -10*sys.maxint])
670 _bad_values.extend([10L, -10L, 10*sys.maxint, -10*sys.maxint])
671
671
672
672
673 class LongTrait(HasTraits):
673 class LongTrait(HasTraits):
674
674
675 value = Long(99L)
675 value = Long(99L)
676
676
677 class TestLong(TraitTestBase):
677 class TestLong(TraitTestBase):
678
678
679 obj = LongTrait()
679 obj = LongTrait()
680
680
681 _default_value = 99L
681 _default_value = 99L
682 _good_values = [10, -10, 10L, -10L]
682 _good_values = [10, -10, 10L, -10L]
683 _bad_values = ['ten', u'ten', [10], [10l], {'ten': 10},(10,),(10L,),
683 _bad_values = ['ten', u'ten', [10], [10l], {'ten': 10},(10,),(10L,),
684 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
684 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
685 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
685 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
686 u'-10.1']
686 u'-10.1']
687 if not py3compat.PY3:
687 if not py3compat.PY3:
688 # maxint undefined on py3, because int == long
688 # maxint undefined on py3, because int == long
689 _good_values.extend([10*sys.maxint, -10*sys.maxint])
689 _good_values.extend([10*sys.maxint, -10*sys.maxint])
690
690
691 @skipif(py3compat.PY3, "not relevant on py3")
691 @skipif(py3compat.PY3, "not relevant on py3")
692 def test_cast_small(self):
692 def test_cast_small(self):
693 """Long casts ints to long"""
693 """Long casts ints to long"""
694 self.obj.value = 10
694 self.obj.value = 10
695 self.assertEquals(type(self.obj.value), long)
695 self.assertEqual(type(self.obj.value), long)
696
696
697
697
698 class IntegerTrait(HasTraits):
698 class IntegerTrait(HasTraits):
699 value = Integer(1)
699 value = Integer(1)
700
700
701 class TestInteger(TestLong):
701 class TestInteger(TestLong):
702 obj = IntegerTrait()
702 obj = IntegerTrait()
703 _default_value = 1
703 _default_value = 1
704
704
705 def coerce(self, n):
705 def coerce(self, n):
706 return int(n)
706 return int(n)
707
707
708 @skipif(py3compat.PY3, "not relevant on py3")
708 @skipif(py3compat.PY3, "not relevant on py3")
709 def test_cast_small(self):
709 def test_cast_small(self):
710 """Integer casts small longs to int"""
710 """Integer casts small longs to int"""
711 if py3compat.PY3:
711 if py3compat.PY3:
712 raise SkipTest("not relevant on py3")
712 raise SkipTest("not relevant on py3")
713
713
714 self.obj.value = 100L
714 self.obj.value = 100L
715 self.assertEquals(type(self.obj.value), int)
715 self.assertEqual(type(self.obj.value), int)
716
716
717
717
718 class FloatTrait(HasTraits):
718 class FloatTrait(HasTraits):
719
719
720 value = Float(99.0)
720 value = Float(99.0)
721
721
722 class TestFloat(TraitTestBase):
722 class TestFloat(TraitTestBase):
723
723
724 obj = FloatTrait()
724 obj = FloatTrait()
725
725
726 _default_value = 99.0
726 _default_value = 99.0
727 _good_values = [10, -10, 10.1, -10.1]
727 _good_values = [10, -10, 10.1, -10.1]
728 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None,
728 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None,
729 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
729 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
730 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
730 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
731 if not py3compat.PY3:
731 if not py3compat.PY3:
732 _bad_values.extend([10L, -10L])
732 _bad_values.extend([10L, -10L])
733
733
734
734
735 class ComplexTrait(HasTraits):
735 class ComplexTrait(HasTraits):
736
736
737 value = Complex(99.0-99.0j)
737 value = Complex(99.0-99.0j)
738
738
739 class TestComplex(TraitTestBase):
739 class TestComplex(TraitTestBase):
740
740
741 obj = ComplexTrait()
741 obj = ComplexTrait()
742
742
743 _default_value = 99.0-99.0j
743 _default_value = 99.0-99.0j
744 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
744 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
745 10.1j, 10.1+10.1j, 10.1-10.1j]
745 10.1j, 10.1+10.1j, 10.1-10.1j]
746 _bad_values = [u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
746 _bad_values = [u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
747 if not py3compat.PY3:
747 if not py3compat.PY3:
748 _bad_values.extend([10L, -10L])
748 _bad_values.extend([10L, -10L])
749
749
750
750
751 class BytesTrait(HasTraits):
751 class BytesTrait(HasTraits):
752
752
753 value = Bytes(b'string')
753 value = Bytes(b'string')
754
754
755 class TestBytes(TraitTestBase):
755 class TestBytes(TraitTestBase):
756
756
757 obj = BytesTrait()
757 obj = BytesTrait()
758
758
759 _default_value = b'string'
759 _default_value = b'string'
760 _good_values = [b'10', b'-10', b'10L',
760 _good_values = [b'10', b'-10', b'10L',
761 b'-10L', b'10.1', b'-10.1', b'string']
761 b'-10L', b'10.1', b'-10.1', b'string']
762 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j, [10],
762 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j, [10],
763 ['ten'],{'ten': 10},(10,), None, u'string']
763 ['ten'],{'ten': 10},(10,), None, u'string']
764
764
765
765
766 class UnicodeTrait(HasTraits):
766 class UnicodeTrait(HasTraits):
767
767
768 value = Unicode(u'unicode')
768 value = Unicode(u'unicode')
769
769
770 class TestUnicode(TraitTestBase):
770 class TestUnicode(TraitTestBase):
771
771
772 obj = UnicodeTrait()
772 obj = UnicodeTrait()
773
773
774 _default_value = u'unicode'
774 _default_value = u'unicode'
775 _good_values = ['10', '-10', '10L', '-10L', '10.1',
775 _good_values = ['10', '-10', '10L', '-10L', '10.1',
776 '-10.1', '', u'', 'string', u'string', u"€"]
776 '-10.1', '', u'', 'string', u'string', u"€"]
777 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j,
777 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j,
778 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
778 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
779
779
780
780
781 class ObjectNameTrait(HasTraits):
781 class ObjectNameTrait(HasTraits):
782 value = ObjectName("abc")
782 value = ObjectName("abc")
783
783
784 class TestObjectName(TraitTestBase):
784 class TestObjectName(TraitTestBase):
785 obj = ObjectNameTrait()
785 obj = ObjectNameTrait()
786
786
787 _default_value = "abc"
787 _default_value = "abc"
788 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
788 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
789 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
789 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
790 object(), object]
790 object(), object]
791 if sys.version_info[0] < 3:
791 if sys.version_info[0] < 3:
792 _bad_values.append(u"þ")
792 _bad_values.append(u"þ")
793 else:
793 else:
794 _good_values.append(u"þ") # þ=1 is valid in Python 3 (PEP 3131).
794 _good_values.append(u"þ") # þ=1 is valid in Python 3 (PEP 3131).
795
795
796
796
797 class DottedObjectNameTrait(HasTraits):
797 class DottedObjectNameTrait(HasTraits):
798 value = DottedObjectName("a.b")
798 value = DottedObjectName("a.b")
799
799
800 class TestDottedObjectName(TraitTestBase):
800 class TestDottedObjectName(TraitTestBase):
801 obj = DottedObjectNameTrait()
801 obj = DottedObjectNameTrait()
802
802
803 _default_value = "a.b"
803 _default_value = "a.b"
804 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
804 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
805 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc."]
805 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc."]
806 if sys.version_info[0] < 3:
806 if sys.version_info[0] < 3:
807 _bad_values.append(u"t.þ")
807 _bad_values.append(u"t.þ")
808 else:
808 else:
809 _good_values.append(u"t.þ")
809 _good_values.append(u"t.þ")
810
810
811
811
812 class TCPAddressTrait(HasTraits):
812 class TCPAddressTrait(HasTraits):
813
813
814 value = TCPAddress()
814 value = TCPAddress()
815
815
816 class TestTCPAddress(TraitTestBase):
816 class TestTCPAddress(TraitTestBase):
817
817
818 obj = TCPAddressTrait()
818 obj = TCPAddressTrait()
819
819
820 _default_value = ('127.0.0.1',0)
820 _default_value = ('127.0.0.1',0)
821 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
821 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
822 _bad_values = [(0,0),('localhost',10.0),('localhost',-1)]
822 _bad_values = [(0,0),('localhost',10.0),('localhost',-1)]
823
823
824 class ListTrait(HasTraits):
824 class ListTrait(HasTraits):
825
825
826 value = List(Int)
826 value = List(Int)
827
827
828 class TestList(TraitTestBase):
828 class TestList(TraitTestBase):
829
829
830 obj = ListTrait()
830 obj = ListTrait()
831
831
832 _default_value = []
832 _default_value = []
833 _good_values = [[], [1], range(10)]
833 _good_values = [[], [1], range(10)]
834 _bad_values = [10, [1,'a'], 'a', (1,2)]
834 _bad_values = [10, [1,'a'], 'a', (1,2)]
835
835
836 class LenListTrait(HasTraits):
836 class LenListTrait(HasTraits):
837
837
838 value = List(Int, [0], minlen=1, maxlen=2)
838 value = List(Int, [0], minlen=1, maxlen=2)
839
839
840 class TestLenList(TraitTestBase):
840 class TestLenList(TraitTestBase):
841
841
842 obj = LenListTrait()
842 obj = LenListTrait()
843
843
844 _default_value = [0]
844 _default_value = [0]
845 _good_values = [[1], range(2)]
845 _good_values = [[1], range(2)]
846 _bad_values = [10, [1,'a'], 'a', (1,2), [], range(3)]
846 _bad_values = [10, [1,'a'], 'a', (1,2), [], range(3)]
847
847
848 class TupleTrait(HasTraits):
848 class TupleTrait(HasTraits):
849
849
850 value = Tuple(Int)
850 value = Tuple(Int)
851
851
852 class TestTupleTrait(TraitTestBase):
852 class TestTupleTrait(TraitTestBase):
853
853
854 obj = TupleTrait()
854 obj = TupleTrait()
855
855
856 _default_value = None
856 _default_value = None
857 _good_values = [(1,), None,(0,)]
857 _good_values = [(1,), None,(0,)]
858 _bad_values = [10, (1,2), [1],('a'), ()]
858 _bad_values = [10, (1,2), [1],('a'), ()]
859
859
860 def test_invalid_args(self):
860 def test_invalid_args(self):
861 self.assertRaises(TypeError, Tuple, 5)
861 self.assertRaises(TypeError, Tuple, 5)
862 self.assertRaises(TypeError, Tuple, default_value='hello')
862 self.assertRaises(TypeError, Tuple, default_value='hello')
863 t = Tuple(Int, CBytes, default_value=(1,5))
863 t = Tuple(Int, CBytes, default_value=(1,5))
864
864
865 class LooseTupleTrait(HasTraits):
865 class LooseTupleTrait(HasTraits):
866
866
867 value = Tuple((1,2,3))
867 value = Tuple((1,2,3))
868
868
869 class TestLooseTupleTrait(TraitTestBase):
869 class TestLooseTupleTrait(TraitTestBase):
870
870
871 obj = LooseTupleTrait()
871 obj = LooseTupleTrait()
872
872
873 _default_value = (1,2,3)
873 _default_value = (1,2,3)
874 _good_values = [(1,), None, (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
874 _good_values = [(1,), None, (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
875 _bad_values = [10, 'hello', [1], []]
875 _bad_values = [10, 'hello', [1], []]
876
876
877 def test_invalid_args(self):
877 def test_invalid_args(self):
878 self.assertRaises(TypeError, Tuple, 5)
878 self.assertRaises(TypeError, Tuple, 5)
879 self.assertRaises(TypeError, Tuple, default_value='hello')
879 self.assertRaises(TypeError, Tuple, default_value='hello')
880 t = Tuple(Int, CBytes, default_value=(1,5))
880 t = Tuple(Int, CBytes, default_value=(1,5))
881
881
882
882
883 class MultiTupleTrait(HasTraits):
883 class MultiTupleTrait(HasTraits):
884
884
885 value = Tuple(Int, Bytes, default_value=[99,b'bottles'])
885 value = Tuple(Int, Bytes, default_value=[99,b'bottles'])
886
886
887 class TestMultiTuple(TraitTestBase):
887 class TestMultiTuple(TraitTestBase):
888
888
889 obj = MultiTupleTrait()
889 obj = MultiTupleTrait()
890
890
891 _default_value = (99,b'bottles')
891 _default_value = (99,b'bottles')
892 _good_values = [(1,b'a'), (2,b'b')]
892 _good_values = [(1,b'a'), (2,b'b')]
893 _bad_values = ((),10, b'a', (1,b'a',3), (b'a',1), (1, u'a'))
893 _bad_values = ((),10, b'a', (1,b'a',3), (b'a',1), (1, u'a'))
894
894
895 class CRegExpTrait(HasTraits):
895 class CRegExpTrait(HasTraits):
896
896
897 value = CRegExp(r'')
897 value = CRegExp(r'')
898
898
899 class TestCRegExp(TraitTestBase):
899 class TestCRegExp(TraitTestBase):
900
900
901 def coerce(self, value):
901 def coerce(self, value):
902 return re.compile(value)
902 return re.compile(value)
903
903
904 obj = CRegExpTrait()
904 obj = CRegExpTrait()
905
905
906 _default_value = re.compile(r'')
906 _default_value = re.compile(r'')
907 _good_values = [r'\d+', re.compile(r'\d+')]
907 _good_values = [r'\d+', re.compile(r'\d+')]
908 _bad_values = [r'(', None, ()]
908 _bad_values = [r'(', None, ()]
@@ -1,212 +1,212 b''
1 """test building messages with streamsession"""
1 """test building messages with streamsession"""
2
2
3 #-------------------------------------------------------------------------------
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
4 # Copyright (C) 2011 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9
9
10 #-------------------------------------------------------------------------------
10 #-------------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-------------------------------------------------------------------------------
12 #-------------------------------------------------------------------------------
13
13
14 import os
14 import os
15 import uuid
15 import uuid
16 import zmq
16 import zmq
17
17
18 from zmq.tests import BaseZMQTestCase
18 from zmq.tests import BaseZMQTestCase
19 from zmq.eventloop.zmqstream import ZMQStream
19 from zmq.eventloop.zmqstream import ZMQStream
20
20
21 from IPython.zmq import session as ss
21 from IPython.zmq import session as ss
22
22
23 class SessionTestCase(BaseZMQTestCase):
23 class SessionTestCase(BaseZMQTestCase):
24
24
25 def setUp(self):
25 def setUp(self):
26 BaseZMQTestCase.setUp(self)
26 BaseZMQTestCase.setUp(self)
27 self.session = ss.Session()
27 self.session = ss.Session()
28
28
29
29
30 class MockSocket(zmq.Socket):
30 class MockSocket(zmq.Socket):
31
31
32 def __init__(self, *args, **kwargs):
32 def __init__(self, *args, **kwargs):
33 super(MockSocket,self).__init__(*args,**kwargs)
33 super(MockSocket,self).__init__(*args,**kwargs)
34 self.data = []
34 self.data = []
35
35
36 def send_multipart(self, msgparts, *args, **kwargs):
36 def send_multipart(self, msgparts, *args, **kwargs):
37 self.data.extend(msgparts)
37 self.data.extend(msgparts)
38
38
39 def send(self, part, *args, **kwargs):
39 def send(self, part, *args, **kwargs):
40 self.data.append(part)
40 self.data.append(part)
41
41
42 def recv_multipart(self, *args, **kwargs):
42 def recv_multipart(self, *args, **kwargs):
43 return self.data
43 return self.data
44
44
45 class TestSession(SessionTestCase):
45 class TestSession(SessionTestCase):
46
46
47 def test_msg(self):
47 def test_msg(self):
48 """message format"""
48 """message format"""
49 msg = self.session.msg('execute')
49 msg = self.session.msg('execute')
50 thekeys = set('header parent_header content msg_type msg_id'.split())
50 thekeys = set('header parent_header content msg_type msg_id'.split())
51 s = set(msg.keys())
51 s = set(msg.keys())
52 self.assertEquals(s, thekeys)
52 self.assertEqual(s, thekeys)
53 self.assertTrue(isinstance(msg['content'],dict))
53 self.assertTrue(isinstance(msg['content'],dict))
54 self.assertTrue(isinstance(msg['header'],dict))
54 self.assertTrue(isinstance(msg['header'],dict))
55 self.assertTrue(isinstance(msg['parent_header'],dict))
55 self.assertTrue(isinstance(msg['parent_header'],dict))
56 self.assertTrue(isinstance(msg['msg_id'],str))
56 self.assertTrue(isinstance(msg['msg_id'],str))
57 self.assertTrue(isinstance(msg['msg_type'],str))
57 self.assertTrue(isinstance(msg['msg_type'],str))
58 self.assertEquals(msg['header']['msg_type'], 'execute')
58 self.assertEqual(msg['header']['msg_type'], 'execute')
59 self.assertEquals(msg['msg_type'], 'execute')
59 self.assertEqual(msg['msg_type'], 'execute')
60
60
61 def test_serialize(self):
61 def test_serialize(self):
62 msg = self.session.msg('execute', content=dict(a=10, b=1.1))
62 msg = self.session.msg('execute', content=dict(a=10, b=1.1))
63 msg_list = self.session.serialize(msg, ident=b'foo')
63 msg_list = self.session.serialize(msg, ident=b'foo')
64 ident, msg_list = self.session.feed_identities(msg_list)
64 ident, msg_list = self.session.feed_identities(msg_list)
65 new_msg = self.session.unserialize(msg_list)
65 new_msg = self.session.unserialize(msg_list)
66 self.assertEquals(ident[0], b'foo')
66 self.assertEqual(ident[0], b'foo')
67 self.assertEquals(new_msg['msg_id'],msg['msg_id'])
67 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
68 self.assertEquals(new_msg['msg_type'],msg['msg_type'])
68 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
69 self.assertEquals(new_msg['header'],msg['header'])
69 self.assertEqual(new_msg['header'],msg['header'])
70 self.assertEquals(new_msg['content'],msg['content'])
70 self.assertEqual(new_msg['content'],msg['content'])
71 self.assertEquals(new_msg['parent_header'],msg['parent_header'])
71 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
72 # ensure floats don't come out as Decimal:
72 # ensure floats don't come out as Decimal:
73 self.assertEquals(type(new_msg['content']['b']),type(new_msg['content']['b']))
73 self.assertEqual(type(new_msg['content']['b']),type(new_msg['content']['b']))
74
74
75 def test_send(self):
75 def test_send(self):
76 socket = MockSocket(zmq.Context.instance(),zmq.PAIR)
76 socket = MockSocket(zmq.Context.instance(),zmq.PAIR)
77
77
78 msg = self.session.msg('execute', content=dict(a=10))
78 msg = self.session.msg('execute', content=dict(a=10))
79 self.session.send(socket, msg, ident=b'foo', buffers=[b'bar'])
79 self.session.send(socket, msg, ident=b'foo', buffers=[b'bar'])
80 ident, msg_list = self.session.feed_identities(socket.data)
80 ident, msg_list = self.session.feed_identities(socket.data)
81 new_msg = self.session.unserialize(msg_list)
81 new_msg = self.session.unserialize(msg_list)
82 self.assertEquals(ident[0], b'foo')
82 self.assertEqual(ident[0], b'foo')
83 self.assertEquals(new_msg['msg_id'],msg['msg_id'])
83 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
84 self.assertEquals(new_msg['msg_type'],msg['msg_type'])
84 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
85 self.assertEquals(new_msg['header'],msg['header'])
85 self.assertEqual(new_msg['header'],msg['header'])
86 self.assertEquals(new_msg['content'],msg['content'])
86 self.assertEqual(new_msg['content'],msg['content'])
87 self.assertEquals(new_msg['parent_header'],msg['parent_header'])
87 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
88 self.assertEquals(new_msg['buffers'],[b'bar'])
88 self.assertEqual(new_msg['buffers'],[b'bar'])
89
89
90 socket.data = []
90 socket.data = []
91
91
92 content = msg['content']
92 content = msg['content']
93 header = msg['header']
93 header = msg['header']
94 parent = msg['parent_header']
94 parent = msg['parent_header']
95 msg_type = header['msg_type']
95 msg_type = header['msg_type']
96 self.session.send(socket, None, content=content, parent=parent,
96 self.session.send(socket, None, content=content, parent=parent,
97 header=header, ident=b'foo', buffers=[b'bar'])
97 header=header, ident=b'foo', buffers=[b'bar'])
98 ident, msg_list = self.session.feed_identities(socket.data)
98 ident, msg_list = self.session.feed_identities(socket.data)
99 new_msg = self.session.unserialize(msg_list)
99 new_msg = self.session.unserialize(msg_list)
100 self.assertEquals(ident[0], b'foo')
100 self.assertEqual(ident[0], b'foo')
101 self.assertEquals(new_msg['msg_id'],msg['msg_id'])
101 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
102 self.assertEquals(new_msg['msg_type'],msg['msg_type'])
102 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
103 self.assertEquals(new_msg['header'],msg['header'])
103 self.assertEqual(new_msg['header'],msg['header'])
104 self.assertEquals(new_msg['content'],msg['content'])
104 self.assertEqual(new_msg['content'],msg['content'])
105 self.assertEquals(new_msg['parent_header'],msg['parent_header'])
105 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
106 self.assertEquals(new_msg['buffers'],[b'bar'])
106 self.assertEqual(new_msg['buffers'],[b'bar'])
107
107
108 socket.data = []
108 socket.data = []
109
109
110 self.session.send(socket, msg, ident=b'foo', buffers=[b'bar'])
110 self.session.send(socket, msg, ident=b'foo', buffers=[b'bar'])
111 ident, new_msg = self.session.recv(socket)
111 ident, new_msg = self.session.recv(socket)
112 self.assertEquals(ident[0], b'foo')
112 self.assertEqual(ident[0], b'foo')
113 self.assertEquals(new_msg['msg_id'],msg['msg_id'])
113 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
114 self.assertEquals(new_msg['msg_type'],msg['msg_type'])
114 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
115 self.assertEquals(new_msg['header'],msg['header'])
115 self.assertEqual(new_msg['header'],msg['header'])
116 self.assertEquals(new_msg['content'],msg['content'])
116 self.assertEqual(new_msg['content'],msg['content'])
117 self.assertEquals(new_msg['parent_header'],msg['parent_header'])
117 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
118 self.assertEquals(new_msg['buffers'],[b'bar'])
118 self.assertEqual(new_msg['buffers'],[b'bar'])
119
119
120 socket.close()
120 socket.close()
121
121
122 def test_args(self):
122 def test_args(self):
123 """initialization arguments for Session"""
123 """initialization arguments for Session"""
124 s = self.session
124 s = self.session
125 self.assertTrue(s.pack is ss.default_packer)
125 self.assertTrue(s.pack is ss.default_packer)
126 self.assertTrue(s.unpack is ss.default_unpacker)
126 self.assertTrue(s.unpack is ss.default_unpacker)
127 self.assertEquals(s.username, os.environ.get('USER', u'username'))
127 self.assertEqual(s.username, os.environ.get('USER', u'username'))
128
128
129 s = ss.Session()
129 s = ss.Session()
130 self.assertEquals(s.username, os.environ.get('USER', u'username'))
130 self.assertEqual(s.username, os.environ.get('USER', u'username'))
131
131
132 self.assertRaises(TypeError, ss.Session, pack='hi')
132 self.assertRaises(TypeError, ss.Session, pack='hi')
133 self.assertRaises(TypeError, ss.Session, unpack='hi')
133 self.assertRaises(TypeError, ss.Session, unpack='hi')
134 u = str(uuid.uuid4())
134 u = str(uuid.uuid4())
135 s = ss.Session(username=u'carrot', session=u)
135 s = ss.Session(username=u'carrot', session=u)
136 self.assertEquals(s.session, u)
136 self.assertEqual(s.session, u)
137 self.assertEquals(s.username, u'carrot')
137 self.assertEqual(s.username, u'carrot')
138
138
139 def test_tracking(self):
139 def test_tracking(self):
140 """test tracking messages"""
140 """test tracking messages"""
141 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
141 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
142 s = self.session
142 s = self.session
143 stream = ZMQStream(a)
143 stream = ZMQStream(a)
144 msg = s.send(a, 'hello', track=False)
144 msg = s.send(a, 'hello', track=False)
145 self.assertTrue(msg['tracker'] is None)
145 self.assertTrue(msg['tracker'] is None)
146 msg = s.send(a, 'hello', track=True)
146 msg = s.send(a, 'hello', track=True)
147 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
147 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
148 M = zmq.Message(b'hi there', track=True)
148 M = zmq.Message(b'hi there', track=True)
149 msg = s.send(a, 'hello', buffers=[M], track=True)
149 msg = s.send(a, 'hello', buffers=[M], track=True)
150 t = msg['tracker']
150 t = msg['tracker']
151 self.assertTrue(isinstance(t, zmq.MessageTracker))
151 self.assertTrue(isinstance(t, zmq.MessageTracker))
152 self.assertRaises(zmq.NotDone, t.wait, .1)
152 self.assertRaises(zmq.NotDone, t.wait, .1)
153 del M
153 del M
154 t.wait(1) # this will raise
154 t.wait(1) # this will raise
155
155
156
156
157 # def test_rekey(self):
157 # def test_rekey(self):
158 # """rekeying dict around json str keys"""
158 # """rekeying dict around json str keys"""
159 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
159 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
160 # self.assertRaises(KeyError, ss.rekey, d)
160 # self.assertRaises(KeyError, ss.rekey, d)
161 #
161 #
162 # d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()}
162 # d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()}
163 # d2 = {0:d['0'],1:d[1],'asdf':d['asdf']}
163 # d2 = {0:d['0'],1:d[1],'asdf':d['asdf']}
164 # rd = ss.rekey(d)
164 # rd = ss.rekey(d)
165 # self.assertEquals(d2,rd)
165 # self.assertEqual(d2,rd)
166 #
166 #
167 # d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()}
167 # d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()}
168 # d2 = {1.5:d['1.5'],1:d['1']}
168 # d2 = {1.5:d['1.5'],1:d['1']}
169 # rd = ss.rekey(d)
169 # rd = ss.rekey(d)
170 # self.assertEquals(d2,rd)
170 # self.assertEqual(d2,rd)
171 #
171 #
172 # d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()}
172 # d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()}
173 # self.assertRaises(KeyError, ss.rekey, d)
173 # self.assertRaises(KeyError, ss.rekey, d)
174 #
174 #
175 def test_unique_msg_ids(self):
175 def test_unique_msg_ids(self):
176 """test that messages receive unique ids"""
176 """test that messages receive unique ids"""
177 ids = set()
177 ids = set()
178 for i in range(2**12):
178 for i in range(2**12):
179 h = self.session.msg_header('test')
179 h = self.session.msg_header('test')
180 msg_id = h['msg_id']
180 msg_id = h['msg_id']
181 self.assertTrue(msg_id not in ids)
181 self.assertTrue(msg_id not in ids)
182 ids.add(msg_id)
182 ids.add(msg_id)
183
183
184 def test_feed_identities(self):
184 def test_feed_identities(self):
185 """scrub the front for zmq IDENTITIES"""
185 """scrub the front for zmq IDENTITIES"""
186 theids = "engine client other".split()
186 theids = "engine client other".split()
187 content = dict(code='whoda',stuff=object())
187 content = dict(code='whoda',stuff=object())
188 themsg = self.session.msg('execute',content=content)
188 themsg = self.session.msg('execute',content=content)
189 pmsg = theids
189 pmsg = theids
190
190
191 def test_session_id(self):
191 def test_session_id(self):
192 session = ss.Session()
192 session = ss.Session()
193 # get bs before us
193 # get bs before us
194 bs = session.bsession
194 bs = session.bsession
195 us = session.session
195 us = session.session
196 self.assertEquals(us.encode('ascii'), bs)
196 self.assertEqual(us.encode('ascii'), bs)
197 session = ss.Session()
197 session = ss.Session()
198 # get us before bs
198 # get us before bs
199 us = session.session
199 us = session.session
200 bs = session.bsession
200 bs = session.bsession
201 self.assertEquals(us.encode('ascii'), bs)
201 self.assertEqual(us.encode('ascii'), bs)
202 # change propagates:
202 # change propagates:
203 session.session = 'something else'
203 session.session = 'something else'
204 bs = session.bsession
204 bs = session.bsession
205 us = session.session
205 us = session.session
206 self.assertEquals(us.encode('ascii'), bs)
206 self.assertEqual(us.encode('ascii'), bs)
207 session = ss.Session(session='stuff')
207 session = ss.Session(session='stuff')
208 # get us before bs
208 # get us before bs
209 self.assertEquals(session.bsession, session.session.encode('ascii'))
209 self.assertEqual(session.bsession, session.session.encode('ascii'))
210 self.assertEquals(b'stuff', session.bsession)
210 self.assertEqual(b'stuff', session.bsession)
211
211
212
212
General Comments 0
You need to be logged in to leave comments. Login now