##// END OF EJS Templates
s/assertEquals/assertEqual/
Bradley M. Froehle -
Show More
@@ -1,175 +1,175 b''
1 1 """
2 2 Tests for IPython.config.application.Application
3 3
4 4 Authors:
5 5
6 6 * Brian Granger
7 7 """
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Copyright (C) 2008-2011 The IPython Development Team
11 11 #
12 12 # Distributed under the terms of the BSD License. The full license is in
13 13 # the file COPYING, distributed as part of this software.
14 14 #-----------------------------------------------------------------------------
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # Imports
18 18 #-----------------------------------------------------------------------------
19 19
20 20 import logging
21 21 from unittest import TestCase
22 22
23 23 from IPython.config.configurable import Configurable
24 24 from IPython.config.loader import Config
25 25
26 26 from IPython.config.application import (
27 27 Application
28 28 )
29 29
30 30 from IPython.utils.traitlets import (
31 31 Bool, Unicode, Integer, Float, List, Dict
32 32 )
33 33
34 34 #-----------------------------------------------------------------------------
35 35 # Code
36 36 #-----------------------------------------------------------------------------
37 37
38 38 class Foo(Configurable):
39 39
40 40 i = Integer(0, config=True, help="The integer i.")
41 41 j = Integer(1, config=True, help="The integer j.")
42 42 name = Unicode(u'Brian', config=True, help="First name.")
43 43
44 44
45 45 class Bar(Configurable):
46 46
47 47 b = Integer(0, config=True, help="The integer b.")
48 48 enabled = Bool(True, config=True, help="Enable bar.")
49 49
50 50
51 51 class MyApp(Application):
52 52
53 53 name = Unicode(u'myapp')
54 54 running = Bool(False, config=True,
55 55 help="Is the app running?")
56 56 classes = List([Bar, Foo])
57 57 config_file = Unicode(u'', config=True,
58 58 help="Load this config file")
59 59
60 60 aliases = Dict({
61 61 'i' : 'Foo.i',
62 62 'j' : 'Foo.j',
63 63 'name' : 'Foo.name',
64 64 'enabled' : 'Bar.enabled',
65 65 'log-level' : 'Application.log_level',
66 66 })
67 67
68 68 flags = Dict(dict(enable=({'Bar': {'enabled' : True}}, "Set Bar.enabled to True"),
69 69 disable=({'Bar': {'enabled' : False}}, "Set Bar.enabled to False"),
70 70 crit=({'Application' : {'log_level' : logging.CRITICAL}},
71 71 "set level=CRITICAL"),
72 72 ))
73 73
74 74 def init_foo(self):
75 75 self.foo = Foo(config=self.config)
76 76
77 77 def init_bar(self):
78 78 self.bar = Bar(config=self.config)
79 79
80 80
81 81 class TestApplication(TestCase):
82 82
83 83 def test_basic(self):
84 84 app = MyApp()
85 self.assertEquals(app.name, u'myapp')
86 self.assertEquals(app.running, False)
87 self.assertEquals(app.classes, [MyApp,Bar,Foo])
88 self.assertEquals(app.config_file, u'')
85 self.assertEqual(app.name, u'myapp')
86 self.assertEqual(app.running, False)
87 self.assertEqual(app.classes, [MyApp,Bar,Foo])
88 self.assertEqual(app.config_file, u'')
89 89
90 90 def test_config(self):
91 91 app = MyApp()
92 92 app.parse_command_line(["--i=10","--Foo.j=10","--enabled=False","--log-level=50"])
93 93 config = app.config
94 self.assertEquals(config.Foo.i, 10)
95 self.assertEquals(config.Foo.j, 10)
96 self.assertEquals(config.Bar.enabled, False)
97 self.assertEquals(config.MyApp.log_level,50)
94 self.assertEqual(config.Foo.i, 10)
95 self.assertEqual(config.Foo.j, 10)
96 self.assertEqual(config.Bar.enabled, False)
97 self.assertEqual(config.MyApp.log_level,50)
98 98
99 99 def test_config_propagation(self):
100 100 app = MyApp()
101 101 app.parse_command_line(["--i=10","--Foo.j=10","--enabled=False","--log-level=50"])
102 102 app.init_foo()
103 103 app.init_bar()
104 self.assertEquals(app.foo.i, 10)
105 self.assertEquals(app.foo.j, 10)
106 self.assertEquals(app.bar.enabled, False)
104 self.assertEqual(app.foo.i, 10)
105 self.assertEqual(app.foo.j, 10)
106 self.assertEqual(app.bar.enabled, False)
107 107
108 108 def test_flags(self):
109 109 app = MyApp()
110 110 app.parse_command_line(["--disable"])
111 111 app.init_bar()
112 self.assertEquals(app.bar.enabled, False)
112 self.assertEqual(app.bar.enabled, False)
113 113 app.parse_command_line(["--enable"])
114 114 app.init_bar()
115 self.assertEquals(app.bar.enabled, True)
115 self.assertEqual(app.bar.enabled, True)
116 116
117 117 def test_aliases(self):
118 118 app = MyApp()
119 119 app.parse_command_line(["--i=5", "--j=10"])
120 120 app.init_foo()
121 self.assertEquals(app.foo.i, 5)
121 self.assertEqual(app.foo.i, 5)
122 122 app.init_foo()
123 self.assertEquals(app.foo.j, 10)
123 self.assertEqual(app.foo.j, 10)
124 124
125 125 def test_flag_clobber(self):
126 126 """test that setting flags doesn't clobber existing settings"""
127 127 app = MyApp()
128 128 app.parse_command_line(["--Bar.b=5", "--disable"])
129 129 app.init_bar()
130 self.assertEquals(app.bar.enabled, False)
131 self.assertEquals(app.bar.b, 5)
130 self.assertEqual(app.bar.enabled, False)
131 self.assertEqual(app.bar.b, 5)
132 132 app.parse_command_line(["--enable", "--Bar.b=10"])
133 133 app.init_bar()
134 self.assertEquals(app.bar.enabled, True)
135 self.assertEquals(app.bar.b, 10)
134 self.assertEqual(app.bar.enabled, True)
135 self.assertEqual(app.bar.b, 10)
136 136
137 137 def test_flatten_flags(self):
138 138 cfg = Config()
139 139 cfg.MyApp.log_level = logging.WARN
140 140 app = MyApp()
141 141 app.update_config(cfg)
142 self.assertEquals(app.log_level, logging.WARN)
143 self.assertEquals(app.config.MyApp.log_level, logging.WARN)
142 self.assertEqual(app.log_level, logging.WARN)
143 self.assertEqual(app.config.MyApp.log_level, logging.WARN)
144 144 app.initialize(["--crit"])
145 self.assertEquals(app.log_level, logging.CRITICAL)
145 self.assertEqual(app.log_level, logging.CRITICAL)
146 146 # this would be app.config.Application.log_level if it failed:
147 self.assertEquals(app.config.MyApp.log_level, logging.CRITICAL)
147 self.assertEqual(app.config.MyApp.log_level, logging.CRITICAL)
148 148
149 149 def test_flatten_aliases(self):
150 150 cfg = Config()
151 151 cfg.MyApp.log_level = logging.WARN
152 152 app = MyApp()
153 153 app.update_config(cfg)
154 self.assertEquals(app.log_level, logging.WARN)
155 self.assertEquals(app.config.MyApp.log_level, logging.WARN)
154 self.assertEqual(app.log_level, logging.WARN)
155 self.assertEqual(app.config.MyApp.log_level, logging.WARN)
156 156 app.initialize(["--log-level", "CRITICAL"])
157 self.assertEquals(app.log_level, logging.CRITICAL)
157 self.assertEqual(app.log_level, logging.CRITICAL)
158 158 # this would be app.config.Application.log_level if it failed:
159 self.assertEquals(app.config.MyApp.log_level, "CRITICAL")
159 self.assertEqual(app.config.MyApp.log_level, "CRITICAL")
160 160
161 161 def test_extra_args(self):
162 162 app = MyApp()
163 163 app.parse_command_line(["--Bar.b=5", 'extra', "--disable", 'args'])
164 164 app.init_bar()
165 self.assertEquals(app.bar.enabled, False)
166 self.assertEquals(app.bar.b, 5)
167 self.assertEquals(app.extra_args, ['extra', 'args'])
165 self.assertEqual(app.bar.enabled, False)
166 self.assertEqual(app.bar.b, 5)
167 self.assertEqual(app.extra_args, ['extra', 'args'])
168 168 app = MyApp()
169 169 app.parse_command_line(["--Bar.b=5", '--', 'extra', "--disable", 'args'])
170 170 app.init_bar()
171 self.assertEquals(app.bar.enabled, True)
172 self.assertEquals(app.bar.b, 5)
173 self.assertEquals(app.extra_args, ['extra', '--disable', 'args'])
171 self.assertEqual(app.bar.enabled, True)
172 self.assertEqual(app.bar.b, 5)
173 self.assertEqual(app.extra_args, ['extra', '--disable', 'args'])
174 174
175 175
@@ -1,183 +1,183 b''
1 1 # encoding: utf-8
2 2 """
3 3 Tests for IPython.config.configurable
4 4
5 5 Authors:
6 6
7 7 * Brian Granger
8 8 * Fernando Perez (design help)
9 9 """
10 10
11 11 #-----------------------------------------------------------------------------
12 12 # Copyright (C) 2008-2011 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-----------------------------------------------------------------------------
17 17
18 18 #-----------------------------------------------------------------------------
19 19 # Imports
20 20 #-----------------------------------------------------------------------------
21 21
22 22 from unittest import TestCase
23 23
24 24 from IPython.config.configurable import (
25 25 Configurable,
26 26 SingletonConfigurable
27 27 )
28 28
29 29 from IPython.utils.traitlets import (
30 30 Integer, Float, Unicode
31 31 )
32 32
33 33 from IPython.config.loader import Config
34 34 from IPython.utils.py3compat import PY3
35 35
36 36 #-----------------------------------------------------------------------------
37 37 # Test cases
38 38 #-----------------------------------------------------------------------------
39 39
40 40
41 41 class MyConfigurable(Configurable):
42 42 a = Integer(1, config=True, help="The integer a.")
43 43 b = Float(1.0, config=True, help="The integer b.")
44 44 c = Unicode('no config')
45 45
46 46
47 47 mc_help=u"""MyConfigurable options
48 48 ----------------------
49 49 --MyConfigurable.a=<Integer>
50 50 Default: 1
51 51 The integer a.
52 52 --MyConfigurable.b=<Float>
53 53 Default: 1.0
54 54 The integer b."""
55 55
56 56 mc_help_inst=u"""MyConfigurable options
57 57 ----------------------
58 58 --MyConfigurable.a=<Integer>
59 59 Current: 5
60 60 The integer a.
61 61 --MyConfigurable.b=<Float>
62 62 Current: 4.0
63 63 The integer b."""
64 64
65 65 # On Python 3, the Integer trait is a synonym for Int
66 66 if PY3:
67 67 mc_help = mc_help.replace(u"<Integer>", u"<Int>")
68 68 mc_help_inst = mc_help_inst.replace(u"<Integer>", u"<Int>")
69 69
70 70 class Foo(Configurable):
71 71 a = Integer(0, config=True, help="The integer a.")
72 72 b = Unicode('nope', config=True)
73 73
74 74
75 75 class Bar(Foo):
76 76 b = Unicode('gotit', config=False, help="The string b.")
77 77 c = Float(config=True, help="The string c.")
78 78
79 79
80 80 class TestConfigurable(TestCase):
81 81
82 82 def test_default(self):
83 83 c1 = Configurable()
84 84 c2 = Configurable(config=c1.config)
85 85 c3 = Configurable(config=c2.config)
86 self.assertEquals(c1.config, c2.config)
87 self.assertEquals(c2.config, c3.config)
86 self.assertEqual(c1.config, c2.config)
87 self.assertEqual(c2.config, c3.config)
88 88
89 89 def test_custom(self):
90 90 config = Config()
91 91 config.foo = 'foo'
92 92 config.bar = 'bar'
93 93 c1 = Configurable(config=config)
94 94 c2 = Configurable(config=c1.config)
95 95 c3 = Configurable(config=c2.config)
96 self.assertEquals(c1.config, config)
97 self.assertEquals(c2.config, config)
98 self.assertEquals(c3.config, config)
96 self.assertEqual(c1.config, config)
97 self.assertEqual(c2.config, config)
98 self.assertEqual(c3.config, config)
99 99 # Test that copies are not made
100 100 self.assert_(c1.config is config)
101 101 self.assert_(c2.config is config)
102 102 self.assert_(c3.config is config)
103 103 self.assert_(c1.config is c2.config)
104 104 self.assert_(c2.config is c3.config)
105 105
106 106 def test_inheritance(self):
107 107 config = Config()
108 108 config.MyConfigurable.a = 2
109 109 config.MyConfigurable.b = 2.0
110 110 c1 = MyConfigurable(config=config)
111 111 c2 = MyConfigurable(config=c1.config)
112 self.assertEquals(c1.a, config.MyConfigurable.a)
113 self.assertEquals(c1.b, config.MyConfigurable.b)
114 self.assertEquals(c2.a, config.MyConfigurable.a)
115 self.assertEquals(c2.b, config.MyConfigurable.b)
112 self.assertEqual(c1.a, config.MyConfigurable.a)
113 self.assertEqual(c1.b, config.MyConfigurable.b)
114 self.assertEqual(c2.a, config.MyConfigurable.a)
115 self.assertEqual(c2.b, config.MyConfigurable.b)
116 116
117 117 def test_parent(self):
118 118 config = Config()
119 119 config.Foo.a = 10
120 120 config.Foo.b = "wow"
121 121 config.Bar.b = 'later'
122 122 config.Bar.c = 100.0
123 123 f = Foo(config=config)
124 124 b = Bar(config=f.config)
125 self.assertEquals(f.a, 10)
126 self.assertEquals(f.b, 'wow')
127 self.assertEquals(b.b, 'gotit')
128 self.assertEquals(b.c, 100.0)
125 self.assertEqual(f.a, 10)
126 self.assertEqual(f.b, 'wow')
127 self.assertEqual(b.b, 'gotit')
128 self.assertEqual(b.c, 100.0)
129 129
130 130 def test_override1(self):
131 131 config = Config()
132 132 config.MyConfigurable.a = 2
133 133 config.MyConfigurable.b = 2.0
134 134 c = MyConfigurable(a=3, config=config)
135 self.assertEquals(c.a, 3)
136 self.assertEquals(c.b, config.MyConfigurable.b)
137 self.assertEquals(c.c, 'no config')
135 self.assertEqual(c.a, 3)
136 self.assertEqual(c.b, config.MyConfigurable.b)
137 self.assertEqual(c.c, 'no config')
138 138
139 139 def test_override2(self):
140 140 config = Config()
141 141 config.Foo.a = 1
142 142 config.Bar.b = 'or' # Up above b is config=False, so this won't do it.
143 143 config.Bar.c = 10.0
144 144 c = Bar(config=config)
145 self.assertEquals(c.a, config.Foo.a)
146 self.assertEquals(c.b, 'gotit')
147 self.assertEquals(c.c, config.Bar.c)
145 self.assertEqual(c.a, config.Foo.a)
146 self.assertEqual(c.b, 'gotit')
147 self.assertEqual(c.c, config.Bar.c)
148 148 c = Bar(a=2, b='and', c=20.0, config=config)
149 self.assertEquals(c.a, 2)
150 self.assertEquals(c.b, 'and')
151 self.assertEquals(c.c, 20.0)
149 self.assertEqual(c.a, 2)
150 self.assertEqual(c.b, 'and')
151 self.assertEqual(c.c, 20.0)
152 152
153 153 def test_help(self):
154 self.assertEquals(MyConfigurable.class_get_help(), mc_help)
154 self.assertEqual(MyConfigurable.class_get_help(), mc_help)
155 155
156 156 def test_help_inst(self):
157 157 inst = MyConfigurable(a=5, b=4)
158 self.assertEquals(MyConfigurable.class_get_help(inst), mc_help_inst)
158 self.assertEqual(MyConfigurable.class_get_help(inst), mc_help_inst)
159 159
160 160
161 161 class TestSingletonConfigurable(TestCase):
162 162
163 163 def test_instance(self):
164 164 from IPython.config.configurable import SingletonConfigurable
165 165 class Foo(SingletonConfigurable): pass
166 self.assertEquals(Foo.initialized(), False)
166 self.assertEqual(Foo.initialized(), False)
167 167 foo = Foo.instance()
168 self.assertEquals(Foo.initialized(), True)
169 self.assertEquals(foo, Foo.instance())
170 self.assertEquals(SingletonConfigurable._instance, None)
168 self.assertEqual(Foo.initialized(), True)
169 self.assertEqual(foo, Foo.instance())
170 self.assertEqual(SingletonConfigurable._instance, None)
171 171
172 172 def test_inheritance(self):
173 173 class Bar(SingletonConfigurable): pass
174 174 class Bam(Bar): pass
175 self.assertEquals(Bar.initialized(), False)
176 self.assertEquals(Bam.initialized(), False)
175 self.assertEqual(Bar.initialized(), False)
176 self.assertEqual(Bam.initialized(), False)
177 177 bam = Bam.instance()
178 178 bam == Bar.instance()
179 self.assertEquals(Bar.initialized(), True)
180 self.assertEquals(Bam.initialized(), True)
181 self.assertEquals(bam, Bam._instance)
182 self.assertEquals(bam, Bar._instance)
183 self.assertEquals(SingletonConfigurable._instance, None)
179 self.assertEqual(Bar.initialized(), True)
180 self.assertEqual(Bam.initialized(), True)
181 self.assertEqual(bam, Bam._instance)
182 self.assertEqual(bam, Bar._instance)
183 self.assertEqual(SingletonConfigurable._instance, None)
@@ -1,263 +1,263 b''
1 1 # encoding: utf-8
2 2 """
3 3 Tests for IPython.config.loader
4 4
5 5 Authors:
6 6
7 7 * Brian Granger
8 8 * Fernando Perez (design help)
9 9 """
10 10
11 11 #-----------------------------------------------------------------------------
12 12 # Copyright (C) 2008-2011 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-----------------------------------------------------------------------------
17 17
18 18 #-----------------------------------------------------------------------------
19 19 # Imports
20 20 #-----------------------------------------------------------------------------
21 21
22 22 import os
23 23 import sys
24 24 from tempfile import mkstemp
25 25 from unittest import TestCase
26 26
27 27 from nose import SkipTest
28 28
29 29 from IPython.testing.tools import mute_warn
30 30
31 31 from IPython.utils.traitlets import Unicode
32 32 from IPython.config.configurable import Configurable
33 33 from IPython.config.loader import (
34 34 Config,
35 35 PyFileConfigLoader,
36 36 KeyValueConfigLoader,
37 37 ArgParseConfigLoader,
38 38 KVArgParseConfigLoader,
39 39 ConfigError
40 40 )
41 41
42 42 #-----------------------------------------------------------------------------
43 43 # Actual tests
44 44 #-----------------------------------------------------------------------------
45 45
46 46
47 47 pyfile = """
48 48 c = get_config()
49 49 c.a=10
50 50 c.b=20
51 51 c.Foo.Bar.value=10
52 52 c.Foo.Bam.value=list(range(10)) # list() is just so it's the same on Python 3
53 53 c.D.C.value='hi there'
54 54 """
55 55
56 56 class TestPyFileCL(TestCase):
57 57
58 58 def test_basic(self):
59 59 fd, fname = mkstemp('.py')
60 60 f = os.fdopen(fd, 'w')
61 61 f.write(pyfile)
62 62 f.close()
63 63 # Unlink the file
64 64 cl = PyFileConfigLoader(fname)
65 65 config = cl.load_config()
66 self.assertEquals(config.a, 10)
67 self.assertEquals(config.b, 20)
68 self.assertEquals(config.Foo.Bar.value, 10)
69 self.assertEquals(config.Foo.Bam.value, range(10))
70 self.assertEquals(config.D.C.value, 'hi there')
66 self.assertEqual(config.a, 10)
67 self.assertEqual(config.b, 20)
68 self.assertEqual(config.Foo.Bar.value, 10)
69 self.assertEqual(config.Foo.Bam.value, range(10))
70 self.assertEqual(config.D.C.value, 'hi there')
71 71
72 72 class MyLoader1(ArgParseConfigLoader):
73 73 def _add_arguments(self, aliases=None, flags=None):
74 74 p = self.parser
75 75 p.add_argument('-f', '--foo', dest='Global.foo', type=str)
76 76 p.add_argument('-b', dest='MyClass.bar', type=int)
77 77 p.add_argument('-n', dest='n', action='store_true')
78 78 p.add_argument('Global.bam', type=str)
79 79
80 80 class MyLoader2(ArgParseConfigLoader):
81 81 def _add_arguments(self, aliases=None, flags=None):
82 82 subparsers = self.parser.add_subparsers(dest='subparser_name')
83 83 subparser1 = subparsers.add_parser('1')
84 84 subparser1.add_argument('-x',dest='Global.x')
85 85 subparser2 = subparsers.add_parser('2')
86 86 subparser2.add_argument('y')
87 87
88 88 class TestArgParseCL(TestCase):
89 89
90 90 def test_basic(self):
91 91 cl = MyLoader1()
92 92 config = cl.load_config('-f hi -b 10 -n wow'.split())
93 self.assertEquals(config.Global.foo, 'hi')
94 self.assertEquals(config.MyClass.bar, 10)
95 self.assertEquals(config.n, True)
96 self.assertEquals(config.Global.bam, 'wow')
93 self.assertEqual(config.Global.foo, 'hi')
94 self.assertEqual(config.MyClass.bar, 10)
95 self.assertEqual(config.n, True)
96 self.assertEqual(config.Global.bam, 'wow')
97 97 config = cl.load_config(['wow'])
98 self.assertEquals(config.keys(), ['Global'])
99 self.assertEquals(config.Global.keys(), ['bam'])
100 self.assertEquals(config.Global.bam, 'wow')
98 self.assertEqual(config.keys(), ['Global'])
99 self.assertEqual(config.Global.keys(), ['bam'])
100 self.assertEqual(config.Global.bam, 'wow')
101 101
102 102 def test_add_arguments(self):
103 103 cl = MyLoader2()
104 104 config = cl.load_config('2 frobble'.split())
105 self.assertEquals(config.subparser_name, '2')
106 self.assertEquals(config.y, 'frobble')
105 self.assertEqual(config.subparser_name, '2')
106 self.assertEqual(config.y, 'frobble')
107 107 config = cl.load_config('1 -x frobble'.split())
108 self.assertEquals(config.subparser_name, '1')
109 self.assertEquals(config.Global.x, 'frobble')
108 self.assertEqual(config.subparser_name, '1')
109 self.assertEqual(config.Global.x, 'frobble')
110 110
111 111 def test_argv(self):
112 112 cl = MyLoader1(argv='-f hi -b 10 -n wow'.split())
113 113 config = cl.load_config()
114 self.assertEquals(config.Global.foo, 'hi')
115 self.assertEquals(config.MyClass.bar, 10)
116 self.assertEquals(config.n, True)
117 self.assertEquals(config.Global.bam, 'wow')
114 self.assertEqual(config.Global.foo, 'hi')
115 self.assertEqual(config.MyClass.bar, 10)
116 self.assertEqual(config.n, True)
117 self.assertEqual(config.Global.bam, 'wow')
118 118
119 119
120 120 class TestKeyValueCL(TestCase):
121 121 klass = KeyValueConfigLoader
122 122
123 123 def test_basic(self):
124 124 cl = self.klass()
125 125 argv = ['--'+s.strip('c.') for s in pyfile.split('\n')[2:-1]]
126 126 with mute_warn():
127 127 config = cl.load_config(argv)
128 self.assertEquals(config.a, 10)
129 self.assertEquals(config.b, 20)
130 self.assertEquals(config.Foo.Bar.value, 10)
131 self.assertEquals(config.Foo.Bam.value, range(10))
132 self.assertEquals(config.D.C.value, 'hi there')
128 self.assertEqual(config.a, 10)
129 self.assertEqual(config.b, 20)
130 self.assertEqual(config.Foo.Bar.value, 10)
131 self.assertEqual(config.Foo.Bam.value, range(10))
132 self.assertEqual(config.D.C.value, 'hi there')
133 133
134 134 def test_expanduser(self):
135 135 cl = self.klass()
136 136 argv = ['--a=~/1/2/3', '--b=~', '--c=~/', '--d="~/"']
137 137 with mute_warn():
138 138 config = cl.load_config(argv)
139 self.assertEquals(config.a, os.path.expanduser('~/1/2/3'))
140 self.assertEquals(config.b, os.path.expanduser('~'))
141 self.assertEquals(config.c, os.path.expanduser('~/'))
142 self.assertEquals(config.d, '~/')
139 self.assertEqual(config.a, os.path.expanduser('~/1/2/3'))
140 self.assertEqual(config.b, os.path.expanduser('~'))
141 self.assertEqual(config.c, os.path.expanduser('~/'))
142 self.assertEqual(config.d, '~/')
143 143
144 144 def test_extra_args(self):
145 145 cl = self.klass()
146 146 with mute_warn():
147 147 config = cl.load_config(['--a=5', 'b', '--c=10', 'd'])
148 self.assertEquals(cl.extra_args, ['b', 'd'])
149 self.assertEquals(config.a, 5)
150 self.assertEquals(config.c, 10)
148 self.assertEqual(cl.extra_args, ['b', 'd'])
149 self.assertEqual(config.a, 5)
150 self.assertEqual(config.c, 10)
151 151 with mute_warn():
152 152 config = cl.load_config(['--', '--a=5', '--c=10'])
153 self.assertEquals(cl.extra_args, ['--a=5', '--c=10'])
153 self.assertEqual(cl.extra_args, ['--a=5', '--c=10'])
154 154
155 155 def test_unicode_args(self):
156 156 cl = self.klass()
157 157 argv = [u'--a=épsîlön']
158 158 with mute_warn():
159 159 config = cl.load_config(argv)
160 self.assertEquals(config.a, u'épsîlön')
160 self.assertEqual(config.a, u'épsîlön')
161 161
162 162 def test_unicode_bytes_args(self):
163 163 uarg = u'--a=é'
164 164 try:
165 165 barg = uarg.encode(sys.stdin.encoding)
166 166 except (TypeError, UnicodeEncodeError):
167 167 raise SkipTest("sys.stdin.encoding can't handle 'é'")
168 168
169 169 cl = self.klass()
170 170 with mute_warn():
171 171 config = cl.load_config([barg])
172 self.assertEquals(config.a, u'é')
172 self.assertEqual(config.a, u'é')
173 173
174 174 def test_unicode_alias(self):
175 175 cl = self.klass()
176 176 argv = [u'--a=épsîlön']
177 177 with mute_warn():
178 178 config = cl.load_config(argv, aliases=dict(a='A.a'))
179 self.assertEquals(config.A.a, u'épsîlön')
179 self.assertEqual(config.A.a, u'épsîlön')
180 180
181 181
182 182 class TestArgParseKVCL(TestKeyValueCL):
183 183 klass = KVArgParseConfigLoader
184 184
185 185 def test_expanduser2(self):
186 186 cl = self.klass()
187 187 argv = ['-a', '~/1/2/3', '--b', "'~/1/2/3'"]
188 188 with mute_warn():
189 189 config = cl.load_config(argv, aliases=dict(a='A.a', b='A.b'))
190 self.assertEquals(config.A.a, os.path.expanduser('~/1/2/3'))
191 self.assertEquals(config.A.b, '~/1/2/3')
190 self.assertEqual(config.A.a, os.path.expanduser('~/1/2/3'))
191 self.assertEqual(config.A.b, '~/1/2/3')
192 192
193 193 def test_eval(self):
194 194 cl = self.klass()
195 195 argv = ['-c', 'a=5']
196 196 with mute_warn():
197 197 config = cl.load_config(argv, aliases=dict(c='A.c'))
198 self.assertEquals(config.A.c, u"a=5")
198 self.assertEqual(config.A.c, u"a=5")
199 199
200 200
201 201 class TestConfig(TestCase):
202 202
203 203 def test_setget(self):
204 204 c = Config()
205 205 c.a = 10
206 self.assertEquals(c.a, 10)
207 self.assertEquals('b' in c, False)
206 self.assertEqual(c.a, 10)
207 self.assertEqual('b' in c, False)
208 208
209 209 def test_auto_section(self):
210 210 c = Config()
211 self.assertEquals('A' in c, True)
212 self.assertEquals(c._has_section('A'), False)
211 self.assertEqual('A' in c, True)
212 self.assertEqual(c._has_section('A'), False)
213 213 A = c.A
214 214 A.foo = 'hi there'
215 self.assertEquals(c._has_section('A'), True)
216 self.assertEquals(c.A.foo, 'hi there')
215 self.assertEqual(c._has_section('A'), True)
216 self.assertEqual(c.A.foo, 'hi there')
217 217 del c.A
218 self.assertEquals(len(c.A.keys()),0)
218 self.assertEqual(len(c.A.keys()),0)
219 219
220 220 def test_merge_doesnt_exist(self):
221 221 c1 = Config()
222 222 c2 = Config()
223 223 c2.bar = 10
224 224 c2.Foo.bar = 10
225 225 c1._merge(c2)
226 self.assertEquals(c1.Foo.bar, 10)
227 self.assertEquals(c1.bar, 10)
226 self.assertEqual(c1.Foo.bar, 10)
227 self.assertEqual(c1.bar, 10)
228 228 c2.Bar.bar = 10
229 229 c1._merge(c2)
230 self.assertEquals(c1.Bar.bar, 10)
230 self.assertEqual(c1.Bar.bar, 10)
231 231
232 232 def test_merge_exists(self):
233 233 c1 = Config()
234 234 c2 = Config()
235 235 c1.Foo.bar = 10
236 236 c1.Foo.bam = 30
237 237 c2.Foo.bar = 20
238 238 c2.Foo.wow = 40
239 239 c1._merge(c2)
240 self.assertEquals(c1.Foo.bam, 30)
241 self.assertEquals(c1.Foo.bar, 20)
242 self.assertEquals(c1.Foo.wow, 40)
240 self.assertEqual(c1.Foo.bam, 30)
241 self.assertEqual(c1.Foo.bar, 20)
242 self.assertEqual(c1.Foo.wow, 40)
243 243 c2.Foo.Bam.bam = 10
244 244 c1._merge(c2)
245 self.assertEquals(c1.Foo.Bam.bam, 10)
245 self.assertEqual(c1.Foo.Bam.bam, 10)
246 246
247 247 def test_deepcopy(self):
248 248 c1 = Config()
249 249 c1.Foo.bar = 10
250 250 c1.Foo.bam = 30
251 251 c1.a = 'asdf'
252 252 c1.b = range(10)
253 253 import copy
254 254 c2 = copy.deepcopy(c1)
255 self.assertEquals(c1, c2)
255 self.assertEqual(c1, c2)
256 256 self.assert_(c1 is not c2)
257 257 self.assert_(c1.Foo is not c2.Foo)
258 258
259 259 def test_builtin(self):
260 260 c1 = Config()
261 261 exec 'foo = True' in c1
262 self.assertEquals(c1.foo, True)
262 self.assertEqual(c1.foo, True)
263 263 self.assertRaises(ConfigError, setattr, c1, 'ValueError', 10)
@@ -1,407 +1,407 b''
1 1 # -*- coding: utf-8 -*-
2 2 """Tests for the key interactiveshell module.
3 3
4 4 Historically the main classes in interactiveshell have been under-tested. This
5 5 module should grow as many single-method tests as possible to trap many of the
6 6 recurring bugs we seem to encounter with high-level interaction.
7 7
8 8 Authors
9 9 -------
10 10 * Fernando Perez
11 11 """
12 12 #-----------------------------------------------------------------------------
13 13 # Copyright (C) 2011 The IPython Development Team
14 14 #
15 15 # Distributed under the terms of the BSD License. The full license is in
16 16 # the file COPYING, distributed as part of this software.
17 17 #-----------------------------------------------------------------------------
18 18
19 19 #-----------------------------------------------------------------------------
20 20 # Imports
21 21 #-----------------------------------------------------------------------------
22 22 # stdlib
23 23 import os
24 24 import shutil
25 25 import sys
26 26 import tempfile
27 27 import unittest
28 28 from os.path import join
29 29 from StringIO import StringIO
30 30
31 31 # third-party
32 32 import nose.tools as nt
33 33
34 34 # Our own
35 35 from IPython.testing.decorators import skipif
36 36 from IPython.utils import io
37 37
38 38 #-----------------------------------------------------------------------------
39 39 # Globals
40 40 #-----------------------------------------------------------------------------
41 41 # This is used by every single test, no point repeating it ad nauseam
42 42 ip = get_ipython()
43 43
44 44 #-----------------------------------------------------------------------------
45 45 # Tests
46 46 #-----------------------------------------------------------------------------
47 47
48 48 class InteractiveShellTestCase(unittest.TestCase):
49 49 def test_naked_string_cells(self):
50 50 """Test that cells with only naked strings are fully executed"""
51 51 # First, single-line inputs
52 52 ip.run_cell('"a"\n')
53 self.assertEquals(ip.user_ns['_'], 'a')
53 self.assertEqual(ip.user_ns['_'], 'a')
54 54 # And also multi-line cells
55 55 ip.run_cell('"""a\nb"""\n')
56 self.assertEquals(ip.user_ns['_'], 'a\nb')
56 self.assertEqual(ip.user_ns['_'], 'a\nb')
57 57
58 58 def test_run_empty_cell(self):
59 59 """Just make sure we don't get a horrible error with a blank
60 60 cell of input. Yes, I did overlook that."""
61 61 old_xc = ip.execution_count
62 62 ip.run_cell('')
63 self.assertEquals(ip.execution_count, old_xc)
63 self.assertEqual(ip.execution_count, old_xc)
64 64
65 65 def test_run_cell_multiline(self):
66 66 """Multi-block, multi-line cells must execute correctly.
67 67 """
68 68 src = '\n'.join(["x=1",
69 69 "y=2",
70 70 "if 1:",
71 71 " x += 1",
72 72 " y += 1",])
73 73 ip.run_cell(src)
74 self.assertEquals(ip.user_ns['x'], 2)
75 self.assertEquals(ip.user_ns['y'], 3)
74 self.assertEqual(ip.user_ns['x'], 2)
75 self.assertEqual(ip.user_ns['y'], 3)
76 76
77 77 def test_multiline_string_cells(self):
78 78 "Code sprinkled with multiline strings should execute (GH-306)"
79 79 ip.run_cell('tmp=0')
80 self.assertEquals(ip.user_ns['tmp'], 0)
80 self.assertEqual(ip.user_ns['tmp'], 0)
81 81 ip.run_cell('tmp=1;"""a\nb"""\n')
82 self.assertEquals(ip.user_ns['tmp'], 1)
82 self.assertEqual(ip.user_ns['tmp'], 1)
83 83
84 84 def test_dont_cache_with_semicolon(self):
85 85 "Ending a line with semicolon should not cache the returned object (GH-307)"
86 86 oldlen = len(ip.user_ns['Out'])
87 87 a = ip.run_cell('1;', store_history=True)
88 88 newlen = len(ip.user_ns['Out'])
89 self.assertEquals(oldlen, newlen)
89 self.assertEqual(oldlen, newlen)
90 90 #also test the default caching behavior
91 91 ip.run_cell('1', store_history=True)
92 92 newlen = len(ip.user_ns['Out'])
93 self.assertEquals(oldlen+1, newlen)
93 self.assertEqual(oldlen+1, newlen)
94 94
95 95 def test_In_variable(self):
96 96 "Verify that In variable grows with user input (GH-284)"
97 97 oldlen = len(ip.user_ns['In'])
98 98 ip.run_cell('1;', store_history=True)
99 99 newlen = len(ip.user_ns['In'])
100 self.assertEquals(oldlen+1, newlen)
101 self.assertEquals(ip.user_ns['In'][-1],'1;')
100 self.assertEqual(oldlen+1, newlen)
101 self.assertEqual(ip.user_ns['In'][-1],'1;')
102 102
103 103 def test_magic_names_in_string(self):
104 104 ip.run_cell('a = """\n%exit\n"""')
105 self.assertEquals(ip.user_ns['a'], '\n%exit\n')
105 self.assertEqual(ip.user_ns['a'], '\n%exit\n')
106 106
107 107 def test_alias_crash(self):
108 108 """Errors in prefilter can't crash IPython"""
109 109 ip.run_cell('%alias parts echo first %s second %s')
110 110 # capture stderr:
111 111 save_err = io.stderr
112 112 io.stderr = StringIO()
113 113 ip.run_cell('parts 1')
114 114 err = io.stderr.getvalue()
115 115 io.stderr = save_err
116 self.assertEquals(err.split(':')[0], 'ERROR')
116 self.assertEqual(err.split(':')[0], 'ERROR')
117 117
118 118 def test_trailing_newline(self):
119 119 """test that running !(command) does not raise a SyntaxError"""
120 120 ip.run_cell('!(true)\n', False)
121 121 ip.run_cell('!(true)\n\n\n', False)
122 122
123 123 def test_gh_597(self):
124 124 """Pretty-printing lists of objects with non-ascii reprs may cause
125 125 problems."""
126 126 class Spam(object):
127 127 def __repr__(self):
128 128 return "\xe9"*50
129 129 import IPython.core.formatters
130 130 f = IPython.core.formatters.PlainTextFormatter()
131 131 f([Spam(),Spam()])
132 132
133 133
134 134 def test_future_flags(self):
135 135 """Check that future flags are used for parsing code (gh-777)"""
136 136 ip.run_cell('from __future__ import print_function')
137 137 try:
138 138 ip.run_cell('prfunc_return_val = print(1,2, sep=" ")')
139 139 assert 'prfunc_return_val' in ip.user_ns
140 140 finally:
141 141 # Reset compiler flags so we don't mess up other tests.
142 142 ip.compile.reset_compiler_flags()
143 143
144 144 def test_future_unicode(self):
145 145 """Check that unicode_literals is imported from __future__ (gh #786)"""
146 146 try:
147 147 ip.run_cell(u'byte_str = "a"')
148 148 assert isinstance(ip.user_ns['byte_str'], str) # string literals are byte strings by default
149 149 ip.run_cell('from __future__ import unicode_literals')
150 150 ip.run_cell(u'unicode_str = "a"')
151 151 assert isinstance(ip.user_ns['unicode_str'], unicode) # strings literals are now unicode
152 152 finally:
153 153 # Reset compiler flags so we don't mess up other tests.
154 154 ip.compile.reset_compiler_flags()
155 155
156 156 def test_can_pickle(self):
157 157 "Can we pickle objects defined interactively (GH-29)"
158 158 ip = get_ipython()
159 159 ip.reset()
160 160 ip.run_cell(("class Mylist(list):\n"
161 161 " def __init__(self,x=[]):\n"
162 162 " list.__init__(self,x)"))
163 163 ip.run_cell("w=Mylist([1,2,3])")
164 164
165 165 from cPickle import dumps
166 166
167 167 # We need to swap in our main module - this is only necessary
168 168 # inside the test framework, because IPython puts the interactive module
169 169 # in place (but the test framework undoes this).
170 170 _main = sys.modules['__main__']
171 171 sys.modules['__main__'] = ip.user_module
172 172 try:
173 173 res = dumps(ip.user_ns["w"])
174 174 finally:
175 175 sys.modules['__main__'] = _main
176 176 self.assertTrue(isinstance(res, bytes))
177 177
178 178 def test_global_ns(self):
179 179 "Code in functions must be able to access variables outside them."
180 180 ip = get_ipython()
181 181 ip.run_cell("a = 10")
182 182 ip.run_cell(("def f(x):\n"
183 183 " return x + a"))
184 184 ip.run_cell("b = f(12)")
185 185 self.assertEqual(ip.user_ns["b"], 22)
186 186
187 187 def test_bad_custom_tb(self):
188 188 """Check that InteractiveShell is protected from bad custom exception handlers"""
189 189 from IPython.utils import io
190 190 save_stderr = io.stderr
191 191 try:
192 192 # capture stderr
193 193 io.stderr = StringIO()
194 194 ip.set_custom_exc((IOError,), lambda etype,value,tb: 1/0)
195 self.assertEquals(ip.custom_exceptions, (IOError,))
195 self.assertEqual(ip.custom_exceptions, (IOError,))
196 196 ip.run_cell(u'raise IOError("foo")')
197 self.assertEquals(ip.custom_exceptions, ())
197 self.assertEqual(ip.custom_exceptions, ())
198 198 self.assertTrue("Custom TB Handler failed" in io.stderr.getvalue())
199 199 finally:
200 200 io.stderr = save_stderr
201 201
202 202 def test_bad_custom_tb_return(self):
203 203 """Check that InteractiveShell is protected from bad return types in custom exception handlers"""
204 204 from IPython.utils import io
205 205 save_stderr = io.stderr
206 206 try:
207 207 # capture stderr
208 208 io.stderr = StringIO()
209 209 ip.set_custom_exc((NameError,),lambda etype,value,tb, tb_offset=None: 1)
210 self.assertEquals(ip.custom_exceptions, (NameError,))
210 self.assertEqual(ip.custom_exceptions, (NameError,))
211 211 ip.run_cell(u'a=abracadabra')
212 self.assertEquals(ip.custom_exceptions, ())
212 self.assertEqual(ip.custom_exceptions, ())
213 213 self.assertTrue("Custom TB Handler failed" in io.stderr.getvalue())
214 214 finally:
215 215 io.stderr = save_stderr
216 216
217 217 def test_drop_by_id(self):
218 218 myvars = {"a":object(), "b":object(), "c": object()}
219 219 ip.push(myvars, interactive=False)
220 220 for name in myvars:
221 221 assert name in ip.user_ns, name
222 222 assert name in ip.user_ns_hidden, name
223 223 ip.user_ns['b'] = 12
224 224 ip.drop_by_id(myvars)
225 225 for name in ["a", "c"]:
226 226 assert name not in ip.user_ns, name
227 227 assert name not in ip.user_ns_hidden, name
228 228 assert ip.user_ns['b'] == 12
229 229 ip.reset()
230 230
231 231 def test_var_expand(self):
232 232 ip.user_ns['f'] = u'Ca\xf1o'
233 233 self.assertEqual(ip.var_expand(u'echo $f'), u'echo Ca\xf1o')
234 234 self.assertEqual(ip.var_expand(u'echo {f}'), u'echo Ca\xf1o')
235 235 self.assertEqual(ip.var_expand(u'echo {f[:-1]}'), u'echo Ca\xf1')
236 236 self.assertEqual(ip.var_expand(u'echo {1*2}'), u'echo 2')
237 237
238 238 ip.user_ns['f'] = b'Ca\xc3\xb1o'
239 239 # This should not raise any exception:
240 240 ip.var_expand(u'echo $f')
241 241
242 242 def test_var_expand_local(self):
243 243 """Test local variable expansion in !system and %magic calls"""
244 244 # !system
245 245 ip.run_cell('def test():\n'
246 246 ' lvar = "ttt"\n'
247 247 ' ret = !echo {lvar}\n'
248 248 ' return ret[0]\n')
249 249 res = ip.user_ns['test']()
250 250 nt.assert_in('ttt', res)
251 251
252 252 # %magic
253 253 ip.run_cell('def makemacro():\n'
254 254 ' macroname = "macro_var_expand_locals"\n'
255 255 ' %macro {macroname} codestr\n')
256 256 ip.user_ns['codestr'] = "str(12)"
257 257 ip.run_cell('makemacro()')
258 258 nt.assert_in('macro_var_expand_locals', ip.user_ns)
259 259
260 260 def test_bad_var_expand(self):
261 261 """var_expand on invalid formats shouldn't raise"""
262 262 # SyntaxError
263 263 self.assertEqual(ip.var_expand(u"{'a':5}"), u"{'a':5}")
264 264 # NameError
265 265 self.assertEqual(ip.var_expand(u"{asdf}"), u"{asdf}")
266 266 # ZeroDivisionError
267 267 self.assertEqual(ip.var_expand(u"{1/0}"), u"{1/0}")
268 268
269 269 def test_silent_nopostexec(self):
270 270 """run_cell(silent=True) doesn't invoke post-exec funcs"""
271 271 d = dict(called=False)
272 272 def set_called():
273 273 d['called'] = True
274 274
275 275 ip.register_post_execute(set_called)
276 276 ip.run_cell("1", silent=True)
277 277 self.assertFalse(d['called'])
278 278 # double-check that non-silent exec did what we expected
279 279 # silent to avoid
280 280 ip.run_cell("1")
281 281 self.assertTrue(d['called'])
282 282 # remove post-exec
283 283 ip._post_execute.pop(set_called)
284 284
285 285 def test_silent_noadvance(self):
286 286 """run_cell(silent=True) doesn't advance execution_count"""
287 287 ec = ip.execution_count
288 288 # silent should force store_history=False
289 289 ip.run_cell("1", store_history=True, silent=True)
290 290
291 self.assertEquals(ec, ip.execution_count)
291 self.assertEqual(ec, ip.execution_count)
292 292 # double-check that non-silent exec did what we expected
293 293 # silent to avoid
294 294 ip.run_cell("1", store_history=True)
295 self.assertEquals(ec+1, ip.execution_count)
295 self.assertEqual(ec+1, ip.execution_count)
296 296
297 297 def test_silent_nodisplayhook(self):
298 298 """run_cell(silent=True) doesn't trigger displayhook"""
299 299 d = dict(called=False)
300 300
301 301 trap = ip.display_trap
302 302 save_hook = trap.hook
303 303
304 304 def failing_hook(*args, **kwargs):
305 305 d['called'] = True
306 306
307 307 try:
308 308 trap.hook = failing_hook
309 309 ip.run_cell("1", silent=True)
310 310 self.assertFalse(d['called'])
311 311 # double-check that non-silent exec did what we expected
312 312 # silent to avoid
313 313 ip.run_cell("1")
314 314 self.assertTrue(d['called'])
315 315 finally:
316 316 trap.hook = save_hook
317 317
318 318 @skipif(sys.version_info[0] >= 3, "softspace removed in py3")
319 319 def test_print_softspace(self):
320 320 """Verify that softspace is handled correctly when executing multiple
321 321 statements.
322 322
323 323 In [1]: print 1; print 2
324 324 1
325 325 2
326 326
327 327 In [2]: print 1,; print 2
328 328 1 2
329 329 """
330 330
331 331 def test_ofind_line_magic(self):
332 332 from IPython.core.magic import register_line_magic
333 333
334 334 @register_line_magic
335 335 def lmagic(line):
336 336 "A line magic"
337 337
338 338 # Get info on line magic
339 339 lfind = ip._ofind('lmagic')
340 340 info = dict(found=True, isalias=False, ismagic=True,
341 341 namespace = 'IPython internal', obj= lmagic.__wrapped__,
342 342 parent = None)
343 343 nt.assert_equal(lfind, info)
344 344
345 345 def test_ofind_cell_magic(self):
346 346 from IPython.core.magic import register_cell_magic
347 347
348 348 @register_cell_magic
349 349 def cmagic(line, cell):
350 350 "A cell magic"
351 351
352 352 # Get info on cell magic
353 353 find = ip._ofind('cmagic')
354 354 info = dict(found=True, isalias=False, ismagic=True,
355 355 namespace = 'IPython internal', obj= cmagic.__wrapped__,
356 356 parent = None)
357 357 nt.assert_equal(find, info)
358 358
359 359 def test_custom_exception(self):
360 360 called = []
361 361 def my_handler(shell, etype, value, tb, tb_offset=None):
362 362 called.append(etype)
363 363 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
364 364
365 365 ip.set_custom_exc((ValueError,), my_handler)
366 366 try:
367 367 ip.run_cell("raise ValueError('test')")
368 368 # Check that this was called, and only once.
369 369 self.assertEqual(called, [ValueError])
370 370 finally:
371 371 # Reset the custom exception hook
372 372 ip.set_custom_exc((), None)
373 373
374 374
375 375 class TestSafeExecfileNonAsciiPath(unittest.TestCase):
376 376
377 377 def setUp(self):
378 378 self.BASETESTDIR = tempfile.mkdtemp()
379 379 self.TESTDIR = join(self.BASETESTDIR, u"åäö")
380 380 os.mkdir(self.TESTDIR)
381 381 with open(join(self.TESTDIR, u"åäötestscript.py"), "w") as sfile:
382 382 sfile.write("pass\n")
383 383 self.oldpath = os.getcwdu()
384 384 os.chdir(self.TESTDIR)
385 385 self.fname = u"åäötestscript.py"
386 386
387 387 def tearDown(self):
388 388 os.chdir(self.oldpath)
389 389 shutil.rmtree(self.BASETESTDIR)
390 390
391 391 def test_1(self):
392 392 """Test safe_execfile with non-ascii path
393 393 """
394 394 ip.safe_execfile(self.fname, {}, raise_exceptions=True)
395 395
396 396
397 397 class TestSystemRaw(unittest.TestCase):
398 398 def test_1(self):
399 399 """Test system_raw with non-ascii cmd
400 400 """
401 401 cmd = ur'''python -c "'åäö'" '''
402 402 ip.system_raw(cmd)
403 403
404 404
405 405 def test__IPYTHON__():
406 406 # This shouldn't raise a NameError, that's all
407 407 __IPYTHON__
@@ -1,46 +1,46 b''
1 1 """Tests for plugin.py"""
2 2
3 3 #-----------------------------------------------------------------------------
4 4 # Imports
5 5 #-----------------------------------------------------------------------------
6 6
7 7 from unittest import TestCase
8 8
9 9 from IPython.core.plugin import Plugin, PluginManager
10 10
11 11 #-----------------------------------------------------------------------------
12 12 # Tests
13 13 #-----------------------------------------------------------------------------
14 14
15 15 class FooPlugin(Plugin):
16 16 pass
17 17
18 18
19 19 class BarPlugin(Plugin):
20 20 pass
21 21
22 22
23 23 class BadPlugin(object):
24 24 pass
25 25
26 26
27 27 class PluginTest(TestCase):
28 28
29 29 def setUp(self):
30 30 self.manager = PluginManager()
31 31
32 32 def test_register_get(self):
33 self.assertEquals(None, self.manager.get_plugin('foo'))
33 self.assertEqual(None, self.manager.get_plugin('foo'))
34 34 foo = FooPlugin()
35 35 self.manager.register_plugin('foo', foo)
36 self.assertEquals(foo, self.manager.get_plugin('foo'))
36 self.assertEqual(foo, self.manager.get_plugin('foo'))
37 37 bar = BarPlugin()
38 38 self.assertRaises(KeyError, self.manager.register_plugin, 'foo', bar)
39 39 bad = BadPlugin()
40 40 self.assertRaises(TypeError, self.manager.register_plugin, 'bad')
41 41
42 42 def test_unregister(self):
43 43 foo = FooPlugin()
44 44 self.manager.register_plugin('foo', foo)
45 45 self.manager.unregister_plugin('foo')
46 self.assertEquals(None, self.manager.get_plugin('foo'))
46 self.assertEqual(None, self.manager.get_plugin('foo'))
@@ -1,111 +1,111 b''
1 1 # -*- coding: utf-8
2 2 """Tests for prompt generation."""
3 3
4 4 import unittest
5 5
6 6 import os
7 7 import nose.tools as nt
8 8
9 9 from IPython.testing import tools as tt, decorators as dec
10 10 from IPython.core.prompts import PromptManager, LazyEvaluate
11 11 from IPython.testing.globalipapp import get_ipython
12 12 from IPython.utils import py3compat
13 13 from IPython.utils.tempdir import TemporaryDirectory
14 14
15 15 ip = get_ipython()
16 16
17 17
18 18 class PromptTests(unittest.TestCase):
19 19 def setUp(self):
20 20 self.pm = PromptManager(shell=ip, config=ip.config)
21 21
22 22 def test_multiline_prompt(self):
23 23 self.pm.in_template = "[In]\n>>>"
24 24 self.pm.render('in')
25 25 self.assertEqual(self.pm.width, 3)
26 26 self.assertEqual(self.pm.txtwidth, 3)
27 27
28 28 self.pm.in_template = '[In]\n'
29 29 self.pm.render('in')
30 30 self.assertEqual(self.pm.width, 0)
31 31 self.assertEqual(self.pm.txtwidth, 0)
32 32
33 33 def test_translate_abbreviations(self):
34 34 def do_translate(template):
35 35 self.pm.in_template = template
36 36 return self.pm.templates['in']
37 37
38 38 pairs = [(r'%n>', '{color.number}{count}{color.prompt}>'),
39 39 (r'\T', '{time}'),
40 40 (r'\n', '\n')
41 41 ]
42 42
43 43 tt.check_pairs(do_translate, pairs)
44 44
45 45 def test_user_ns(self):
46 46 self.pm.color_scheme = 'NoColor'
47 47 ip.ex("foo='bar'")
48 48 self.pm.in_template = "In [{foo}]"
49 49 prompt = self.pm.render('in')
50 self.assertEquals(prompt, u'In [bar]')
50 self.assertEqual(prompt, u'In [bar]')
51 51
52 52 def test_builtins(self):
53 53 self.pm.color_scheme = 'NoColor'
54 54 self.pm.in_template = "In [{int}]"
55 55 prompt = self.pm.render('in')
56 self.assertEquals(prompt, u"In [%r]" % int)
56 self.assertEqual(prompt, u"In [%r]" % int)
57 57
58 58 def test_undefined(self):
59 59 self.pm.color_scheme = 'NoColor'
60 60 self.pm.in_template = "In [{foo_dne}]"
61 61 prompt = self.pm.render('in')
62 self.assertEquals(prompt, u"In [<ERROR: 'foo_dne' not found>]")
62 self.assertEqual(prompt, u"In [<ERROR: 'foo_dne' not found>]")
63 63
64 64 def test_render(self):
65 65 self.pm.in_template = r'\#>'
66 66 self.assertEqual(self.pm.render('in',color=False), '%d>' % ip.execution_count)
67 67
68 68 def test_render_unicode_cwd(self):
69 69 save = os.getcwdu()
70 70 with TemporaryDirectory(u'ünicødé') as td:
71 71 os.chdir(td)
72 72 self.pm.in_template = r'\w [\#]'
73 73 p = self.pm.render('in', color=False)
74 self.assertEquals(p, u"%s [%i]" % (os.getcwdu(), ip.execution_count))
74 self.assertEqual(p, u"%s [%i]" % (os.getcwdu(), ip.execution_count))
75 75 os.chdir(save)
76 76
77 77 def test_lazy_eval_unicode(self):
78 78 u = u'ünicødé'
79 79 lz = LazyEvaluate(lambda : u)
80 80 # str(lz) would fail
81 self.assertEquals(unicode(lz), u)
82 self.assertEquals(format(lz), u)
81 self.assertEqual(unicode(lz), u)
82 self.assertEqual(format(lz), u)
83 83
84 84 def test_lazy_eval_nonascii_bytes(self):
85 85 u = u'ünicødé'
86 86 b = u.encode('utf8')
87 87 lz = LazyEvaluate(lambda : b)
88 88 # unicode(lz) would fail
89 self.assertEquals(str(lz), str(b))
90 self.assertEquals(format(lz), str(b))
89 self.assertEqual(str(lz), str(b))
90 self.assertEqual(format(lz), str(b))
91 91
92 92 def test_lazy_eval_float(self):
93 93 f = 0.503
94 94 lz = LazyEvaluate(lambda : f)
95 95
96 self.assertEquals(str(lz), str(f))
97 self.assertEquals(unicode(lz), unicode(f))
98 self.assertEquals(format(lz), str(f))
99 self.assertEquals(format(lz, '.1'), '0.5')
96 self.assertEqual(str(lz), str(f))
97 self.assertEqual(unicode(lz), unicode(f))
98 self.assertEqual(format(lz), str(f))
99 self.assertEqual(format(lz, '.1'), '0.5')
100 100
101 101 @dec.skip_win32
102 102 def test_cwd_x(self):
103 103 self.pm.in_template = r"\X0"
104 104 save = os.getcwdu()
105 105 os.chdir(os.path.expanduser('~'))
106 106 p = self.pm.render('in', color=False)
107 107 try:
108 self.assertEquals(p, '~')
108 self.assertEqual(p, '~')
109 109 finally:
110 110 os.chdir(save)
111 111
@@ -1,27 +1,27 b''
1 1 """Tests for the notebook kernel and session manager."""
2 2
3 3 from unittest import TestCase
4 4
5 5 from IPython.frontend.html.notebook.kernelmanager import MultiKernelManager
6 6
7 7 class TestKernelManager(TestCase):
8 8
9 9 def test_km_lifecycle(self):
10 10 km = MultiKernelManager()
11 11 kid = km.start_kernel()
12 12 self.assert_(kid in km)
13 self.assertEquals(len(km),1)
13 self.assertEqual(len(km),1)
14 14 km.kill_kernel(kid)
15 15 self.assert_(not kid in km)
16 16
17 17 kid = km.start_kernel()
18 self.assertEquals('127.0.0.1',km.get_kernel_ip(kid))
18 self.assertEqual('127.0.0.1',km.get_kernel_ip(kid))
19 19 port_dict = km.get_kernel_ports(kid)
20 20 self.assert_('stdin_port' in port_dict)
21 21 self.assert_('iopub_port' in port_dict)
22 22 self.assert_('shell_port' in port_dict)
23 23 self.assert_('hb_port' in port_dict)
24 24 km.get_kernel(kid)
25 25 km.kill_kernel(kid)
26 26
27 27
@@ -1,34 +1,34 b''
1 1 """Tests for the notebook manager."""
2 2
3 3 import os
4 4 from unittest import TestCase
5 5 from tempfile import NamedTemporaryFile
6 6
7 7 from IPython.utils.tempdir import TemporaryDirectory
8 8 from IPython.utils.traitlets import TraitError
9 9
10 10 from IPython.frontend.html.notebook.notebookmanager import NotebookManager
11 11
12 12 class TestNotebookManager(TestCase):
13 13
14 14 def test_nb_dir(self):
15 15 with TemporaryDirectory() as td:
16 16 km = NotebookManager(notebook_dir=td)
17 self.assertEquals(km.notebook_dir, td)
17 self.assertEqual(km.notebook_dir, td)
18 18
19 19 def test_create_nb_dir(self):
20 20 with TemporaryDirectory() as td:
21 21 nbdir = os.path.join(td, 'notebooks')
22 22 km = NotebookManager(notebook_dir=nbdir)
23 self.assertEquals(km.notebook_dir, nbdir)
23 self.assertEqual(km.notebook_dir, nbdir)
24 24
25 25 def test_missing_nb_dir(self):
26 26 with TemporaryDirectory() as td:
27 27 nbdir = os.path.join(td, 'notebook', 'dir', 'is', 'missing')
28 28 self.assertRaises(TraitError, NotebookManager, notebook_dir=nbdir)
29 29
30 30 def test_invalid_nb_dir(self):
31 31 with NamedTemporaryFile() as tf:
32 32 self.assertRaises(TraitError, NotebookManager, notebook_dir=tf.name)
33 33
34 34
@@ -1,171 +1,171 b''
1 1 # Standard library imports
2 2 import unittest
3 3
4 4 # Local imports
5 5 from IPython.frontend.qt.console.ansi_code_processor import AnsiCodeProcessor
6 6
7 7
8 8 class TestAnsiCodeProcessor(unittest.TestCase):
9 9
10 10 def setUp(self):
11 11 self.processor = AnsiCodeProcessor()
12 12
13 13 def test_clear(self):
14 14 """ Do control sequences for clearing the console work?
15 15 """
16 16 string = '\x1b[2J\x1b[K'
17 17 i = -1
18 18 for i, substring in enumerate(self.processor.split_string(string)):
19 19 if i == 0:
20 self.assertEquals(len(self.processor.actions), 1)
20 self.assertEqual(len(self.processor.actions), 1)
21 21 action = self.processor.actions[0]
22 self.assertEquals(action.action, 'erase')
23 self.assertEquals(action.area, 'screen')
24 self.assertEquals(action.erase_to, 'all')
22 self.assertEqual(action.action, 'erase')
23 self.assertEqual(action.area, 'screen')
24 self.assertEqual(action.erase_to, 'all')
25 25 elif i == 1:
26 self.assertEquals(len(self.processor.actions), 1)
26 self.assertEqual(len(self.processor.actions), 1)
27 27 action = self.processor.actions[0]
28 self.assertEquals(action.action, 'erase')
29 self.assertEquals(action.area, 'line')
30 self.assertEquals(action.erase_to, 'end')
28 self.assertEqual(action.action, 'erase')
29 self.assertEqual(action.area, 'line')
30 self.assertEqual(action.erase_to, 'end')
31 31 else:
32 32 self.fail('Too many substrings.')
33 self.assertEquals(i, 1, 'Too few substrings.')
33 self.assertEqual(i, 1, 'Too few substrings.')
34 34
35 35 def test_colors(self):
36 36 """ Do basic controls sequences for colors work?
37 37 """
38 38 string = 'first\x1b[34mblue\x1b[0mlast'
39 39 i = -1
40 40 for i, substring in enumerate(self.processor.split_string(string)):
41 41 if i == 0:
42 self.assertEquals(substring, 'first')
43 self.assertEquals(self.processor.foreground_color, None)
42 self.assertEqual(substring, 'first')
43 self.assertEqual(self.processor.foreground_color, None)
44 44 elif i == 1:
45 self.assertEquals(substring, 'blue')
46 self.assertEquals(self.processor.foreground_color, 4)
45 self.assertEqual(substring, 'blue')
46 self.assertEqual(self.processor.foreground_color, 4)
47 47 elif i == 2:
48 self.assertEquals(substring, 'last')
49 self.assertEquals(self.processor.foreground_color, None)
48 self.assertEqual(substring, 'last')
49 self.assertEqual(self.processor.foreground_color, None)
50 50 else:
51 51 self.fail('Too many substrings.')
52 self.assertEquals(i, 2, 'Too few substrings.')
52 self.assertEqual(i, 2, 'Too few substrings.')
53 53
54 54 def test_colors_xterm(self):
55 55 """ Do xterm-specific control sequences for colors work?
56 56 """
57 57 string = '\x1b]4;20;rgb:ff/ff/ff\x1b' \
58 58 '\x1b]4;25;rgbi:1.0/1.0/1.0\x1b'
59 59 substrings = list(self.processor.split_string(string))
60 60 desired = { 20 : (255, 255, 255),
61 61 25 : (255, 255, 255) }
62 self.assertEquals(self.processor.color_map, desired)
62 self.assertEqual(self.processor.color_map, desired)
63 63
64 64 string = '\x1b[38;5;20m\x1b[48;5;25m'
65 65 substrings = list(self.processor.split_string(string))
66 self.assertEquals(self.processor.foreground_color, 20)
67 self.assertEquals(self.processor.background_color, 25)
66 self.assertEqual(self.processor.foreground_color, 20)
67 self.assertEqual(self.processor.background_color, 25)
68 68
69 69 def test_scroll(self):
70 70 """ Do control sequences for scrolling the buffer work?
71 71 """
72 72 string = '\x1b[5S\x1b[T'
73 73 i = -1
74 74 for i, substring in enumerate(self.processor.split_string(string)):
75 75 if i == 0:
76 self.assertEquals(len(self.processor.actions), 1)
76 self.assertEqual(len(self.processor.actions), 1)
77 77 action = self.processor.actions[0]
78 self.assertEquals(action.action, 'scroll')
79 self.assertEquals(action.dir, 'up')
80 self.assertEquals(action.unit, 'line')
81 self.assertEquals(action.count, 5)
78 self.assertEqual(action.action, 'scroll')
79 self.assertEqual(action.dir, 'up')
80 self.assertEqual(action.unit, 'line')
81 self.assertEqual(action.count, 5)
82 82 elif i == 1:
83 self.assertEquals(len(self.processor.actions), 1)
83 self.assertEqual(len(self.processor.actions), 1)
84 84 action = self.processor.actions[0]
85 self.assertEquals(action.action, 'scroll')
86 self.assertEquals(action.dir, 'down')
87 self.assertEquals(action.unit, 'line')
88 self.assertEquals(action.count, 1)
85 self.assertEqual(action.action, 'scroll')
86 self.assertEqual(action.dir, 'down')
87 self.assertEqual(action.unit, 'line')
88 self.assertEqual(action.count, 1)
89 89 else:
90 90 self.fail('Too many substrings.')
91 self.assertEquals(i, 1, 'Too few substrings.')
91 self.assertEqual(i, 1, 'Too few substrings.')
92 92
93 93 def test_formfeed(self):
94 94 """ Are formfeed characters processed correctly?
95 95 """
96 96 string = '\f' # form feed
97 self.assertEquals(list(self.processor.split_string(string)), [''])
98 self.assertEquals(len(self.processor.actions), 1)
97 self.assertEqual(list(self.processor.split_string(string)), [''])
98 self.assertEqual(len(self.processor.actions), 1)
99 99 action = self.processor.actions[0]
100 self.assertEquals(action.action, 'scroll')
101 self.assertEquals(action.dir, 'down')
102 self.assertEquals(action.unit, 'page')
103 self.assertEquals(action.count, 1)
100 self.assertEqual(action.action, 'scroll')
101 self.assertEqual(action.dir, 'down')
102 self.assertEqual(action.unit, 'page')
103 self.assertEqual(action.count, 1)
104 104
105 105 def test_carriage_return(self):
106 106 """ Are carriage return characters processed correctly?
107 107 """
108 108 string = 'foo\rbar' # carriage return
109 109 splits = []
110 110 actions = []
111 111 for split in self.processor.split_string(string):
112 112 splits.append(split)
113 113 actions.append([action.action for action in self.processor.actions])
114 self.assertEquals(splits, ['foo', None, 'bar'])
115 self.assertEquals(actions, [[], ['carriage-return'], []])
114 self.assertEqual(splits, ['foo', None, 'bar'])
115 self.assertEqual(actions, [[], ['carriage-return'], []])
116 116
117 117 def test_carriage_return_newline(self):
118 118 """transform CRLF to LF"""
119 119 string = 'foo\rbar\r\ncat\r\n\n' # carriage return and newline
120 120 # only one CR action should occur, and '\r\n' should transform to '\n'
121 121 splits = []
122 122 actions = []
123 123 for split in self.processor.split_string(string):
124 124 splits.append(split)
125 125 actions.append([action.action for action in self.processor.actions])
126 self.assertEquals(splits, ['foo', None, 'bar', '\r\n', 'cat', '\r\n', '\n'])
127 self.assertEquals(actions, [[], ['carriage-return'], [], ['newline'], [], ['newline'], ['newline']])
126 self.assertEqual(splits, ['foo', None, 'bar', '\r\n', 'cat', '\r\n', '\n'])
127 self.assertEqual(actions, [[], ['carriage-return'], [], ['newline'], [], ['newline'], ['newline']])
128 128
129 129 def test_beep(self):
130 130 """ Are beep characters processed correctly?
131 131 """
132 132 string = 'foo\abar' # bell
133 133 splits = []
134 134 actions = []
135 135 for split in self.processor.split_string(string):
136 136 splits.append(split)
137 137 actions.append([action.action for action in self.processor.actions])
138 self.assertEquals(splits, ['foo', None, 'bar'])
139 self.assertEquals(actions, [[], ['beep'], []])
138 self.assertEqual(splits, ['foo', None, 'bar'])
139 self.assertEqual(actions, [[], ['beep'], []])
140 140
141 141 def test_backspace(self):
142 142 """ Are backspace characters processed correctly?
143 143 """
144 144 string = 'foo\bbar' # backspace
145 145 splits = []
146 146 actions = []
147 147 for split in self.processor.split_string(string):
148 148 splits.append(split)
149 149 actions.append([action.action for action in self.processor.actions])
150 self.assertEquals(splits, ['foo', None, 'bar'])
151 self.assertEquals(actions, [[], ['backspace'], []])
150 self.assertEqual(splits, ['foo', None, 'bar'])
151 self.assertEqual(actions, [[], ['backspace'], []])
152 152
153 153 def test_combined(self):
154 154 """ Are CR and BS characters processed correctly in combination?
155 155
156 156 BS is treated as a change in print position, rather than a
157 157 backwards character deletion. Therefore a BS at EOL is
158 158 effectively ignored.
159 159 """
160 160 string = 'abc\rdef\b' # CR and backspace
161 161 splits = []
162 162 actions = []
163 163 for split in self.processor.split_string(string):
164 164 splits.append(split)
165 165 actions.append([action.action for action in self.processor.actions])
166 self.assertEquals(splits, ['abc', None, 'def', None])
167 self.assertEquals(actions, [[], ['carriage-return'], [], ['backspace']])
166 self.assertEqual(splits, ['abc', None, 'def', None])
167 self.assertEqual(actions, [[], ['carriage-return'], [], ['backspace']])
168 168
169 169
170 170 if __name__ == '__main__':
171 171 unittest.main()
@@ -1,47 +1,47 b''
1 1 # Standard library imports
2 2 import unittest
3 3
4 4 # System library imports
5 5 from pygments.lexers import CLexer, CppLexer, PythonLexer
6 6
7 7 # Local imports
8 8 from IPython.frontend.qt.console.completion_lexer import CompletionLexer
9 9
10 10
11 11 class TestCompletionLexer(unittest.TestCase):
12 12
13 13 def testPython(self):
14 14 """ Does the CompletionLexer work for Python?
15 15 """
16 16 lexer = CompletionLexer(PythonLexer())
17 17
18 18 # Test simplest case.
19 self.assertEquals(lexer.get_context("foo.bar.baz"),
19 self.assertEqual(lexer.get_context("foo.bar.baz"),
20 20 [ "foo", "bar", "baz" ])
21 21
22 22 # Test trailing period.
23 self.assertEquals(lexer.get_context("foo.bar."), [ "foo", "bar", "" ])
23 self.assertEqual(lexer.get_context("foo.bar."), [ "foo", "bar", "" ])
24 24
25 25 # Test with prompt present.
26 self.assertEquals(lexer.get_context(">>> foo.bar.baz"),
26 self.assertEqual(lexer.get_context(">>> foo.bar.baz"),
27 27 [ "foo", "bar", "baz" ])
28 28
29 29 # Test spacing in name.
30 self.assertEquals(lexer.get_context("foo.bar. baz"), [ "baz" ])
30 self.assertEqual(lexer.get_context("foo.bar. baz"), [ "baz" ])
31 31
32 32 # Test parenthesis.
33 self.assertEquals(lexer.get_context("foo("), [])
33 self.assertEqual(lexer.get_context("foo("), [])
34 34
35 35 def testC(self):
36 36 """ Does the CompletionLexer work for C/C++?
37 37 """
38 38 lexer = CompletionLexer(CLexer())
39 self.assertEquals(lexer.get_context("foo.bar"), [ "foo", "bar" ])
40 self.assertEquals(lexer.get_context("foo->bar"), [ "foo", "bar" ])
39 self.assertEqual(lexer.get_context("foo.bar"), [ "foo", "bar" ])
40 self.assertEqual(lexer.get_context("foo->bar"), [ "foo", "bar" ])
41 41
42 42 lexer = CompletionLexer(CppLexer())
43 self.assertEquals(lexer.get_context("Foo::Bar"), [ "Foo", "Bar" ])
43 self.assertEqual(lexer.get_context("Foo::Bar"), [ "Foo", "Bar" ])
44 44
45 45
46 46 if __name__ == '__main__':
47 47 unittest.main()
@@ -1,42 +1,42 b''
1 1 # Standard library imports
2 2 import unittest
3 3
4 4 # System library imports
5 5 from IPython.external.qt import QtGui
6 6
7 7 # Local imports
8 8 from IPython.frontend.qt.console.console_widget import ConsoleWidget
9 9
10 10
11 11 class TestConsoleWidget(unittest.TestCase):
12 12
13 13 @classmethod
14 14 def setUpClass(cls):
15 15 """ Create the application for the test case.
16 16 """
17 17 cls._app = QtGui.QApplication.instance()
18 18 if cls._app is None:
19 19 cls._app = QtGui.QApplication([])
20 20 cls._app.setQuitOnLastWindowClosed(False)
21 21
22 22 @classmethod
23 23 def tearDownClass(cls):
24 24 """ Exit the application.
25 25 """
26 26 QtGui.QApplication.quit()
27 27
28 28 def test_special_characters(self):
29 29 """ Are special characters displayed correctly?
30 30 """
31 31 w = ConsoleWidget()
32 32 cursor = w._get_prompt_cursor()
33 33
34 34 test_inputs = ['xyz\b\b=\n', 'foo\b\nbar\n', 'foo\b\nbar\r\n', 'abc\rxyz\b\b=']
35 35 expected_outputs = [u'x=z\u2029', u'foo\u2029bar\u2029', u'foo\u2029bar\u2029', 'x=z']
36 36 for i, text in enumerate(test_inputs):
37 37 w._insert_plain_text(cursor, text)
38 38 cursor.select(cursor.Document)
39 39 selection = cursor.selectedText()
40 self.assertEquals(expected_outputs[i], selection)
40 self.assertEqual(expected_outputs[i], selection)
41 41 # clear all the text
42 42 cursor.insertText('')
@@ -1,171 +1,171 b''
1 1 # -*- coding: utf-8 -*-
2 2 """Tests for the key interactiveshell module.
3 3
4 4 Authors
5 5 -------
6 6 * Julian Taylor
7 7 """
8 8 #-----------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-----------------------------------------------------------------------------
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Imports
17 17 #-----------------------------------------------------------------------------
18 18 # stdlib
19 19 import sys
20 20 import unittest
21 21
22 22 from IPython.testing.decorators import skipif
23 23 from IPython.utils import py3compat
24 24
25 25 class InteractiveShellTestCase(unittest.TestCase):
26 26 def rl_hist_entries(self, rl, n):
27 27 """Get last n readline history entries as a list"""
28 28 return [rl.get_history_item(rl.get_current_history_length() - x)
29 29 for x in range(n - 1, -1, -1)]
30 30
31 31 def test_runs_without_rl(self):
32 32 """Test that function does not throw without readline"""
33 33 ip = get_ipython()
34 34 ip.has_readline = False
35 35 ip.readline = None
36 36 ip._replace_rlhist_multiline(u'source', 0)
37 37
38 38 @skipif(not get_ipython().has_readline, 'no readline')
39 39 def test_runs_without_remove_history_item(self):
40 40 """Test that function does not throw on windows without
41 41 remove_history_item"""
42 42 ip = get_ipython()
43 43 if hasattr(ip.readline, 'remove_history_item'):
44 44 del ip.readline.remove_history_item
45 45 ip._replace_rlhist_multiline(u'source', 0)
46 46
47 47 @skipif(not get_ipython().has_readline, 'no readline')
48 48 @skipif(not hasattr(get_ipython().readline, 'remove_history_item'),
49 49 'no remove_history_item')
50 50 def test_replace_multiline_hist_disabled(self):
51 51 """Test that multiline replace does nothing if disabled"""
52 52 ip = get_ipython()
53 53 ip.multiline_history = False
54 54
55 55 ghist = [u'line1', u'line2']
56 56 for h in ghist:
57 57 ip.readline.add_history(h)
58 58 hlen_b4_cell = ip.readline.get_current_history_length()
59 59 hlen_b4_cell = ip._replace_rlhist_multiline(u'sourc€\nsource2',
60 60 hlen_b4_cell)
61 61
62 self.assertEquals(ip.readline.get_current_history_length(),
62 self.assertEqual(ip.readline.get_current_history_length(),
63 63 hlen_b4_cell)
64 64 hist = self.rl_hist_entries(ip.readline, 2)
65 self.assertEquals(hist, ghist)
65 self.assertEqual(hist, ghist)
66 66
67 67 @skipif(not get_ipython().has_readline, 'no readline')
68 68 @skipif(not hasattr(get_ipython().readline, 'remove_history_item'),
69 69 'no remove_history_item')
70 70 def test_replace_multiline_hist_adds(self):
71 71 """Test that multiline replace function adds history"""
72 72 ip = get_ipython()
73 73
74 74 hlen_b4_cell = ip.readline.get_current_history_length()
75 75 hlen_b4_cell = ip._replace_rlhist_multiline(u'sourc€', hlen_b4_cell)
76 76
77 self.assertEquals(hlen_b4_cell,
77 self.assertEqual(hlen_b4_cell,
78 78 ip.readline.get_current_history_length())
79 79
80 80 @skipif(not get_ipython().has_readline, 'no readline')
81 81 @skipif(not hasattr(get_ipython().readline, 'remove_history_item'),
82 82 'no remove_history_item')
83 83 def test_replace_multiline_hist_keeps_history(self):
84 84 """Test that multiline replace does not delete history"""
85 85 ip = get_ipython()
86 86 ip.multiline_history = True
87 87
88 88 ghist = [u'line1', u'line2']
89 89 for h in ghist:
90 90 ip.readline.add_history(h)
91 91
92 92 #start cell
93 93 hlen_b4_cell = ip.readline.get_current_history_length()
94 94 # nothing added to rl history, should do nothing
95 95 hlen_b4_cell = ip._replace_rlhist_multiline(u'sourc€\nsource2',
96 96 hlen_b4_cell)
97 97
98 self.assertEquals(ip.readline.get_current_history_length(),
98 self.assertEqual(ip.readline.get_current_history_length(),
99 99 hlen_b4_cell)
100 100 hist = self.rl_hist_entries(ip.readline, 2)
101 self.assertEquals(hist, ghist)
101 self.assertEqual(hist, ghist)
102 102
103 103
104 104 @skipif(not get_ipython().has_readline, 'no readline')
105 105 @skipif(not hasattr(get_ipython().readline, 'remove_history_item'),
106 106 'no remove_history_item')
107 107 def test_replace_multiline_hist_replaces_twice(self):
108 108 """Test that multiline entries are replaced twice"""
109 109 ip = get_ipython()
110 110 ip.multiline_history = True
111 111
112 112 ip.readline.add_history(u'line0')
113 113 #start cell
114 114 hlen_b4_cell = ip.readline.get_current_history_length()
115 115 ip.readline.add_history('l€ne1')
116 116 ip.readline.add_history('line2')
117 117 #replace cell with single line
118 118 hlen_b4_cell = ip._replace_rlhist_multiline(u'l€ne1\nline2',
119 119 hlen_b4_cell)
120 120 ip.readline.add_history('l€ne3')
121 121 ip.readline.add_history('line4')
122 122 #replace cell with single line
123 123 hlen_b4_cell = ip._replace_rlhist_multiline(u'l€ne3\nline4',
124 124 hlen_b4_cell)
125 125
126 self.assertEquals(ip.readline.get_current_history_length(),
126 self.assertEqual(ip.readline.get_current_history_length(),
127 127 hlen_b4_cell)
128 128 hist = self.rl_hist_entries(ip.readline, 3)
129 129 expected = [u'line0', u'l€ne1\nline2', u'l€ne3\nline4']
130 130 # perform encoding, in case of casting due to ASCII locale
131 131 enc = sys.stdin.encoding or "utf-8"
132 132 expected = [ py3compat.unicode_to_str(e, enc) for e in expected ]
133 self.assertEquals(hist, expected)
133 self.assertEqual(hist, expected)
134 134
135 135
136 136 @skipif(not get_ipython().has_readline, 'no readline')
137 137 @skipif(not hasattr(get_ipython().readline, 'remove_history_item'),
138 138 'no remove_history_item')
139 139 def test_replace_multiline_hist_replaces_empty_line(self):
140 140 """Test that multiline history skips empty line cells"""
141 141 ip = get_ipython()
142 142 ip.multiline_history = True
143 143
144 144 ip.readline.add_history(u'line0')
145 145 #start cell
146 146 hlen_b4_cell = ip.readline.get_current_history_length()
147 147 ip.readline.add_history('l€ne1')
148 148 ip.readline.add_history('line2')
149 149 hlen_b4_cell = ip._replace_rlhist_multiline(u'l€ne1\nline2',
150 150 hlen_b4_cell)
151 151 ip.readline.add_history('')
152 152 hlen_b4_cell = ip._replace_rlhist_multiline(u'', hlen_b4_cell)
153 153 ip.readline.add_history('l€ne3')
154 154 hlen_b4_cell = ip._replace_rlhist_multiline(u'l€ne3', hlen_b4_cell)
155 155 ip.readline.add_history(' ')
156 156 hlen_b4_cell = ip._replace_rlhist_multiline(' ', hlen_b4_cell)
157 157 ip.readline.add_history('\t')
158 158 ip.readline.add_history('\t ')
159 159 hlen_b4_cell = ip._replace_rlhist_multiline('\t', hlen_b4_cell)
160 160 ip.readline.add_history('line4')
161 161 hlen_b4_cell = ip._replace_rlhist_multiline(u'line4', hlen_b4_cell)
162 162
163 self.assertEquals(ip.readline.get_current_history_length(),
163 self.assertEqual(ip.readline.get_current_history_length(),
164 164 hlen_b4_cell)
165 165 hist = self.rl_hist_entries(ip.readline, 4)
166 166 # expect no empty cells in history
167 167 expected = [u'line0', u'l€ne1\nline2', u'l€ne3', u'line4']
168 168 # perform encoding, in case of casting due to ASCII locale
169 169 enc = sys.stdin.encoding or "utf-8"
170 170 expected = [ py3compat.unicode_to_str(e, enc) for e in expected ]
171 self.assertEquals(hist, expected)
171 self.assertEqual(hist, expected)
@@ -1,14 +1,14 b''
1 1 from unittest import TestCase
2 2
3 3 from ..nbjson import reads, writes
4 4 from .nbexamples import nb0
5 5
6 6
7 7 class TestJSON(TestCase):
8 8
9 9 def test_roundtrip(self):
10 10 s = writes(nb0)
11 self.assertEquals(reads(s),nb0)
11 self.assertEqual(reads(s),nb0)
12 12
13 13
14 14
@@ -1,41 +1,41 b''
1 1 from unittest import TestCase
2 2
3 3 from ..nbbase import (
4 4 NotebookNode,
5 5 new_code_cell, new_text_cell, new_notebook
6 6 )
7 7
8 8 class TestCell(TestCase):
9 9
10 10 def test_empty_code_cell(self):
11 11 cc = new_code_cell()
12 self.assertEquals(cc.cell_type,'code')
13 self.assertEquals('code' not in cc, True)
14 self.assertEquals('prompt_number' not in cc, True)
12 self.assertEqual(cc.cell_type,'code')
13 self.assertEqual('code' not in cc, True)
14 self.assertEqual('prompt_number' not in cc, True)
15 15
16 16 def test_code_cell(self):
17 17 cc = new_code_cell(code='a=10', prompt_number=0)
18 self.assertEquals(cc.code, u'a=10')
19 self.assertEquals(cc.prompt_number, 0)
18 self.assertEqual(cc.code, u'a=10')
19 self.assertEqual(cc.prompt_number, 0)
20 20
21 21 def test_empty_text_cell(self):
22 22 tc = new_text_cell()
23 self.assertEquals(tc.cell_type, 'text')
24 self.assertEquals('text' not in tc, True)
23 self.assertEqual(tc.cell_type, 'text')
24 self.assertEqual('text' not in tc, True)
25 25
26 26 def test_text_cell(self):
27 27 tc = new_text_cell('hi')
28 self.assertEquals(tc.text, u'hi')
28 self.assertEqual(tc.text, u'hi')
29 29
30 30
31 31 class TestNotebook(TestCase):
32 32
33 33 def test_empty_notebook(self):
34 34 nb = new_notebook()
35 self.assertEquals(nb.cells, [])
35 self.assertEqual(nb.cells, [])
36 36
37 37 def test_notebooke(self):
38 38 cells = [new_code_cell(),new_text_cell()]
39 39 nb = new_notebook(cells=cells)
40 self.assertEquals(nb.cells,cells)
40 self.assertEqual(nb.cells,cells)
41 41
@@ -1,34 +1,34 b''
1 1 import pprint
2 2 from unittest import TestCase
3 3
4 4 from ..nbjson import reads, writes
5 5 from .nbexamples import nb0
6 6
7 7
8 8 class TestJSON(TestCase):
9 9
10 10 def test_roundtrip(self):
11 11 s = writes(nb0)
12 12 # print
13 13 # print pprint.pformat(nb0,indent=2)
14 14 # print
15 15 # print pprint.pformat(reads(s),indent=2)
16 16 # print
17 17 # print s
18 self.assertEquals(reads(s),nb0)
18 self.assertEqual(reads(s),nb0)
19 19
20 20 def test_roundtrip_nosplit(self):
21 21 """Ensure that multiline blobs are still readable"""
22 22 # ensures that notebooks written prior to splitlines change
23 23 # are still readable.
24 24 s = writes(nb0, split_lines=False)
25 self.assertEquals(reads(s),nb0)
25 self.assertEqual(reads(s),nb0)
26 26
27 27 def test_roundtrip_split(self):
28 28 """Ensure that splitting multiline blocks is safe"""
29 29 # This won't differ from test_roundtrip unless the default changes
30 30 s = writes(nb0, split_lines=True)
31 self.assertEquals(reads(s),nb0)
31 self.assertEqual(reads(s),nb0)
32 32
33 33
34 34
@@ -1,113 +1,113 b''
1 1 from unittest import TestCase
2 2
3 3 from ..nbbase import (
4 4 NotebookNode,
5 5 new_code_cell, new_text_cell, new_worksheet, new_notebook, new_output,
6 6 new_author, new_metadata
7 7 )
8 8
9 9 class TestCell(TestCase):
10 10
11 11 def test_empty_code_cell(self):
12 12 cc = new_code_cell()
13 self.assertEquals(cc.cell_type,u'code')
14 self.assertEquals(u'input' not in cc, True)
15 self.assertEquals(u'prompt_number' not in cc, True)
16 self.assertEquals(cc.outputs, [])
17 self.assertEquals(cc.collapsed, False)
13 self.assertEqual(cc.cell_type,u'code')
14 self.assertEqual(u'input' not in cc, True)
15 self.assertEqual(u'prompt_number' not in cc, True)
16 self.assertEqual(cc.outputs, [])
17 self.assertEqual(cc.collapsed, False)
18 18
19 19 def test_code_cell(self):
20 20 cc = new_code_cell(input='a=10', prompt_number=0, collapsed=True)
21 21 cc.outputs = [new_output(output_type=u'pyout',
22 22 output_svg=u'foo',output_text=u'10',prompt_number=0)]
23 self.assertEquals(cc.input, u'a=10')
24 self.assertEquals(cc.prompt_number, 0)
25 self.assertEquals(cc.language, u'python')
26 self.assertEquals(cc.outputs[0].svg, u'foo')
27 self.assertEquals(cc.outputs[0].text, u'10')
28 self.assertEquals(cc.outputs[0].prompt_number, 0)
29 self.assertEquals(cc.collapsed, True)
23 self.assertEqual(cc.input, u'a=10')
24 self.assertEqual(cc.prompt_number, 0)
25 self.assertEqual(cc.language, u'python')
26 self.assertEqual(cc.outputs[0].svg, u'foo')
27 self.assertEqual(cc.outputs[0].text, u'10')
28 self.assertEqual(cc.outputs[0].prompt_number, 0)
29 self.assertEqual(cc.collapsed, True)
30 30
31 31 def test_pyerr(self):
32 32 o = new_output(output_type=u'pyerr', etype=u'NameError',
33 33 evalue=u'Name not found', traceback=[u'frame 0', u'frame 1', u'frame 2']
34 34 )
35 self.assertEquals(o.output_type, u'pyerr')
36 self.assertEquals(o.etype, u'NameError')
37 self.assertEquals(o.evalue, u'Name not found')
38 self.assertEquals(o.traceback, [u'frame 0', u'frame 1', u'frame 2'])
35 self.assertEqual(o.output_type, u'pyerr')
36 self.assertEqual(o.etype, u'NameError')
37 self.assertEqual(o.evalue, u'Name not found')
38 self.assertEqual(o.traceback, [u'frame 0', u'frame 1', u'frame 2'])
39 39
40 40 def test_empty_html_cell(self):
41 41 tc = new_text_cell(u'html')
42 self.assertEquals(tc.cell_type, u'html')
43 self.assertEquals(u'source' not in tc, True)
44 self.assertEquals(u'rendered' not in tc, True)
42 self.assertEqual(tc.cell_type, u'html')
43 self.assertEqual(u'source' not in tc, True)
44 self.assertEqual(u'rendered' not in tc, True)
45 45
46 46 def test_html_cell(self):
47 47 tc = new_text_cell(u'html', 'hi', 'hi')
48 self.assertEquals(tc.source, u'hi')
49 self.assertEquals(tc.rendered, u'hi')
48 self.assertEqual(tc.source, u'hi')
49 self.assertEqual(tc.rendered, u'hi')
50 50
51 51 def test_empty_markdown_cell(self):
52 52 tc = new_text_cell(u'markdown')
53 self.assertEquals(tc.cell_type, u'markdown')
54 self.assertEquals(u'source' not in tc, True)
55 self.assertEquals(u'rendered' not in tc, True)
53 self.assertEqual(tc.cell_type, u'markdown')
54 self.assertEqual(u'source' not in tc, True)
55 self.assertEqual(u'rendered' not in tc, True)
56 56
57 57 def test_markdown_cell(self):
58 58 tc = new_text_cell(u'markdown', 'hi', 'hi')
59 self.assertEquals(tc.source, u'hi')
60 self.assertEquals(tc.rendered, u'hi')
59 self.assertEqual(tc.source, u'hi')
60 self.assertEqual(tc.rendered, u'hi')
61 61
62 62
63 63 class TestWorksheet(TestCase):
64 64
65 65 def test_empty_worksheet(self):
66 66 ws = new_worksheet()
67 self.assertEquals(ws.cells,[])
68 self.assertEquals(u'name' not in ws, True)
67 self.assertEqual(ws.cells,[])
68 self.assertEqual(u'name' not in ws, True)
69 69
70 70 def test_worksheet(self):
71 71 cells = [new_code_cell(), new_text_cell(u'html')]
72 72 ws = new_worksheet(cells=cells,name=u'foo')
73 self.assertEquals(ws.cells,cells)
74 self.assertEquals(ws.name,u'foo')
73 self.assertEqual(ws.cells,cells)
74 self.assertEqual(ws.name,u'foo')
75 75
76 76 class TestNotebook(TestCase):
77 77
78 78 def test_empty_notebook(self):
79 79 nb = new_notebook()
80 self.assertEquals(nb.worksheets, [])
81 self.assertEquals(nb.metadata, NotebookNode())
82 self.assertEquals(nb.nbformat,2)
80 self.assertEqual(nb.worksheets, [])
81 self.assertEqual(nb.metadata, NotebookNode())
82 self.assertEqual(nb.nbformat,2)
83 83
84 84 def test_notebook(self):
85 85 worksheets = [new_worksheet(),new_worksheet()]
86 86 metadata = new_metadata(name=u'foo')
87 87 nb = new_notebook(metadata=metadata,worksheets=worksheets)
88 self.assertEquals(nb.metadata.name,u'foo')
89 self.assertEquals(nb.worksheets,worksheets)
90 self.assertEquals(nb.nbformat,2)
88 self.assertEqual(nb.metadata.name,u'foo')
89 self.assertEqual(nb.worksheets,worksheets)
90 self.assertEqual(nb.nbformat,2)
91 91
92 92 class TestMetadata(TestCase):
93 93
94 94 def test_empty_metadata(self):
95 95 md = new_metadata()
96 self.assertEquals(u'name' not in md, True)
97 self.assertEquals(u'authors' not in md, True)
98 self.assertEquals(u'license' not in md, True)
99 self.assertEquals(u'saved' not in md, True)
100 self.assertEquals(u'modified' not in md, True)
101 self.assertEquals(u'gistid' not in md, True)
96 self.assertEqual(u'name' not in md, True)
97 self.assertEqual(u'authors' not in md, True)
98 self.assertEqual(u'license' not in md, True)
99 self.assertEqual(u'saved' not in md, True)
100 self.assertEqual(u'modified' not in md, True)
101 self.assertEqual(u'gistid' not in md, True)
102 102
103 103 def test_metadata(self):
104 104 authors = [new_author(name='Bart Simpson',email='bsimpson@fox.com')]
105 105 md = new_metadata(name=u'foo',license=u'BSD',created=u'today',
106 106 modified=u'now',gistid=u'21341231',authors=authors)
107 self.assertEquals(md.name, u'foo')
108 self.assertEquals(md.license, u'BSD')
109 self.assertEquals(md.created, u'today')
110 self.assertEquals(md.modified, u'now')
111 self.assertEquals(md.gistid, u'21341231')
112 self.assertEquals(md.authors, authors)
107 self.assertEqual(md.name, u'foo')
108 self.assertEqual(md.license, u'BSD')
109 self.assertEqual(md.created, u'today')
110 self.assertEqual(md.modified, u'now')
111 self.assertEqual(md.gistid, u'21341231')
112 self.assertEqual(md.authors, authors)
113 113
@@ -1,17 +1,17 b''
1 1 from unittest import TestCase
2 2
3 3 from ..nbbase import (
4 4 NotebookNode,
5 5 new_code_cell, new_text_cell, new_worksheet, new_notebook
6 6 )
7 7
8 8 from ..nbpy import reads, writes
9 9 from .nbexamples import nb0, nb0_py
10 10
11 11
12 12 class TestPy(TestCase):
13 13
14 14 def test_write(self):
15 15 s = writes(nb0)
16 self.assertEquals(s,nb0_py)
16 self.assertEqual(s,nb0_py)
17 17
@@ -1,63 +1,63 b''
1 1 # -*- coding: utf8 -*-
2 2 import io
3 3 import os
4 4 import shutil
5 5 import tempfile
6 6
7 7 pjoin = os.path.join
8 8
9 9 from ..nbbase import (
10 10 NotebookNode,
11 11 new_code_cell, new_text_cell, new_worksheet, new_notebook
12 12 )
13 13
14 14 from ..nbpy import reads, writes, read, write
15 15 from .nbexamples import nb0, nb0_py
16 16
17 17
18 18 def open_utf8(fname, mode):
19 19 return io.open(fname, mode=mode, encoding='utf-8')
20 20
21 21 class NBFormatTest:
22 22 """Mixin for writing notebook format tests"""
23 23
24 24 # override with appropriate values in subclasses
25 25 nb0_ref = None
26 26 ext = None
27 27 mod = None
28 28
29 29 def setUp(self):
30 30 self.wd = tempfile.mkdtemp()
31 31
32 32 def tearDown(self):
33 33 shutil.rmtree(self.wd)
34 34
35 35 def assertNBEquals(self, nba, nbb):
36 self.assertEquals(nba, nbb)
36 self.assertEqual(nba, nbb)
37 37
38 38 def test_writes(self):
39 39 s = self.mod.writes(nb0)
40 40 if self.nb0_ref:
41 self.assertEquals(s, self.nb0_ref)
41 self.assertEqual(s, self.nb0_ref)
42 42
43 43 def test_reads(self):
44 44 s = self.mod.writes(nb0)
45 45 nb = self.mod.reads(s)
46 46
47 47 def test_roundtrip(self):
48 48 s = self.mod.writes(nb0)
49 49 self.assertNBEquals(self.mod.reads(s),nb0)
50 50
51 51 def test_write_file(self):
52 52 with open_utf8(pjoin(self.wd, "nb0.%s" % self.ext), 'w') as f:
53 53 self.mod.write(nb0, f)
54 54
55 55 def test_read_file(self):
56 56 with open_utf8(pjoin(self.wd, "nb0.%s" % self.ext), 'w') as f:
57 57 self.mod.write(nb0, f)
58 58
59 59 with open_utf8(pjoin(self.wd, "nb0.%s" % self.ext), 'r') as f:
60 60 nb = self.mod.read(f)
61 61
62 62
63 63
@@ -1,33 +1,33 b''
1 1 import pprint
2 2 from unittest import TestCase
3 3
4 4 from ..nbjson import reads, writes
5 5 from .. import nbjson
6 6 from .nbexamples import nb0
7 7
8 8 from . import formattest
9 9
10 10 from .nbexamples import nb0
11 11
12 12
13 13 class TestJSON(formattest.NBFormatTest, TestCase):
14 14
15 15 nb0_ref = None
16 16 ext = 'ipynb'
17 17 mod = nbjson
18 18
19 19 def test_roundtrip_nosplit(self):
20 20 """Ensure that multiline blobs are still readable"""
21 21 # ensures that notebooks written prior to splitlines change
22 22 # are still readable.
23 23 s = writes(nb0, split_lines=False)
24 self.assertEquals(nbjson.reads(s),nb0)
24 self.assertEqual(nbjson.reads(s),nb0)
25 25
26 26 def test_roundtrip_split(self):
27 27 """Ensure that splitting multiline blocks is safe"""
28 28 # This won't differ from test_roundtrip unless the default changes
29 29 s = writes(nb0, split_lines=True)
30 self.assertEquals(nbjson.reads(s),nb0)
30 self.assertEqual(nbjson.reads(s),nb0)
31 31
32 32
33 33
@@ -1,143 +1,143 b''
1 1 from unittest import TestCase
2 2
3 3 from ..nbbase import (
4 4 NotebookNode,
5 5 new_code_cell, new_text_cell, new_worksheet, new_notebook, new_output,
6 6 new_author, new_metadata, new_heading_cell, nbformat
7 7 )
8 8
9 9 class TestCell(TestCase):
10 10
11 11 def test_empty_code_cell(self):
12 12 cc = new_code_cell()
13 self.assertEquals(cc.cell_type,u'code')
14 self.assertEquals(u'input' not in cc, True)
15 self.assertEquals(u'prompt_number' not in cc, True)
16 self.assertEquals(cc.outputs, [])
17 self.assertEquals(cc.collapsed, False)
13 self.assertEqual(cc.cell_type,u'code')
14 self.assertEqual(u'input' not in cc, True)
15 self.assertEqual(u'prompt_number' not in cc, True)
16 self.assertEqual(cc.outputs, [])
17 self.assertEqual(cc.collapsed, False)
18 18
19 19 def test_code_cell(self):
20 20 cc = new_code_cell(input='a=10', prompt_number=0, collapsed=True)
21 21 cc.outputs = [new_output(output_type=u'pyout',
22 22 output_svg=u'foo',output_text=u'10',prompt_number=0)]
23 self.assertEquals(cc.input, u'a=10')
24 self.assertEquals(cc.prompt_number, 0)
25 self.assertEquals(cc.language, u'python')
26 self.assertEquals(cc.outputs[0].svg, u'foo')
27 self.assertEquals(cc.outputs[0].text, u'10')
28 self.assertEquals(cc.outputs[0].prompt_number, 0)
29 self.assertEquals(cc.collapsed, True)
23 self.assertEqual(cc.input, u'a=10')
24 self.assertEqual(cc.prompt_number, 0)
25 self.assertEqual(cc.language, u'python')
26 self.assertEqual(cc.outputs[0].svg, u'foo')
27 self.assertEqual(cc.outputs[0].text, u'10')
28 self.assertEqual(cc.outputs[0].prompt_number, 0)
29 self.assertEqual(cc.collapsed, True)
30 30
31 31 def test_pyerr(self):
32 32 o = new_output(output_type=u'pyerr', etype=u'NameError',
33 33 evalue=u'Name not found', traceback=[u'frame 0', u'frame 1', u'frame 2']
34 34 )
35 self.assertEquals(o.output_type, u'pyerr')
36 self.assertEquals(o.etype, u'NameError')
37 self.assertEquals(o.evalue, u'Name not found')
38 self.assertEquals(o.traceback, [u'frame 0', u'frame 1', u'frame 2'])
35 self.assertEqual(o.output_type, u'pyerr')
36 self.assertEqual(o.etype, u'NameError')
37 self.assertEqual(o.evalue, u'Name not found')
38 self.assertEqual(o.traceback, [u'frame 0', u'frame 1', u'frame 2'])
39 39
40 40 def test_empty_html_cell(self):
41 41 tc = new_text_cell(u'html')
42 self.assertEquals(tc.cell_type, u'html')
43 self.assertEquals(u'source' not in tc, True)
44 self.assertEquals(u'rendered' not in tc, True)
42 self.assertEqual(tc.cell_type, u'html')
43 self.assertEqual(u'source' not in tc, True)
44 self.assertEqual(u'rendered' not in tc, True)
45 45
46 46 def test_html_cell(self):
47 47 tc = new_text_cell(u'html', 'hi', 'hi')
48 self.assertEquals(tc.source, u'hi')
49 self.assertEquals(tc.rendered, u'hi')
48 self.assertEqual(tc.source, u'hi')
49 self.assertEqual(tc.rendered, u'hi')
50 50
51 51 def test_empty_markdown_cell(self):
52 52 tc = new_text_cell(u'markdown')
53 self.assertEquals(tc.cell_type, u'markdown')
54 self.assertEquals(u'source' not in tc, True)
55 self.assertEquals(u'rendered' not in tc, True)
53 self.assertEqual(tc.cell_type, u'markdown')
54 self.assertEqual(u'source' not in tc, True)
55 self.assertEqual(u'rendered' not in tc, True)
56 56
57 57 def test_markdown_cell(self):
58 58 tc = new_text_cell(u'markdown', 'hi', 'hi')
59 self.assertEquals(tc.source, u'hi')
60 self.assertEquals(tc.rendered, u'hi')
59 self.assertEqual(tc.source, u'hi')
60 self.assertEqual(tc.rendered, u'hi')
61 61
62 62 def test_empty_raw_cell(self):
63 63 tc = new_text_cell(u'raw')
64 self.assertEquals(tc.cell_type, u'raw')
65 self.assertEquals(u'source' not in tc, True)
66 self.assertEquals(u'rendered' not in tc, True)
64 self.assertEqual(tc.cell_type, u'raw')
65 self.assertEqual(u'source' not in tc, True)
66 self.assertEqual(u'rendered' not in tc, True)
67 67
68 68 def test_raw_cell(self):
69 69 tc = new_text_cell(u'raw', 'hi', 'hi')
70 self.assertEquals(tc.source, u'hi')
71 self.assertEquals(tc.rendered, u'hi')
70 self.assertEqual(tc.source, u'hi')
71 self.assertEqual(tc.rendered, u'hi')
72 72
73 73 def test_empty_heading_cell(self):
74 74 tc = new_heading_cell()
75 self.assertEquals(tc.cell_type, u'heading')
76 self.assertEquals(u'source' not in tc, True)
77 self.assertEquals(u'rendered' not in tc, True)
75 self.assertEqual(tc.cell_type, u'heading')
76 self.assertEqual(u'source' not in tc, True)
77 self.assertEqual(u'rendered' not in tc, True)
78 78
79 79 def test_heading_cell(self):
80 80 tc = new_heading_cell(u'hi', u'hi', level=2)
81 self.assertEquals(tc.source, u'hi')
82 self.assertEquals(tc.rendered, u'hi')
83 self.assertEquals(tc.level, 2)
81 self.assertEqual(tc.source, u'hi')
82 self.assertEqual(tc.rendered, u'hi')
83 self.assertEqual(tc.level, 2)
84 84
85 85
86 86 class TestWorksheet(TestCase):
87 87
88 88 def test_empty_worksheet(self):
89 89 ws = new_worksheet()
90 self.assertEquals(ws.cells,[])
91 self.assertEquals(u'name' not in ws, True)
90 self.assertEqual(ws.cells,[])
91 self.assertEqual(u'name' not in ws, True)
92 92
93 93 def test_worksheet(self):
94 94 cells = [new_code_cell(), new_text_cell(u'html')]
95 95 ws = new_worksheet(cells=cells,name=u'foo')
96 self.assertEquals(ws.cells,cells)
97 self.assertEquals(ws.name,u'foo')
96 self.assertEqual(ws.cells,cells)
97 self.assertEqual(ws.name,u'foo')
98 98
99 99 class TestNotebook(TestCase):
100 100
101 101 def test_empty_notebook(self):
102 102 nb = new_notebook()
103 self.assertEquals(nb.worksheets, [])
104 self.assertEquals(nb.metadata, NotebookNode())
105 self.assertEquals(nb.nbformat,nbformat)
103 self.assertEqual(nb.worksheets, [])
104 self.assertEqual(nb.metadata, NotebookNode())
105 self.assertEqual(nb.nbformat,nbformat)
106 106
107 107 def test_notebook(self):
108 108 worksheets = [new_worksheet(),new_worksheet()]
109 109 metadata = new_metadata(name=u'foo')
110 110 nb = new_notebook(metadata=metadata,worksheets=worksheets)
111 self.assertEquals(nb.metadata.name,u'foo')
112 self.assertEquals(nb.worksheets,worksheets)
113 self.assertEquals(nb.nbformat,nbformat)
111 self.assertEqual(nb.metadata.name,u'foo')
112 self.assertEqual(nb.worksheets,worksheets)
113 self.assertEqual(nb.nbformat,nbformat)
114 114
115 115 def test_notebook_name(self):
116 116 worksheets = [new_worksheet(),new_worksheet()]
117 117 nb = new_notebook(name='foo',worksheets=worksheets)
118 self.assertEquals(nb.metadata.name,u'foo')
119 self.assertEquals(nb.worksheets,worksheets)
120 self.assertEquals(nb.nbformat,nbformat)
118 self.assertEqual(nb.metadata.name,u'foo')
119 self.assertEqual(nb.worksheets,worksheets)
120 self.assertEqual(nb.nbformat,nbformat)
121 121
122 122 class TestMetadata(TestCase):
123 123
124 124 def test_empty_metadata(self):
125 125 md = new_metadata()
126 self.assertEquals(u'name' not in md, True)
127 self.assertEquals(u'authors' not in md, True)
128 self.assertEquals(u'license' not in md, True)
129 self.assertEquals(u'saved' not in md, True)
130 self.assertEquals(u'modified' not in md, True)
131 self.assertEquals(u'gistid' not in md, True)
126 self.assertEqual(u'name' not in md, True)
127 self.assertEqual(u'authors' not in md, True)
128 self.assertEqual(u'license' not in md, True)
129 self.assertEqual(u'saved' not in md, True)
130 self.assertEqual(u'modified' not in md, True)
131 self.assertEqual(u'gistid' not in md, True)
132 132
133 133 def test_metadata(self):
134 134 authors = [new_author(name='Bart Simpson',email='bsimpson@fox.com')]
135 135 md = new_metadata(name=u'foo',license=u'BSD',created=u'today',
136 136 modified=u'now',gistid=u'21341231',authors=authors)
137 self.assertEquals(md.name, u'foo')
138 self.assertEquals(md.license, u'BSD')
139 self.assertEquals(md.created, u'today')
140 self.assertEquals(md.modified, u'now')
141 self.assertEquals(md.gistid, u'21341231')
142 self.assertEquals(md.authors, authors)
137 self.assertEqual(md.name, u'foo')
138 self.assertEqual(md.license, u'BSD')
139 self.assertEqual(md.created, u'today')
140 self.assertEqual(md.modified, u'now')
141 self.assertEqual(md.gistid, u'21341231')
142 self.assertEqual(md.authors, authors)
143 143
@@ -1,46 +1,46 b''
1 1 # -*- coding: utf8 -*-
2 2
3 3 from unittest import TestCase
4 4
5 5 from . import formattest
6 6
7 7 from .. import nbpy
8 8 from .nbexamples import nb0, nb0_py
9 9
10 10
11 11 class TestPy(formattest.NBFormatTest, TestCase):
12 12
13 13 nb0_ref = nb0_py
14 14 ext = 'py'
15 15 mod = nbpy
16 16 ignored_keys = ['collapsed', 'outputs', 'prompt_number', 'metadata']
17 17
18 18 def assertSubset(self, da, db):
19 19 """assert that da is a subset of db, ignoring self.ignored_keys.
20 20
21 21 Called recursively on containers, ultimately comparing individual
22 22 elements.
23 23 """
24 24 if isinstance(da, dict):
25 25 for k,v in da.iteritems():
26 26 if k in self.ignored_keys:
27 27 continue
28 28 self.assertTrue(k in db)
29 29 self.assertSubset(v, db[k])
30 30 elif isinstance(da, list):
31 31 for a,b in zip(da, db):
32 32 self.assertSubset(a,b)
33 33 else:
34 34 if isinstance(da, basestring) and isinstance(db, basestring):
35 35 # pyfile is not sensitive to preserving leading/trailing
36 36 # newlines in blocks through roundtrip
37 37 da = da.strip('\n')
38 38 db = db.strip('\n')
39 self.assertEquals(da, db)
39 self.assertEqual(da, db)
40 40 return True
41 41
42 42 def assertNBEquals(self, nba, nbb):
43 43 # since roundtrip is lossy, only compare keys that are preserved
44 44 # assumes nba is read from my file format
45 45 return self.assertSubset(nba, nbb)
46 46
@@ -1,184 +1,184 b''
1 1 """base class for parallel client tests
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14 from __future__ import print_function
15 15
16 16 import sys
17 17 import tempfile
18 18 import time
19 19 from StringIO import StringIO
20 20
21 21 from nose import SkipTest
22 22
23 23 import zmq
24 24 from zmq.tests import BaseZMQTestCase
25 25
26 26 from IPython.external.decorator import decorator
27 27
28 28 from IPython.parallel import error
29 29 from IPython.parallel import Client
30 30
31 31 from IPython.parallel.tests import launchers, add_engines
32 32
33 33 # simple tasks for use in apply tests
34 34
35 35 def segfault():
36 36 """this will segfault"""
37 37 import ctypes
38 38 ctypes.memset(-1,0,1)
39 39
40 40 def crash():
41 41 """from stdlib crashers in the test suite"""
42 42 import types
43 43 if sys.platform.startswith('win'):
44 44 import ctypes
45 45 ctypes.windll.kernel32.SetErrorMode(0x0002);
46 46 args = [ 0, 0, 0, 0, b'\x04\x71\x00\x00', (), (), (), '', '', 1, b'']
47 47 if sys.version_info[0] >= 3:
48 48 # Python3 adds 'kwonlyargcount' as the second argument to Code
49 49 args.insert(1, 0)
50 50
51 51 co = types.CodeType(*args)
52 52 exec(co)
53 53
54 54 def wait(n):
55 55 """sleep for a time"""
56 56 import time
57 57 time.sleep(n)
58 58 return n
59 59
60 60 def raiser(eclass):
61 61 """raise an exception"""
62 62 raise eclass()
63 63
64 64 def generate_output():
65 65 """function for testing output
66 66
67 67 publishes two outputs of each type, and returns
68 68 a rich displayable object.
69 69 """
70 70
71 71 import sys
72 72 from IPython.core.display import display, HTML, Math
73 73
74 74 print("stdout")
75 75 print("stderr", file=sys.stderr)
76 76
77 77 display(HTML("<b>HTML</b>"))
78 78
79 79 print("stdout2")
80 80 print("stderr2", file=sys.stderr)
81 81
82 82 display(Math(r"\alpha=\beta"))
83 83
84 84 return Math("42")
85 85
86 86 # test decorator for skipping tests when libraries are unavailable
87 87 def skip_without(*names):
88 88 """skip a test if some names are not importable"""
89 89 @decorator
90 90 def skip_without_names(f, *args, **kwargs):
91 91 """decorator to skip tests in the absence of numpy."""
92 92 for name in names:
93 93 try:
94 94 __import__(name)
95 95 except ImportError:
96 96 raise SkipTest
97 97 return f(*args, **kwargs)
98 98 return skip_without_names
99 99
100 100 #-------------------------------------------------------------------------------
101 101 # Classes
102 102 #-------------------------------------------------------------------------------
103 103
104 104
105 105 class ClusterTestCase(BaseZMQTestCase):
106 106
107 107 def add_engines(self, n=1, block=True):
108 108 """add multiple engines to our cluster"""
109 109 self.engines.extend(add_engines(n))
110 110 if block:
111 111 self.wait_on_engines()
112 112
113 113 def minimum_engines(self, n=1, block=True):
114 114 """add engines until there are at least n connected"""
115 115 self.engines.extend(add_engines(n, total=True))
116 116 if block:
117 117 self.wait_on_engines()
118 118
119 119
120 120 def wait_on_engines(self, timeout=5):
121 121 """wait for our engines to connect."""
122 122 n = len(self.engines)+self.base_engine_count
123 123 tic = time.time()
124 124 while time.time()-tic < timeout and len(self.client.ids) < n:
125 125 time.sleep(0.1)
126 126
127 127 assert not len(self.client.ids) < n, "waiting for engines timed out"
128 128
129 129 def connect_client(self):
130 130 """connect a client with my Context, and track its sockets for cleanup"""
131 131 c = Client(profile='iptest', context=self.context)
132 132 for name in filter(lambda n:n.endswith('socket'), dir(c)):
133 133 s = getattr(c, name)
134 134 s.setsockopt(zmq.LINGER, 0)
135 135 self.sockets.append(s)
136 136 return c
137 137
138 138 def assertRaisesRemote(self, etype, f, *args, **kwargs):
139 139 try:
140 140 try:
141 141 f(*args, **kwargs)
142 142 except error.CompositeError as e:
143 143 e.raise_exception()
144 144 except error.RemoteError as e:
145 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(etype.__name__, e.ename))
145 self.assertEqual(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(etype.__name__, e.ename))
146 146 else:
147 147 self.fail("should have raised a RemoteError")
148 148
149 149 def _wait_for(self, f, timeout=10):
150 150 """wait for a condition"""
151 151 tic = time.time()
152 152 while time.time() <= tic + timeout:
153 153 if f():
154 154 return
155 155 time.sleep(0.1)
156 156 self.client.spin()
157 157 if not f():
158 158 print("Warning: Awaited condition never arrived")
159 159
160 160 def setUp(self):
161 161 BaseZMQTestCase.setUp(self)
162 162 self.client = self.connect_client()
163 163 # start every test with clean engine namespaces:
164 164 self.client.clear(block=True)
165 165 self.base_engine_count=len(self.client.ids)
166 166 self.engines=[]
167 167
168 168 def tearDown(self):
169 169 # self.client.clear(block=True)
170 170 # close fds:
171 171 for e in filter(lambda e: e.poll() is not None, launchers):
172 172 launchers.remove(e)
173 173
174 174 # allow flushing of incoming messages to prevent crash on socket close
175 175 self.client.wait(timeout=2)
176 176 # time.sleep(2)
177 177 self.client.spin()
178 178 self.client.close()
179 179 BaseZMQTestCase.tearDown(self)
180 180 # this will be redundant when pyzmq merges PR #88
181 181 # self.context.term()
182 182 # print tempfile.TemporaryFile().fileno(),
183 183 # sys.stdout.flush()
184 184
@@ -1,267 +1,267 b''
1 1 """Tests for asyncresult.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import time
20 20
21 21 from IPython.utils.io import capture_output
22 22
23 23 from IPython.parallel.error import TimeoutError
24 24 from IPython.parallel import error, Client
25 25 from IPython.parallel.tests import add_engines
26 26 from .clienttest import ClusterTestCase
27 27
28 28 def setup():
29 29 add_engines(2, total=True)
30 30
31 31 def wait(n):
32 32 import time
33 33 time.sleep(n)
34 34 return n
35 35
36 36 class AsyncResultTest(ClusterTestCase):
37 37
38 38 def test_single_result_view(self):
39 39 """various one-target views get the right value for single_result"""
40 40 eid = self.client.ids[-1]
41 41 ar = self.client[eid].apply_async(lambda : 42)
42 self.assertEquals(ar.get(), 42)
42 self.assertEqual(ar.get(), 42)
43 43 ar = self.client[[eid]].apply_async(lambda : 42)
44 self.assertEquals(ar.get(), [42])
44 self.assertEqual(ar.get(), [42])
45 45 ar = self.client[-1:].apply_async(lambda : 42)
46 self.assertEquals(ar.get(), [42])
46 self.assertEqual(ar.get(), [42])
47 47
48 48 def test_get_after_done(self):
49 49 ar = self.client[-1].apply_async(lambda : 42)
50 50 ar.wait()
51 51 self.assertTrue(ar.ready())
52 self.assertEquals(ar.get(), 42)
53 self.assertEquals(ar.get(), 42)
52 self.assertEqual(ar.get(), 42)
53 self.assertEqual(ar.get(), 42)
54 54
55 55 def test_get_before_done(self):
56 56 ar = self.client[-1].apply_async(wait, 0.1)
57 57 self.assertRaises(TimeoutError, ar.get, 0)
58 58 ar.wait(0)
59 59 self.assertFalse(ar.ready())
60 self.assertEquals(ar.get(), 0.1)
60 self.assertEqual(ar.get(), 0.1)
61 61
62 62 def test_get_after_error(self):
63 63 ar = self.client[-1].apply_async(lambda : 1/0)
64 64 ar.wait(10)
65 65 self.assertRaisesRemote(ZeroDivisionError, ar.get)
66 66 self.assertRaisesRemote(ZeroDivisionError, ar.get)
67 67 self.assertRaisesRemote(ZeroDivisionError, ar.get_dict)
68 68
69 69 def test_get_dict(self):
70 70 n = len(self.client)
71 71 ar = self.client[:].apply_async(lambda : 5)
72 self.assertEquals(ar.get(), [5]*n)
72 self.assertEqual(ar.get(), [5]*n)
73 73 d = ar.get_dict()
74 self.assertEquals(sorted(d.keys()), sorted(self.client.ids))
74 self.assertEqual(sorted(d.keys()), sorted(self.client.ids))
75 75 for eid,r in d.iteritems():
76 self.assertEquals(r, 5)
76 self.assertEqual(r, 5)
77 77
78 78 def test_list_amr(self):
79 79 ar = self.client.load_balanced_view().map_async(wait, [0.1]*5)
80 80 rlist = list(ar)
81 81
82 82 def test_getattr(self):
83 83 ar = self.client[:].apply_async(wait, 0.5)
84 84 self.assertRaises(AttributeError, lambda : ar._foo)
85 85 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
86 86 self.assertRaises(AttributeError, lambda : ar.foo)
87 87 self.assertRaises(AttributeError, lambda : ar.engine_id)
88 88 self.assertFalse(hasattr(ar, '__length_hint__'))
89 89 self.assertFalse(hasattr(ar, 'foo'))
90 90 self.assertFalse(hasattr(ar, 'engine_id'))
91 91 ar.get(5)
92 92 self.assertRaises(AttributeError, lambda : ar._foo)
93 93 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
94 94 self.assertRaises(AttributeError, lambda : ar.foo)
95 95 self.assertTrue(isinstance(ar.engine_id, list))
96 self.assertEquals(ar.engine_id, ar['engine_id'])
96 self.assertEqual(ar.engine_id, ar['engine_id'])
97 97 self.assertFalse(hasattr(ar, '__length_hint__'))
98 98 self.assertFalse(hasattr(ar, 'foo'))
99 99 self.assertTrue(hasattr(ar, 'engine_id'))
100 100
101 101 def test_getitem(self):
102 102 ar = self.client[:].apply_async(wait, 0.5)
103 103 self.assertRaises(TimeoutError, lambda : ar['foo'])
104 104 self.assertRaises(TimeoutError, lambda : ar['engine_id'])
105 105 ar.get(5)
106 106 self.assertRaises(KeyError, lambda : ar['foo'])
107 107 self.assertTrue(isinstance(ar['engine_id'], list))
108 self.assertEquals(ar.engine_id, ar['engine_id'])
108 self.assertEqual(ar.engine_id, ar['engine_id'])
109 109
110 110 def test_single_result(self):
111 111 ar = self.client[-1].apply_async(wait, 0.5)
112 112 self.assertRaises(TimeoutError, lambda : ar['foo'])
113 113 self.assertRaises(TimeoutError, lambda : ar['engine_id'])
114 114 self.assertTrue(ar.get(5) == 0.5)
115 115 self.assertTrue(isinstance(ar['engine_id'], int))
116 116 self.assertTrue(isinstance(ar.engine_id, int))
117 self.assertEquals(ar.engine_id, ar['engine_id'])
117 self.assertEqual(ar.engine_id, ar['engine_id'])
118 118
119 119 def test_abort(self):
120 120 e = self.client[-1]
121 121 ar = e.execute('import time; time.sleep(1)', block=False)
122 122 ar2 = e.apply_async(lambda : 2)
123 123 ar2.abort()
124 124 self.assertRaises(error.TaskAborted, ar2.get)
125 125 ar.get()
126 126
127 127 def test_len(self):
128 128 v = self.client.load_balanced_view()
129 129 ar = v.map_async(lambda x: x, range(10))
130 self.assertEquals(len(ar), 10)
130 self.assertEqual(len(ar), 10)
131 131 ar = v.apply_async(lambda x: x, range(10))
132 self.assertEquals(len(ar), 1)
132 self.assertEqual(len(ar), 1)
133 133 ar = self.client[:].apply_async(lambda x: x, range(10))
134 self.assertEquals(len(ar), len(self.client.ids))
134 self.assertEqual(len(ar), len(self.client.ids))
135 135
136 136 def test_wall_time_single(self):
137 137 v = self.client.load_balanced_view()
138 138 ar = v.apply_async(time.sleep, 0.25)
139 139 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
140 140 ar.get(2)
141 141 self.assertTrue(ar.wall_time < 1.)
142 142 self.assertTrue(ar.wall_time > 0.2)
143 143
144 144 def test_wall_time_multi(self):
145 145 self.minimum_engines(4)
146 146 v = self.client[:]
147 147 ar = v.apply_async(time.sleep, 0.25)
148 148 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
149 149 ar.get(2)
150 150 self.assertTrue(ar.wall_time < 1.)
151 151 self.assertTrue(ar.wall_time > 0.2)
152 152
153 153 def test_serial_time_single(self):
154 154 v = self.client.load_balanced_view()
155 155 ar = v.apply_async(time.sleep, 0.25)
156 156 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
157 157 ar.get(2)
158 158 self.assertTrue(ar.serial_time < 1.)
159 159 self.assertTrue(ar.serial_time > 0.2)
160 160
161 161 def test_serial_time_multi(self):
162 162 self.minimum_engines(4)
163 163 v = self.client[:]
164 164 ar = v.apply_async(time.sleep, 0.25)
165 165 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
166 166 ar.get(2)
167 167 self.assertTrue(ar.serial_time < 2.)
168 168 self.assertTrue(ar.serial_time > 0.8)
169 169
170 170 def test_elapsed_single(self):
171 171 v = self.client.load_balanced_view()
172 172 ar = v.apply_async(time.sleep, 0.25)
173 173 while not ar.ready():
174 174 time.sleep(0.01)
175 175 self.assertTrue(ar.elapsed < 1)
176 176 self.assertTrue(ar.elapsed < 1)
177 177 ar.get(2)
178 178
179 179 def test_elapsed_multi(self):
180 180 v = self.client[:]
181 181 ar = v.apply_async(time.sleep, 0.25)
182 182 while not ar.ready():
183 183 time.sleep(0.01)
184 184 self.assertTrue(ar.elapsed < 1)
185 185 self.assertTrue(ar.elapsed < 1)
186 186 ar.get(2)
187 187
188 188 def test_hubresult_timestamps(self):
189 189 self.minimum_engines(4)
190 190 v = self.client[:]
191 191 ar = v.apply_async(time.sleep, 0.25)
192 192 ar.get(2)
193 193 rc2 = Client(profile='iptest')
194 194 # must have try/finally to close second Client, otherwise
195 195 # will have dangling sockets causing problems
196 196 try:
197 197 time.sleep(0.25)
198 198 hr = rc2.get_result(ar.msg_ids)
199 199 self.assertTrue(hr.elapsed > 0., "got bad elapsed: %s" % hr.elapsed)
200 200 hr.get(1)
201 201 self.assertTrue(hr.wall_time < ar.wall_time + 0.2, "got bad wall_time: %s > %s" % (hr.wall_time, ar.wall_time))
202 self.assertEquals(hr.serial_time, ar.serial_time)
202 self.assertEqual(hr.serial_time, ar.serial_time)
203 203 finally:
204 204 rc2.close()
205 205
206 206 def test_display_empty_streams_single(self):
207 207 """empty stdout/err are not displayed (single result)"""
208 208 self.minimum_engines(1)
209 209
210 210 v = self.client[-1]
211 211 ar = v.execute("print (5555)")
212 212 ar.get(5)
213 213 with capture_output() as io:
214 214 ar.display_outputs()
215 self.assertEquals(io.stderr, '')
216 self.assertEquals('5555\n', io.stdout)
215 self.assertEqual(io.stderr, '')
216 self.assertEqual('5555\n', io.stdout)
217 217
218 218 ar = v.execute("a=5")
219 219 ar.get(5)
220 220 with capture_output() as io:
221 221 ar.display_outputs()
222 self.assertEquals(io.stderr, '')
223 self.assertEquals(io.stdout, '')
222 self.assertEqual(io.stderr, '')
223 self.assertEqual(io.stdout, '')
224 224
225 225 def test_display_empty_streams_type(self):
226 226 """empty stdout/err are not displayed (groupby type)"""
227 227 self.minimum_engines(1)
228 228
229 229 v = self.client[:]
230 230 ar = v.execute("print (5555)")
231 231 ar.get(5)
232 232 with capture_output() as io:
233 233 ar.display_outputs()
234 self.assertEquals(io.stderr, '')
235 self.assertEquals(io.stdout.count('5555'), len(v), io.stdout)
234 self.assertEqual(io.stderr, '')
235 self.assertEqual(io.stdout.count('5555'), len(v), io.stdout)
236 236 self.assertFalse('\n\n' in io.stdout, io.stdout)
237 self.assertEquals(io.stdout.count('[stdout:'), len(v), io.stdout)
237 self.assertEqual(io.stdout.count('[stdout:'), len(v), io.stdout)
238 238
239 239 ar = v.execute("a=5")
240 240 ar.get(5)
241 241 with capture_output() as io:
242 242 ar.display_outputs()
243 self.assertEquals(io.stderr, '')
244 self.assertEquals(io.stdout, '')
243 self.assertEqual(io.stderr, '')
244 self.assertEqual(io.stdout, '')
245 245
246 246 def test_display_empty_streams_engine(self):
247 247 """empty stdout/err are not displayed (groupby engine)"""
248 248 self.minimum_engines(1)
249 249
250 250 v = self.client[:]
251 251 ar = v.execute("print (5555)")
252 252 ar.get(5)
253 253 with capture_output() as io:
254 254 ar.display_outputs('engine')
255 self.assertEquals(io.stderr, '')
256 self.assertEquals(io.stdout.count('5555'), len(v), io.stdout)
255 self.assertEqual(io.stderr, '')
256 self.assertEqual(io.stdout.count('5555'), len(v), io.stdout)
257 257 self.assertFalse('\n\n' in io.stdout, io.stdout)
258 self.assertEquals(io.stdout.count('[stdout:'), len(v), io.stdout)
258 self.assertEqual(io.stdout.count('[stdout:'), len(v), io.stdout)
259 259
260 260 ar = v.execute("a=5")
261 261 ar.get(5)
262 262 with capture_output() as io:
263 263 ar.display_outputs('engine')
264 self.assertEquals(io.stderr, '')
265 self.assertEquals(io.stdout, '')
264 self.assertEqual(io.stderr, '')
265 self.assertEqual(io.stdout, '')
266 266
267 267
@@ -1,455 +1,455 b''
1 1 """Tests for parallel client.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import time
22 22 from datetime import datetime
23 23 from tempfile import mktemp
24 24
25 25 import zmq
26 26
27 27 from IPython import parallel
28 28 from IPython.parallel.client import client as clientmod
29 29 from IPython.parallel import error
30 30 from IPython.parallel import AsyncResult, AsyncHubResult
31 31 from IPython.parallel import LoadBalancedView, DirectView
32 32
33 33 from clienttest import ClusterTestCase, segfault, wait, add_engines
34 34
35 35 def setup():
36 36 add_engines(4, total=True)
37 37
38 38 class TestClient(ClusterTestCase):
39 39
40 40 def test_ids(self):
41 41 n = len(self.client.ids)
42 42 self.add_engines(2)
43 self.assertEquals(len(self.client.ids), n+2)
43 self.assertEqual(len(self.client.ids), n+2)
44 44
45 45 def test_view_indexing(self):
46 46 """test index access for views"""
47 47 self.minimum_engines(4)
48 48 targets = self.client._build_targets('all')[-1]
49 49 v = self.client[:]
50 self.assertEquals(v.targets, targets)
50 self.assertEqual(v.targets, targets)
51 51 t = self.client.ids[2]
52 52 v = self.client[t]
53 53 self.assert_(isinstance(v, DirectView))
54 self.assertEquals(v.targets, t)
54 self.assertEqual(v.targets, t)
55 55 t = self.client.ids[2:4]
56 56 v = self.client[t]
57 57 self.assert_(isinstance(v, DirectView))
58 self.assertEquals(v.targets, t)
58 self.assertEqual(v.targets, t)
59 59 v = self.client[::2]
60 60 self.assert_(isinstance(v, DirectView))
61 self.assertEquals(v.targets, targets[::2])
61 self.assertEqual(v.targets, targets[::2])
62 62 v = self.client[1::3]
63 63 self.assert_(isinstance(v, DirectView))
64 self.assertEquals(v.targets, targets[1::3])
64 self.assertEqual(v.targets, targets[1::3])
65 65 v = self.client[:-3]
66 66 self.assert_(isinstance(v, DirectView))
67 self.assertEquals(v.targets, targets[:-3])
67 self.assertEqual(v.targets, targets[:-3])
68 68 v = self.client[-1]
69 69 self.assert_(isinstance(v, DirectView))
70 self.assertEquals(v.targets, targets[-1])
70 self.assertEqual(v.targets, targets[-1])
71 71 self.assertRaises(TypeError, lambda : self.client[None])
72 72
73 73 def test_lbview_targets(self):
74 74 """test load_balanced_view targets"""
75 75 v = self.client.load_balanced_view()
76 self.assertEquals(v.targets, None)
76 self.assertEqual(v.targets, None)
77 77 v = self.client.load_balanced_view(-1)
78 self.assertEquals(v.targets, [self.client.ids[-1]])
78 self.assertEqual(v.targets, [self.client.ids[-1]])
79 79 v = self.client.load_balanced_view('all')
80 self.assertEquals(v.targets, None)
80 self.assertEqual(v.targets, None)
81 81
82 82 def test_dview_targets(self):
83 83 """test direct_view targets"""
84 84 v = self.client.direct_view()
85 self.assertEquals(v.targets, 'all')
85 self.assertEqual(v.targets, 'all')
86 86 v = self.client.direct_view('all')
87 self.assertEquals(v.targets, 'all')
87 self.assertEqual(v.targets, 'all')
88 88 v = self.client.direct_view(-1)
89 self.assertEquals(v.targets, self.client.ids[-1])
89 self.assertEqual(v.targets, self.client.ids[-1])
90 90
91 91 def test_lazy_all_targets(self):
92 92 """test lazy evaluation of rc.direct_view('all')"""
93 93 v = self.client.direct_view()
94 self.assertEquals(v.targets, 'all')
94 self.assertEqual(v.targets, 'all')
95 95
96 96 def double(x):
97 97 return x*2
98 98 seq = range(100)
99 99 ref = [ double(x) for x in seq ]
100 100
101 101 # add some engines, which should be used
102 102 self.add_engines(1)
103 103 n1 = len(self.client.ids)
104 104
105 105 # simple apply
106 106 r = v.apply_sync(lambda : 1)
107 self.assertEquals(r, [1] * n1)
107 self.assertEqual(r, [1] * n1)
108 108
109 109 # map goes through remotefunction
110 110 r = v.map_sync(double, seq)
111 self.assertEquals(r, ref)
111 self.assertEqual(r, ref)
112 112
113 113 # add a couple more engines, and try again
114 114 self.add_engines(2)
115 115 n2 = len(self.client.ids)
116 116 self.assertNotEquals(n2, n1)
117 117
118 118 # apply
119 119 r = v.apply_sync(lambda : 1)
120 self.assertEquals(r, [1] * n2)
120 self.assertEqual(r, [1] * n2)
121 121
122 122 # map
123 123 r = v.map_sync(double, seq)
124 self.assertEquals(r, ref)
124 self.assertEqual(r, ref)
125 125
126 126 def test_targets(self):
127 127 """test various valid targets arguments"""
128 128 build = self.client._build_targets
129 129 ids = self.client.ids
130 130 idents,targets = build(None)
131 self.assertEquals(ids, targets)
131 self.assertEqual(ids, targets)
132 132
133 133 def test_clear(self):
134 134 """test clear behavior"""
135 135 self.minimum_engines(2)
136 136 v = self.client[:]
137 137 v.block=True
138 138 v.push(dict(a=5))
139 139 v.pull('a')
140 140 id0 = self.client.ids[-1]
141 141 self.client.clear(targets=id0, block=True)
142 142 a = self.client[:-1].get('a')
143 143 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
144 144 self.client.clear(block=True)
145 145 for i in self.client.ids:
146 146 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
147 147
148 148 def test_get_result(self):
149 149 """test getting results from the Hub."""
150 150 c = clientmod.Client(profile='iptest')
151 151 t = c.ids[-1]
152 152 ar = c[t].apply_async(wait, 1)
153 153 # give the monitor time to notice the message
154 154 time.sleep(.25)
155 155 ahr = self.client.get_result(ar.msg_ids)
156 156 self.assertTrue(isinstance(ahr, AsyncHubResult))
157 self.assertEquals(ahr.get(), ar.get())
157 self.assertEqual(ahr.get(), ar.get())
158 158 ar2 = self.client.get_result(ar.msg_ids)
159 159 self.assertFalse(isinstance(ar2, AsyncHubResult))
160 160 c.close()
161 161
162 162 def test_get_execute_result(self):
163 163 """test getting execute results from the Hub."""
164 164 c = clientmod.Client(profile='iptest')
165 165 t = c.ids[-1]
166 166 cell = '\n'.join([
167 167 'import time',
168 168 'time.sleep(0.25)',
169 169 '5'
170 170 ])
171 171 ar = c[t].execute("import time; time.sleep(1)", silent=False)
172 172 # give the monitor time to notice the message
173 173 time.sleep(.25)
174 174 ahr = self.client.get_result(ar.msg_ids)
175 175 self.assertTrue(isinstance(ahr, AsyncHubResult))
176 self.assertEquals(ahr.get().pyout, ar.get().pyout)
176 self.assertEqual(ahr.get().pyout, ar.get().pyout)
177 177 ar2 = self.client.get_result(ar.msg_ids)
178 178 self.assertFalse(isinstance(ar2, AsyncHubResult))
179 179 c.close()
180 180
181 181 def test_ids_list(self):
182 182 """test client.ids"""
183 183 ids = self.client.ids
184 self.assertEquals(ids, self.client._ids)
184 self.assertEqual(ids, self.client._ids)
185 185 self.assertFalse(ids is self.client._ids)
186 186 ids.remove(ids[-1])
187 187 self.assertNotEquals(ids, self.client._ids)
188 188
189 189 def test_queue_status(self):
190 190 ids = self.client.ids
191 191 id0 = ids[0]
192 192 qs = self.client.queue_status(targets=id0)
193 193 self.assertTrue(isinstance(qs, dict))
194 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
194 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
195 195 allqs = self.client.queue_status()
196 196 self.assertTrue(isinstance(allqs, dict))
197 197 intkeys = list(allqs.keys())
198 198 intkeys.remove('unassigned')
199 self.assertEquals(sorted(intkeys), sorted(self.client.ids))
199 self.assertEqual(sorted(intkeys), sorted(self.client.ids))
200 200 unassigned = allqs.pop('unassigned')
201 201 for eid,qs in allqs.items():
202 202 self.assertTrue(isinstance(qs, dict))
203 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
203 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
204 204
205 205 def test_shutdown(self):
206 206 ids = self.client.ids
207 207 id0 = ids[0]
208 208 self.client.shutdown(id0, block=True)
209 209 while id0 in self.client.ids:
210 210 time.sleep(0.1)
211 211 self.client.spin()
212 212
213 213 self.assertRaises(IndexError, lambda : self.client[id0])
214 214
215 215 def test_result_status(self):
216 216 pass
217 217 # to be written
218 218
219 219 def test_db_query_dt(self):
220 220 """test db query by date"""
221 221 hist = self.client.hub_history()
222 222 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
223 223 tic = middle['submitted']
224 224 before = self.client.db_query({'submitted' : {'$lt' : tic}})
225 225 after = self.client.db_query({'submitted' : {'$gte' : tic}})
226 self.assertEquals(len(before)+len(after),len(hist))
226 self.assertEqual(len(before)+len(after),len(hist))
227 227 for b in before:
228 228 self.assertTrue(b['submitted'] < tic)
229 229 for a in after:
230 230 self.assertTrue(a['submitted'] >= tic)
231 231 same = self.client.db_query({'submitted' : tic})
232 232 for s in same:
233 233 self.assertTrue(s['submitted'] == tic)
234 234
235 235 def test_db_query_keys(self):
236 236 """test extracting subset of record keys"""
237 237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
238 238 for rec in found:
239 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
239 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
240 240
241 241 def test_db_query_default_keys(self):
242 242 """default db_query excludes buffers"""
243 243 found = self.client.db_query({'msg_id': {'$ne' : ''}})
244 244 for rec in found:
245 245 keys = set(rec.keys())
246 246 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
247 247 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
248 248
249 249 def test_db_query_msg_id(self):
250 250 """ensure msg_id is always in db queries"""
251 251 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
252 252 for rec in found:
253 253 self.assertTrue('msg_id' in rec.keys())
254 254 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
255 255 for rec in found:
256 256 self.assertTrue('msg_id' in rec.keys())
257 257 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
258 258 for rec in found:
259 259 self.assertTrue('msg_id' in rec.keys())
260 260
261 261 def test_db_query_get_result(self):
262 262 """pop in db_query shouldn't pop from result itself"""
263 263 self.client[:].apply_sync(lambda : 1)
264 264 found = self.client.db_query({'msg_id': {'$ne' : ''}})
265 265 rc2 = clientmod.Client(profile='iptest')
266 266 # If this bug is not fixed, this call will hang:
267 267 ar = rc2.get_result(self.client.history[-1])
268 268 ar.wait(2)
269 269 self.assertTrue(ar.ready())
270 270 ar.get()
271 271 rc2.close()
272 272
273 273 def test_db_query_in(self):
274 274 """test db query with '$in','$nin' operators"""
275 275 hist = self.client.hub_history()
276 276 even = hist[::2]
277 277 odd = hist[1::2]
278 278 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
279 279 found = [ r['msg_id'] for r in recs ]
280 self.assertEquals(set(even), set(found))
280 self.assertEqual(set(even), set(found))
281 281 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
282 282 found = [ r['msg_id'] for r in recs ]
283 self.assertEquals(set(odd), set(found))
283 self.assertEqual(set(odd), set(found))
284 284
285 285 def test_hub_history(self):
286 286 hist = self.client.hub_history()
287 287 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
288 288 recdict = {}
289 289 for rec in recs:
290 290 recdict[rec['msg_id']] = rec
291 291
292 292 latest = datetime(1984,1,1)
293 293 for msg_id in hist:
294 294 rec = recdict[msg_id]
295 295 newt = rec['submitted']
296 296 self.assertTrue(newt >= latest)
297 297 latest = newt
298 298 ar = self.client[-1].apply_async(lambda : 1)
299 299 ar.get()
300 300 time.sleep(0.25)
301 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
301 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
302 302
303 303 def _wait_for_idle(self):
304 304 """wait for an engine to become idle, according to the Hub"""
305 305 rc = self.client
306 306
307 307 # timeout 5s, polling every 100ms
308 308 qs = rc.queue_status()
309 309 for i in range(50):
310 310 if qs['unassigned'] or any(qs[eid]['tasks'] for eid in rc.ids):
311 311 time.sleep(0.1)
312 312 qs = rc.queue_status()
313 313 else:
314 314 break
315 315
316 316 # ensure Hub up to date:
317 self.assertEquals(qs['unassigned'], 0)
317 self.assertEqual(qs['unassigned'], 0)
318 318 for eid in rc.ids:
319 self.assertEquals(qs[eid]['tasks'], 0)
319 self.assertEqual(qs[eid]['tasks'], 0)
320 320
321 321
322 322 def test_resubmit(self):
323 323 def f():
324 324 import random
325 325 return random.random()
326 326 v = self.client.load_balanced_view()
327 327 ar = v.apply_async(f)
328 328 r1 = ar.get(1)
329 329 # give the Hub a chance to notice:
330 330 self._wait_for_idle()
331 331 ahr = self.client.resubmit(ar.msg_ids)
332 332 r2 = ahr.get(1)
333 333 self.assertFalse(r1 == r2)
334 334
335 335 def test_resubmit_chain(self):
336 336 """resubmit resubmitted tasks"""
337 337 v = self.client.load_balanced_view()
338 338 ar = v.apply_async(lambda x: x, 'x'*1024)
339 339 ar.get()
340 340 self._wait_for_idle()
341 341 ars = [ar]
342 342
343 343 for i in range(10):
344 344 ar = ars[-1]
345 345 ar2 = self.client.resubmit(ar.msg_ids)
346 346
347 347 [ ar.get() for ar in ars ]
348 348
349 349 def test_resubmit_header(self):
350 350 """resubmit shouldn't clobber the whole header"""
351 351 def f():
352 352 import random
353 353 return random.random()
354 354 v = self.client.load_balanced_view()
355 355 v.retries = 1
356 356 ar = v.apply_async(f)
357 357 r1 = ar.get(1)
358 358 # give the Hub a chance to notice:
359 359 self._wait_for_idle()
360 360 ahr = self.client.resubmit(ar.msg_ids)
361 361 ahr.get(1)
362 362 time.sleep(0.5)
363 363 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
364 364 h1,h2 = [ r['header'] for r in records ]
365 365 for key in set(h1.keys()).union(set(h2.keys())):
366 366 if key in ('msg_id', 'date'):
367 367 self.assertNotEquals(h1[key], h2[key])
368 368 else:
369 self.assertEquals(h1[key], h2[key])
369 self.assertEqual(h1[key], h2[key])
370 370
371 371 def test_resubmit_aborted(self):
372 372 def f():
373 373 import random
374 374 return random.random()
375 375 v = self.client.load_balanced_view()
376 376 # restrict to one engine, so we can put a sleep
377 377 # ahead of the task, so it will get aborted
378 378 eid = self.client.ids[-1]
379 379 v.targets = [eid]
380 380 sleep = v.apply_async(time.sleep, 0.5)
381 381 ar = v.apply_async(f)
382 382 ar.abort()
383 383 self.assertRaises(error.TaskAborted, ar.get)
384 384 # Give the Hub a chance to get up to date:
385 385 self._wait_for_idle()
386 386 ahr = self.client.resubmit(ar.msg_ids)
387 387 r2 = ahr.get(1)
388 388
389 389 def test_resubmit_inflight(self):
390 390 """resubmit of inflight task"""
391 391 v = self.client.load_balanced_view()
392 392 ar = v.apply_async(time.sleep,1)
393 393 # give the message a chance to arrive
394 394 time.sleep(0.2)
395 395 ahr = self.client.resubmit(ar.msg_ids)
396 396 ar.get(2)
397 397 ahr.get(2)
398 398
399 399 def test_resubmit_badkey(self):
400 400 """ensure KeyError on resubmit of nonexistant task"""
401 401 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
402 402
403 403 def test_purge_results(self):
404 404 # ensure there are some tasks
405 405 for i in range(5):
406 406 self.client[:].apply_sync(lambda : 1)
407 407 # Wait for the Hub to realise the result is done:
408 408 # This prevents a race condition, where we
409 409 # might purge a result the Hub still thinks is pending.
410 410 time.sleep(0.1)
411 411 rc2 = clientmod.Client(profile='iptest')
412 412 hist = self.client.hub_history()
413 413 ahr = rc2.get_result([hist[-1]])
414 414 ahr.wait(10)
415 415 self.client.purge_results(hist[-1])
416 416 newhist = self.client.hub_history()
417 self.assertEquals(len(newhist)+1,len(hist))
417 self.assertEqual(len(newhist)+1,len(hist))
418 418 rc2.spin()
419 419 rc2.close()
420 420
421 421 def test_purge_all_results(self):
422 422 self.client.purge_results('all')
423 423 hist = self.client.hub_history()
424 self.assertEquals(len(hist), 0)
424 self.assertEqual(len(hist), 0)
425 425
426 426 def test_spin_thread(self):
427 427 self.client.spin_thread(0.01)
428 428 ar = self.client[-1].apply_async(lambda : 1)
429 429 time.sleep(0.1)
430 430 self.assertTrue(ar.wall_time < 0.1,
431 431 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
432 432 )
433 433
434 434 def test_stop_spin_thread(self):
435 435 self.client.spin_thread(0.01)
436 436 self.client.stop_spin_thread()
437 437 ar = self.client[-1].apply_async(lambda : 1)
438 438 time.sleep(0.15)
439 439 self.assertTrue(ar.wall_time > 0.1,
440 440 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
441 441 )
442 442
443 443 def test_activate(self):
444 444 ip = get_ipython()
445 445 magics = ip.magics_manager.magics
446 446 self.assertTrue('px' in magics['line'])
447 447 self.assertTrue('px' in magics['cell'])
448 448 v0 = self.client.activate(-1, '0')
449 449 self.assertTrue('px0' in magics['line'])
450 450 self.assertTrue('px0' in magics['cell'])
451 self.assertEquals(v0.targets, self.client.ids[-1])
451 self.assertEqual(v0.targets, self.client.ids[-1])
452 452 v0 = self.client.activate('all', 'all')
453 453 self.assertTrue('pxall' in magics['line'])
454 454 self.assertTrue('pxall' in magics['cell'])
455 self.assertEquals(v0.targets, 'all')
455 self.assertEqual(v0.targets, 'all')
@@ -1,249 +1,249 b''
1 1 """Tests for db backends
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import logging
22 22 import os
23 23 import tempfile
24 24 import time
25 25
26 26 from datetime import datetime, timedelta
27 27 from unittest import TestCase
28 28
29 29 from IPython.parallel import error
30 30 from IPython.parallel.controller.dictdb import DictDB
31 31 from IPython.parallel.controller.sqlitedb import SQLiteDB
32 32 from IPython.parallel.controller.hub import init_record, empty_record
33 33
34 34 from IPython.testing import decorators as dec
35 35 from IPython.zmq.session import Session
36 36
37 37
38 38 #-------------------------------------------------------------------------------
39 39 # TestCases
40 40 #-------------------------------------------------------------------------------
41 41
42 42
43 43 def setup():
44 44 global temp_db
45 45 temp_db = tempfile.NamedTemporaryFile(suffix='.db').name
46 46
47 47
48 48 class TestDictBackend(TestCase):
49 49 def setUp(self):
50 50 self.session = Session()
51 51 self.db = self.create_db()
52 52 self.load_records(16)
53 53
54 54 def create_db(self):
55 55 return DictDB()
56 56
57 57 def load_records(self, n=1):
58 58 """load n records for testing"""
59 59 #sleep 1/10 s, to ensure timestamp is different to previous calls
60 60 time.sleep(0.1)
61 61 msg_ids = []
62 62 for i in range(n):
63 63 msg = self.session.msg('apply_request', content=dict(a=5))
64 64 msg['buffers'] = []
65 65 rec = init_record(msg)
66 66 msg_id = msg['header']['msg_id']
67 67 msg_ids.append(msg_id)
68 68 self.db.add_record(msg_id, rec)
69 69 return msg_ids
70 70
71 71 def test_add_record(self):
72 72 before = self.db.get_history()
73 73 self.load_records(5)
74 74 after = self.db.get_history()
75 self.assertEquals(len(after), len(before)+5)
76 self.assertEquals(after[:-5],before)
75 self.assertEqual(len(after), len(before)+5)
76 self.assertEqual(after[:-5],before)
77 77
78 78 def test_drop_record(self):
79 79 msg_id = self.load_records()[-1]
80 80 rec = self.db.get_record(msg_id)
81 81 self.db.drop_record(msg_id)
82 82 self.assertRaises(KeyError,self.db.get_record, msg_id)
83 83
84 84 def _round_to_millisecond(self, dt):
85 85 """necessary because mongodb rounds microseconds"""
86 86 micro = dt.microsecond
87 87 extra = int(str(micro)[-3:])
88 88 return dt - timedelta(microseconds=extra)
89 89
90 90 def test_update_record(self):
91 91 now = self._round_to_millisecond(datetime.now())
92 92 #
93 93 msg_id = self.db.get_history()[-1]
94 94 rec1 = self.db.get_record(msg_id)
95 95 data = {'stdout': 'hello there', 'completed' : now}
96 96 self.db.update_record(msg_id, data)
97 97 rec2 = self.db.get_record(msg_id)
98 self.assertEquals(rec2['stdout'], 'hello there')
99 self.assertEquals(rec2['completed'], now)
98 self.assertEqual(rec2['stdout'], 'hello there')
99 self.assertEqual(rec2['completed'], now)
100 100 rec1.update(data)
101 self.assertEquals(rec1, rec2)
101 self.assertEqual(rec1, rec2)
102 102
103 103 # def test_update_record_bad(self):
104 104 # """test updating nonexistant records"""
105 105 # msg_id = str(uuid.uuid4())
106 106 # data = {'stdout': 'hello there'}
107 107 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
108 108
109 109 def test_find_records_dt(self):
110 110 """test finding records by date"""
111 111 hist = self.db.get_history()
112 112 middle = self.db.get_record(hist[len(hist)//2])
113 113 tic = middle['submitted']
114 114 before = self.db.find_records({'submitted' : {'$lt' : tic}})
115 115 after = self.db.find_records({'submitted' : {'$gte' : tic}})
116 self.assertEquals(len(before)+len(after),len(hist))
116 self.assertEqual(len(before)+len(after),len(hist))
117 117 for b in before:
118 118 self.assertTrue(b['submitted'] < tic)
119 119 for a in after:
120 120 self.assertTrue(a['submitted'] >= tic)
121 121 same = self.db.find_records({'submitted' : tic})
122 122 for s in same:
123 123 self.assertTrue(s['submitted'] == tic)
124 124
125 125 def test_find_records_keys(self):
126 126 """test extracting subset of record keys"""
127 127 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
128 128 for rec in found:
129 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
129 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
130 130
131 131 def test_find_records_msg_id(self):
132 132 """ensure msg_id is always in found records"""
133 133 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
134 134 for rec in found:
135 135 self.assertTrue('msg_id' in rec.keys())
136 136 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
137 137 for rec in found:
138 138 self.assertTrue('msg_id' in rec.keys())
139 139 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
140 140 for rec in found:
141 141 self.assertTrue('msg_id' in rec.keys())
142 142
143 143 def test_find_records_in(self):
144 144 """test finding records with '$in','$nin' operators"""
145 145 hist = self.db.get_history()
146 146 even = hist[::2]
147 147 odd = hist[1::2]
148 148 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
149 149 found = [ r['msg_id'] for r in recs ]
150 self.assertEquals(set(even), set(found))
150 self.assertEqual(set(even), set(found))
151 151 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
152 152 found = [ r['msg_id'] for r in recs ]
153 self.assertEquals(set(odd), set(found))
153 self.assertEqual(set(odd), set(found))
154 154
155 155 def test_get_history(self):
156 156 msg_ids = self.db.get_history()
157 157 latest = datetime(1984,1,1)
158 158 for msg_id in msg_ids:
159 159 rec = self.db.get_record(msg_id)
160 160 newt = rec['submitted']
161 161 self.assertTrue(newt >= latest)
162 162 latest = newt
163 163 msg_id = self.load_records(1)[-1]
164 self.assertEquals(self.db.get_history()[-1],msg_id)
164 self.assertEqual(self.db.get_history()[-1],msg_id)
165 165
166 166 def test_datetime(self):
167 167 """get/set timestamps with datetime objects"""
168 168 msg_id = self.db.get_history()[-1]
169 169 rec = self.db.get_record(msg_id)
170 170 self.assertTrue(isinstance(rec['submitted'], datetime))
171 171 self.db.update_record(msg_id, dict(completed=datetime.now()))
172 172 rec = self.db.get_record(msg_id)
173 173 self.assertTrue(isinstance(rec['completed'], datetime))
174 174
175 175 def test_drop_matching(self):
176 176 msg_ids = self.load_records(10)
177 177 query = {'msg_id' : {'$in':msg_ids}}
178 178 self.db.drop_matching_records(query)
179 179 recs = self.db.find_records(query)
180 self.assertEquals(len(recs), 0)
180 self.assertEqual(len(recs), 0)
181 181
182 182 def test_null(self):
183 183 """test None comparison queries"""
184 184 msg_ids = self.load_records(10)
185 185
186 186 query = {'msg_id' : None}
187 187 recs = self.db.find_records(query)
188 self.assertEquals(len(recs), 0)
188 self.assertEqual(len(recs), 0)
189 189
190 190 query = {'msg_id' : {'$ne' : None}}
191 191 recs = self.db.find_records(query)
192 192 self.assertTrue(len(recs) >= 10)
193 193
194 194 def test_pop_safe_get(self):
195 195 """editing query results shouldn't affect record [get]"""
196 196 msg_id = self.db.get_history()[-1]
197 197 rec = self.db.get_record(msg_id)
198 198 rec.pop('buffers')
199 199 rec['garbage'] = 'hello'
200 200 rec['header']['msg_id'] = 'fubar'
201 201 rec2 = self.db.get_record(msg_id)
202 202 self.assertTrue('buffers' in rec2)
203 203 self.assertFalse('garbage' in rec2)
204 self.assertEquals(rec2['header']['msg_id'], msg_id)
204 self.assertEqual(rec2['header']['msg_id'], msg_id)
205 205
206 206 def test_pop_safe_find(self):
207 207 """editing query results shouldn't affect record [find]"""
208 208 msg_id = self.db.get_history()[-1]
209 209 rec = self.db.find_records({'msg_id' : msg_id})[0]
210 210 rec.pop('buffers')
211 211 rec['garbage'] = 'hello'
212 212 rec['header']['msg_id'] = 'fubar'
213 213 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
214 214 self.assertTrue('buffers' in rec2)
215 215 self.assertFalse('garbage' in rec2)
216 self.assertEquals(rec2['header']['msg_id'], msg_id)
216 self.assertEqual(rec2['header']['msg_id'], msg_id)
217 217
218 218 def test_pop_safe_find_keys(self):
219 219 """editing query results shouldn't affect record [find+keys]"""
220 220 msg_id = self.db.get_history()[-1]
221 221 rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers', 'header'])[0]
222 222 rec.pop('buffers')
223 223 rec['garbage'] = 'hello'
224 224 rec['header']['msg_id'] = 'fubar'
225 225 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
226 226 self.assertTrue('buffers' in rec2)
227 227 self.assertFalse('garbage' in rec2)
228 self.assertEquals(rec2['header']['msg_id'], msg_id)
228 self.assertEqual(rec2['header']['msg_id'], msg_id)
229 229
230 230
231 231 class TestSQLiteBackend(TestDictBackend):
232 232
233 233 @dec.skip_without('sqlite3')
234 234 def create_db(self):
235 235 location, fname = os.path.split(temp_db)
236 236 log = logging.getLogger('test')
237 237 log.setLevel(logging.CRITICAL)
238 238 return SQLiteDB(location=location, fname=fname, log=log)
239 239
240 240 def tearDown(self):
241 241 self.db._db.close()
242 242
243 243
244 244 def teardown():
245 245 """cleanup task db file after all tests have run"""
246 246 try:
247 247 os.remove(temp_db)
248 248 except:
249 249 pass
@@ -1,106 +1,106 b''
1 1 """Tests for dependency.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 __docformat__ = "restructuredtext en"
9 9
10 10 #-------------------------------------------------------------------------------
11 11 # Copyright (C) 2011 The IPython Development Team
12 12 #
13 13 # Distributed under the terms of the BSD License. The full license is in
14 14 # the file COPYING, distributed as part of this software.
15 15 #-------------------------------------------------------------------------------
16 16
17 17 #-------------------------------------------------------------------------------
18 18 # Imports
19 19 #-------------------------------------------------------------------------------
20 20
21 21 # import
22 22 import os
23 23
24 24 from IPython.utils.pickleutil import can, uncan
25 25
26 26 import IPython.parallel as pmod
27 27 from IPython.parallel.util import interactive
28 28
29 29 from IPython.parallel.tests import add_engines
30 30 from .clienttest import ClusterTestCase
31 31
32 32 def setup():
33 33 add_engines(1, total=True)
34 34
35 35 @pmod.require('time')
36 36 def wait(n):
37 37 time.sleep(n)
38 38 return n
39 39
40 40 mixed = map(str, range(10))
41 41 completed = map(str, range(0,10,2))
42 42 failed = map(str, range(1,10,2))
43 43
44 44 class DependencyTest(ClusterTestCase):
45 45
46 46 def setUp(self):
47 47 ClusterTestCase.setUp(self)
48 48 self.user_ns = {'__builtins__' : __builtins__}
49 49 self.view = self.client.load_balanced_view()
50 50 self.dview = self.client[-1]
51 51 self.succeeded = set(map(str, range(0,25,2)))
52 52 self.failed = set(map(str, range(1,25,2)))
53 53
54 54 def assertMet(self, dep):
55 55 self.assertTrue(dep.check(self.succeeded, self.failed), "Dependency should be met")
56 56
57 57 def assertUnmet(self, dep):
58 58 self.assertFalse(dep.check(self.succeeded, self.failed), "Dependency should not be met")
59 59
60 60 def assertUnreachable(self, dep):
61 61 self.assertTrue(dep.unreachable(self.succeeded, self.failed), "Dependency should be unreachable")
62 62
63 63 def assertReachable(self, dep):
64 64 self.assertFalse(dep.unreachable(self.succeeded, self.failed), "Dependency should be reachable")
65 65
66 66 def cancan(self, f):
67 67 """decorator to pass through canning into self.user_ns"""
68 68 return uncan(can(f), self.user_ns)
69 69
70 70 def test_require_imports(self):
71 71 """test that @require imports names"""
72 72 @self.cancan
73 73 @pmod.require('urllib')
74 74 @interactive
75 75 def encode(dikt):
76 76 return urllib.urlencode(dikt)
77 77 # must pass through canning to properly connect namespaces
78 self.assertEquals(encode(dict(a=5)), 'a=5')
78 self.assertEqual(encode(dict(a=5)), 'a=5')
79 79
80 80 def test_success_only(self):
81 81 dep = pmod.Dependency(mixed, success=True, failure=False)
82 82 self.assertUnmet(dep)
83 83 self.assertUnreachable(dep)
84 84 dep.all=False
85 85 self.assertMet(dep)
86 86 self.assertReachable(dep)
87 87 dep = pmod.Dependency(completed, success=True, failure=False)
88 88 self.assertMet(dep)
89 89 self.assertReachable(dep)
90 90 dep.all=False
91 91 self.assertMet(dep)
92 92 self.assertReachable(dep)
93 93
94 94 def test_failure_only(self):
95 95 dep = pmod.Dependency(mixed, success=False, failure=True)
96 96 self.assertUnmet(dep)
97 97 self.assertUnreachable(dep)
98 98 dep.all=False
99 99 self.assertMet(dep)
100 100 self.assertReachable(dep)
101 101 dep = pmod.Dependency(completed, success=False, failure=True)
102 102 self.assertUnmet(dep)
103 103 self.assertUnreachable(dep)
104 104 dep.all=False
105 105 self.assertUnmet(dep)
106 106 self.assertUnreachable(dep)
@@ -1,176 +1,176 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test LoadBalancedView objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import time
21 21
22 22 import zmq
23 23 from nose import SkipTest
24 24
25 25 from IPython import parallel as pmod
26 26 from IPython.parallel import error
27 27
28 28 from IPython.parallel.tests import add_engines
29 29
30 30 from .clienttest import ClusterTestCase, crash, wait, skip_without
31 31
32 32 def setup():
33 33 add_engines(3, total=True)
34 34
35 35 class TestLoadBalancedView(ClusterTestCase):
36 36
37 37 def setUp(self):
38 38 ClusterTestCase.setUp(self)
39 39 self.view = self.client.load_balanced_view()
40 40
41 41 def test_z_crash_task(self):
42 42 """test graceful handling of engine death (balanced)"""
43 43 raise SkipTest("crash tests disabled, due to undesirable crash reports")
44 44 # self.add_engines(1)
45 45 ar = self.view.apply_async(crash)
46 46 self.assertRaisesRemote(error.EngineError, ar.get, 10)
47 47 eid = ar.engine_id
48 48 tic = time.time()
49 49 while eid in self.client.ids and time.time()-tic < 5:
50 50 time.sleep(.01)
51 51 self.client.spin()
52 52 self.assertFalse(eid in self.client.ids, "Engine should have died")
53 53
54 54 def test_map(self):
55 55 def f(x):
56 56 return x**2
57 57 data = range(16)
58 58 r = self.view.map_sync(f, data)
59 self.assertEquals(r, map(f, data))
59 self.assertEqual(r, map(f, data))
60 60
61 61 def test_map_unordered(self):
62 62 def f(x):
63 63 return x**2
64 64 def slow_f(x):
65 65 import time
66 66 time.sleep(0.05*x)
67 67 return x**2
68 68 data = range(16,0,-1)
69 69 reference = map(f, data)
70 70
71 71 amr = self.view.map_async(slow_f, data, ordered=False)
72 72 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
73 73 # check individual elements, retrieved as they come
74 74 # list comprehension uses __iter__
75 75 astheycame = [ r for r in amr ]
76 76 # Ensure that at least one result came out of order:
77 77 self.assertNotEquals(astheycame, reference, "should not have preserved order")
78 self.assertEquals(sorted(astheycame, reverse=True), reference, "result corrupted")
78 self.assertEqual(sorted(astheycame, reverse=True), reference, "result corrupted")
79 79
80 80 def test_map_ordered(self):
81 81 def f(x):
82 82 return x**2
83 83 def slow_f(x):
84 84 import time
85 85 time.sleep(0.05*x)
86 86 return x**2
87 87 data = range(16,0,-1)
88 88 reference = map(f, data)
89 89
90 90 amr = self.view.map_async(slow_f, data)
91 91 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
92 92 # check individual elements, retrieved as they come
93 93 # list(amr) uses __iter__
94 94 astheycame = list(amr)
95 95 # Ensure that results came in order
96 self.assertEquals(astheycame, reference)
97 self.assertEquals(amr.result, reference)
96 self.assertEqual(astheycame, reference)
97 self.assertEqual(amr.result, reference)
98 98
99 99 def test_map_iterable(self):
100 100 """test map on iterables (balanced)"""
101 101 view = self.view
102 102 # 101 is prime, so it won't be evenly distributed
103 103 arr = range(101)
104 104 # so that it will be an iterator, even in Python 3
105 105 it = iter(arr)
106 106 r = view.map_sync(lambda x:x, arr)
107 self.assertEquals(r, list(arr))
107 self.assertEqual(r, list(arr))
108 108
109 109
110 110 def test_abort(self):
111 111 view = self.view
112 112 ar = self.client[:].apply_async(time.sleep, .5)
113 113 ar = self.client[:].apply_async(time.sleep, .5)
114 114 time.sleep(0.2)
115 115 ar2 = view.apply_async(lambda : 2)
116 116 ar3 = view.apply_async(lambda : 3)
117 117 view.abort(ar2)
118 118 view.abort(ar3.msg_ids)
119 119 self.assertRaises(error.TaskAborted, ar2.get)
120 120 self.assertRaises(error.TaskAborted, ar3.get)
121 121
122 122 def test_retries(self):
123 123 view = self.view
124 124 view.timeout = 1 # prevent hang if this doesn't behave
125 125 def fail():
126 126 assert False
127 127 for r in range(len(self.client)-1):
128 128 with view.temp_flags(retries=r):
129 129 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
130 130
131 131 with view.temp_flags(retries=len(self.client), timeout=0.25):
132 132 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
133 133
134 134 def test_invalid_dependency(self):
135 135 view = self.view
136 136 with view.temp_flags(after='12345'):
137 137 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
138 138
139 139 def test_impossible_dependency(self):
140 140 self.minimum_engines(2)
141 141 view = self.client.load_balanced_view()
142 142 ar1 = view.apply_async(lambda : 1)
143 143 ar1.get()
144 144 e1 = ar1.engine_id
145 145 e2 = e1
146 146 while e2 == e1:
147 147 ar2 = view.apply_async(lambda : 1)
148 148 ar2.get()
149 149 e2 = ar2.engine_id
150 150
151 151 with view.temp_flags(follow=[ar1, ar2]):
152 152 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
153 153
154 154
155 155 def test_follow(self):
156 156 ar = self.view.apply_async(lambda : 1)
157 157 ar.get()
158 158 ars = []
159 159 first_id = ar.engine_id
160 160
161 161 self.view.follow = ar
162 162 for i in range(5):
163 163 ars.append(self.view.apply_async(lambda : 1))
164 164 self.view.wait(ars)
165 165 for ar in ars:
166 self.assertEquals(ar.engine_id, first_id)
166 self.assertEqual(ar.engine_id, first_id)
167 167
168 168 def test_after(self):
169 169 view = self.view
170 170 ar = view.apply_async(time.sleep, 0.5)
171 171 with view.temp_flags(after=ar):
172 172 ar2 = view.apply_async(lambda : 1)
173 173
174 174 ar.wait()
175 175 ar2.wait()
176 176 self.assertTrue(ar2.started >= ar.completed, "%s not >= %s"%(ar.started, ar.completed))
@@ -1,386 +1,386 b''
1 1 # -*- coding: utf-8 -*-
2 2 """Test Parallel magics
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import re
20 20 import sys
21 21 import time
22 22
23 23 import zmq
24 24 from nose import SkipTest
25 25
26 26 from IPython.testing import decorators as dec
27 27 from IPython.testing.ipunittest import ParametricTestCase
28 28 from IPython.utils.io import capture_output
29 29
30 30 from IPython import parallel as pmod
31 31 from IPython.parallel import error
32 32 from IPython.parallel import AsyncResult
33 33 from IPython.parallel.util import interactive
34 34
35 35 from IPython.parallel.tests import add_engines
36 36
37 37 from .clienttest import ClusterTestCase, generate_output
38 38
39 39 def setup():
40 40 add_engines(3, total=True)
41 41
42 42 class TestParallelMagics(ClusterTestCase, ParametricTestCase):
43 43
44 44 def test_px_blocking(self):
45 45 ip = get_ipython()
46 46 v = self.client[-1:]
47 47 v.activate()
48 48 v.block=True
49 49
50 50 ip.magic('px a=5')
51 self.assertEquals(v['a'], [5])
51 self.assertEqual(v['a'], [5])
52 52 ip.magic('px a=10')
53 self.assertEquals(v['a'], [10])
53 self.assertEqual(v['a'], [10])
54 54 # just 'print a' works ~99% of the time, but this ensures that
55 55 # the stdout message has arrived when the result is finished:
56 56 with capture_output() as io:
57 57 ip.magic(
58 58 'px import sys,time;print(a);sys.stdout.flush();time.sleep(0.2)'
59 59 )
60 60 out = io.stdout
61 61 self.assertTrue('[stdout:' in out, out)
62 62 self.assertFalse('\n\n' in out)
63 63 self.assertTrue(out.rstrip().endswith('10'))
64 64 self.assertRaisesRemote(ZeroDivisionError, ip.magic, 'px 1/0')
65 65
66 66 def _check_generated_stderr(self, stderr, n):
67 67 expected = [
68 68 r'\[stderr:\d+\]',
69 69 '^stderr$',
70 70 '^stderr2$',
71 71 ] * n
72 72
73 73 self.assertFalse('\n\n' in stderr, stderr)
74 74 lines = stderr.splitlines()
75 self.assertEquals(len(lines), len(expected), stderr)
75 self.assertEqual(len(lines), len(expected), stderr)
76 76 for line,expect in zip(lines, expected):
77 77 if isinstance(expect, str):
78 78 expect = [expect]
79 79 for ex in expect:
80 80 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
81 81
82 82 def test_cellpx_block_args(self):
83 83 """%%px --[no]block flags work"""
84 84 ip = get_ipython()
85 85 v = self.client[-1:]
86 86 v.activate()
87 87 v.block=False
88 88
89 89 for block in (True, False):
90 90 v.block = block
91 91 ip.magic("pxconfig --verbose")
92 92 with capture_output() as io:
93 93 ip.run_cell_magic("px", "", "1")
94 94 if block:
95 95 self.assertTrue(io.stdout.startswith("Parallel"), io.stdout)
96 96 else:
97 97 self.assertTrue(io.stdout.startswith("Async"), io.stdout)
98 98
99 99 with capture_output() as io:
100 100 ip.run_cell_magic("px", "--block", "1")
101 101 self.assertTrue(io.stdout.startswith("Parallel"), io.stdout)
102 102
103 103 with capture_output() as io:
104 104 ip.run_cell_magic("px", "--noblock", "1")
105 105 self.assertTrue(io.stdout.startswith("Async"), io.stdout)
106 106
107 107 def test_cellpx_groupby_engine(self):
108 108 """%%px --group-outputs=engine"""
109 109 ip = get_ipython()
110 110 v = self.client[:]
111 111 v.block = True
112 112 v.activate()
113 113
114 114 v['generate_output'] = generate_output
115 115
116 116 with capture_output() as io:
117 117 ip.run_cell_magic('px', '--group-outputs=engine', 'generate_output()')
118 118
119 119 self.assertFalse('\n\n' in io.stdout)
120 120 lines = io.stdout.splitlines()
121 121 expected = [
122 122 r'\[stdout:\d+\]',
123 123 'stdout',
124 124 'stdout2',
125 125 r'\[output:\d+\]',
126 126 r'IPython\.core\.display\.HTML',
127 127 r'IPython\.core\.display\.Math',
128 128 r'Out\[\d+:\d+\]:.*IPython\.core\.display\.Math',
129 129 ] * len(v)
130 130
131 self.assertEquals(len(lines), len(expected), io.stdout)
131 self.assertEqual(len(lines), len(expected), io.stdout)
132 132 for line,expect in zip(lines, expected):
133 133 if isinstance(expect, str):
134 134 expect = [expect]
135 135 for ex in expect:
136 136 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
137 137
138 138 self._check_generated_stderr(io.stderr, len(v))
139 139
140 140
141 141 def test_cellpx_groupby_order(self):
142 142 """%%px --group-outputs=order"""
143 143 ip = get_ipython()
144 144 v = self.client[:]
145 145 v.block = True
146 146 v.activate()
147 147
148 148 v['generate_output'] = generate_output
149 149
150 150 with capture_output() as io:
151 151 ip.run_cell_magic('px', '--group-outputs=order', 'generate_output()')
152 152
153 153 self.assertFalse('\n\n' in io.stdout)
154 154 lines = io.stdout.splitlines()
155 155 expected = []
156 156 expected.extend([
157 157 r'\[stdout:\d+\]',
158 158 'stdout',
159 159 'stdout2',
160 160 ] * len(v))
161 161 expected.extend([
162 162 r'\[output:\d+\]',
163 163 'IPython.core.display.HTML',
164 164 ] * len(v))
165 165 expected.extend([
166 166 r'\[output:\d+\]',
167 167 'IPython.core.display.Math',
168 168 ] * len(v))
169 169 expected.extend([
170 170 r'Out\[\d+:\d+\]:.*IPython\.core\.display\.Math'
171 171 ] * len(v))
172 172
173 self.assertEquals(len(lines), len(expected), io.stdout)
173 self.assertEqual(len(lines), len(expected), io.stdout)
174 174 for line,expect in zip(lines, expected):
175 175 if isinstance(expect, str):
176 176 expect = [expect]
177 177 for ex in expect:
178 178 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
179 179
180 180 self._check_generated_stderr(io.stderr, len(v))
181 181
182 182 def test_cellpx_groupby_type(self):
183 183 """%%px --group-outputs=type"""
184 184 ip = get_ipython()
185 185 v = self.client[:]
186 186 v.block = True
187 187 v.activate()
188 188
189 189 v['generate_output'] = generate_output
190 190
191 191 with capture_output() as io:
192 192 ip.run_cell_magic('px', '--group-outputs=type', 'generate_output()')
193 193
194 194 self.assertFalse('\n\n' in io.stdout)
195 195 lines = io.stdout.splitlines()
196 196
197 197 expected = []
198 198 expected.extend([
199 199 r'\[stdout:\d+\]',
200 200 'stdout',
201 201 'stdout2',
202 202 ] * len(v))
203 203 expected.extend([
204 204 r'\[output:\d+\]',
205 205 r'IPython\.core\.display\.HTML',
206 206 r'IPython\.core\.display\.Math',
207 207 ] * len(v))
208 208 expected.extend([
209 209 (r'Out\[\d+:\d+\]', r'IPython\.core\.display\.Math')
210 210 ] * len(v))
211 211
212 self.assertEquals(len(lines), len(expected), io.stdout)
212 self.assertEqual(len(lines), len(expected), io.stdout)
213 213 for line,expect in zip(lines, expected):
214 214 if isinstance(expect, str):
215 215 expect = [expect]
216 216 for ex in expect:
217 217 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
218 218
219 219 self._check_generated_stderr(io.stderr, len(v))
220 220
221 221
222 222 def test_px_nonblocking(self):
223 223 ip = get_ipython()
224 224 v = self.client[-1:]
225 225 v.activate()
226 226 v.block=False
227 227
228 228 ip.magic('px a=5')
229 self.assertEquals(v['a'], [5])
229 self.assertEqual(v['a'], [5])
230 230 ip.magic('px a=10')
231 self.assertEquals(v['a'], [10])
231 self.assertEqual(v['a'], [10])
232 232 ip.magic('pxconfig --verbose')
233 233 with capture_output() as io:
234 234 ar = ip.magic('px print (a)')
235 235 self.assertTrue(isinstance(ar, AsyncResult))
236 236 self.assertTrue('Async' in io.stdout)
237 237 self.assertFalse('[stdout:' in io.stdout)
238 238 self.assertFalse('\n\n' in io.stdout)
239 239
240 240 ar = ip.magic('px 1/0')
241 241 self.assertRaisesRemote(ZeroDivisionError, ar.get)
242 242
243 243 def test_autopx_blocking(self):
244 244 ip = get_ipython()
245 245 v = self.client[-1]
246 246 v.activate()
247 247 v.block=True
248 248
249 249 with capture_output() as io:
250 250 ip.magic('autopx')
251 251 ip.run_cell('\n'.join(('a=5','b=12345','c=0')))
252 252 ip.run_cell('b*=2')
253 253 ip.run_cell('print (b)')
254 254 ip.run_cell('b')
255 255 ip.run_cell("b/c")
256 256 ip.magic('autopx')
257 257
258 258 output = io.stdout
259 259
260 260 self.assertTrue(output.startswith('%autopx enabled'), output)
261 261 self.assertTrue(output.rstrip().endswith('%autopx disabled'), output)
262 262 self.assertTrue('ZeroDivisionError' in output, output)
263 263 self.assertTrue('\nOut[' in output, output)
264 264 self.assertTrue(': 24690' in output, output)
265 265 ar = v.get_result(-1)
266 self.assertEquals(v['a'], 5)
267 self.assertEquals(v['b'], 24690)
266 self.assertEqual(v['a'], 5)
267 self.assertEqual(v['b'], 24690)
268 268 self.assertRaisesRemote(ZeroDivisionError, ar.get)
269 269
270 270 def test_autopx_nonblocking(self):
271 271 ip = get_ipython()
272 272 v = self.client[-1]
273 273 v.activate()
274 274 v.block=False
275 275
276 276 with capture_output() as io:
277 277 ip.magic('autopx')
278 278 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
279 279 ip.run_cell('print (b)')
280 280 ip.run_cell('import time; time.sleep(0.1)')
281 281 ip.run_cell("b/c")
282 282 ip.run_cell('b*=2')
283 283 ip.magic('autopx')
284 284
285 285 output = io.stdout.rstrip()
286 286
287 287 self.assertTrue(output.startswith('%autopx enabled'))
288 288 self.assertTrue(output.endswith('%autopx disabled'))
289 289 self.assertFalse('ZeroDivisionError' in output)
290 290 ar = v.get_result(-2)
291 291 self.assertRaisesRemote(ZeroDivisionError, ar.get)
292 292 # prevent TaskAborted on pulls, due to ZeroDivisionError
293 293 time.sleep(0.5)
294 self.assertEquals(v['a'], 5)
294 self.assertEqual(v['a'], 5)
295 295 # b*=2 will not fire, due to abort
296 self.assertEquals(v['b'], 10)
296 self.assertEqual(v['b'], 10)
297 297
298 298 def test_result(self):
299 299 ip = get_ipython()
300 300 v = self.client[-1]
301 301 v.activate()
302 302 data = dict(a=111,b=222)
303 303 v.push(data, block=True)
304 304
305 305 for name in ('a', 'b'):
306 306 ip.magic('px ' + name)
307 307 with capture_output() as io:
308 308 ip.magic('pxresult')
309 309 output = io.stdout
310 310 msg = "expected %s output to include %s, but got: %s" % \
311 311 ('%pxresult', str(data[name]), output)
312 312 self.assertTrue(str(data[name]) in output, msg)
313 313
314 314 @dec.skipif_not_matplotlib
315 315 def test_px_pylab(self):
316 316 """%pylab works on engines"""
317 317 ip = get_ipython()
318 318 v = self.client[-1]
319 319 v.block = True
320 320 v.activate()
321 321
322 322 with capture_output() as io:
323 323 ip.magic("px %pylab inline")
324 324
325 325 self.assertTrue("Welcome to pylab" in io.stdout, io.stdout)
326 326 self.assertTrue("backend_inline" in io.stdout, io.stdout)
327 327
328 328 with capture_output() as io:
329 329 ip.magic("px plot(rand(100))")
330 330
331 331 self.assertTrue('Out[' in io.stdout, io.stdout)
332 332 self.assertTrue('matplotlib.lines' in io.stdout, io.stdout)
333 333
334 334 def test_pxconfig(self):
335 335 ip = get_ipython()
336 336 rc = self.client
337 337 v = rc.activate(-1, '_tst')
338 self.assertEquals(v.targets, rc.ids[-1])
338 self.assertEqual(v.targets, rc.ids[-1])
339 339 ip.magic("%pxconfig_tst -t :")
340 self.assertEquals(v.targets, rc.ids)
340 self.assertEqual(v.targets, rc.ids)
341 341 ip.magic("%pxconfig_tst -t ::2")
342 self.assertEquals(v.targets, rc.ids[::2])
342 self.assertEqual(v.targets, rc.ids[::2])
343 343 ip.magic("%pxconfig_tst -t 1::2")
344 self.assertEquals(v.targets, rc.ids[1::2])
344 self.assertEqual(v.targets, rc.ids[1::2])
345 345 ip.magic("%pxconfig_tst -t 1")
346 self.assertEquals(v.targets, 1)
346 self.assertEqual(v.targets, 1)
347 347 ip.magic("%pxconfig_tst --block")
348 self.assertEquals(v.block, True)
348 self.assertEqual(v.block, True)
349 349 ip.magic("%pxconfig_tst --noblock")
350 self.assertEquals(v.block, False)
350 self.assertEqual(v.block, False)
351 351
352 352 def test_cellpx_targets(self):
353 353 """%%px --targets doesn't change defaults"""
354 354 ip = get_ipython()
355 355 rc = self.client
356 356 view = rc.activate(rc.ids)
357 self.assertEquals(view.targets, rc.ids)
357 self.assertEqual(view.targets, rc.ids)
358 358 ip.magic('pxconfig --verbose')
359 359 for cell in ("pass", "1/0"):
360 360 with capture_output() as io:
361 361 try:
362 362 ip.run_cell_magic("px", "--targets all", cell)
363 363 except pmod.RemoteError:
364 364 pass
365 365 self.assertTrue('engine(s): all' in io.stdout)
366 self.assertEquals(view.targets, rc.ids)
366 self.assertEqual(view.targets, rc.ids)
367 367
368 368
369 369 def test_cellpx_block(self):
370 370 """%%px --block doesn't change default"""
371 371 ip = get_ipython()
372 372 rc = self.client
373 373 view = rc.activate(rc.ids)
374 374 view.block = False
375 self.assertEquals(view.targets, rc.ids)
375 self.assertEqual(view.targets, rc.ids)
376 376 ip.magic('pxconfig --verbose')
377 377 for cell in ("pass", "1/0"):
378 378 with capture_output() as io:
379 379 try:
380 380 ip.run_cell_magic("px", "--block", cell)
381 381 except pmod.RemoteError:
382 382 pass
383 383 self.assertFalse('Async' in io.stdout)
384 384 self.assertFalse(view.block)
385 385
386 386
@@ -1,117 +1,117 b''
1 1 """test serialization with newserialized
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20
21 21 from unittest import TestCase
22 22
23 23 from IPython.testing.decorators import parametric
24 24 from IPython.utils import newserialized as ns
25 25 from IPython.utils.pickleutil import can, uncan, CannedObject, CannedFunction
26 26 from IPython.parallel.tests.clienttest import skip_without
27 27
28 28 if sys.version_info[0] >= 3:
29 29 buffer = memoryview
30 30
31 31 class CanningTestCase(TestCase):
32 32 def test_canning(self):
33 33 d = dict(a=5,b=6)
34 34 cd = can(d)
35 35 self.assertTrue(isinstance(cd, dict))
36 36
37 37 def test_canned_function(self):
38 38 f = lambda : 7
39 39 cf = can(f)
40 40 self.assertTrue(isinstance(cf, CannedFunction))
41 41
42 42 @parametric
43 43 def test_can_roundtrip(cls):
44 44 objs = [
45 45 dict(),
46 46 set(),
47 47 list(),
48 48 ['a',1,['a',1],u'e'],
49 49 ]
50 50 return map(cls.run_roundtrip, objs)
51 51
52 52 @classmethod
53 53 def run_roundtrip(self, obj):
54 54 o = uncan(can(obj))
55 55 assert o == obj, "failed assertion: %r == %r"%(o,obj)
56 56
57 57 def test_serialized_interfaces(self):
58 58
59 59 us = {'a':10, 'b':range(10)}
60 60 s = ns.serialize(us)
61 61 uus = ns.unserialize(s)
62 62 self.assertTrue(isinstance(s, ns.SerializeIt))
63 self.assertEquals(uus, us)
63 self.assertEqual(uus, us)
64 64
65 65 def test_pickle_serialized(self):
66 66 obj = {'a':1.45345, 'b':'asdfsdf', 'c':10000L}
67 67 original = ns.UnSerialized(obj)
68 68 originalSer = ns.SerializeIt(original)
69 69 firstData = originalSer.getData()
70 70 firstTD = originalSer.getTypeDescriptor()
71 71 firstMD = originalSer.getMetadata()
72 self.assertEquals(firstTD, 'pickle')
73 self.assertEquals(firstMD, {})
72 self.assertEqual(firstTD, 'pickle')
73 self.assertEqual(firstMD, {})
74 74 unSerialized = ns.UnSerializeIt(originalSer)
75 75 secondObj = unSerialized.getObject()
76 76 for k, v in secondObj.iteritems():
77 self.assertEquals(obj[k], v)
77 self.assertEqual(obj[k], v)
78 78 secondSer = ns.SerializeIt(ns.UnSerialized(secondObj))
79 self.assertEquals(firstData, secondSer.getData())
80 self.assertEquals(firstTD, secondSer.getTypeDescriptor() )
81 self.assertEquals(firstMD, secondSer.getMetadata())
79 self.assertEqual(firstData, secondSer.getData())
80 self.assertEqual(firstTD, secondSer.getTypeDescriptor() )
81 self.assertEqual(firstMD, secondSer.getMetadata())
82 82
83 83 @skip_without('numpy')
84 84 def test_ndarray_serialized(self):
85 85 import numpy
86 86 a = numpy.linspace(0.0, 1.0, 1000)
87 87 unSer1 = ns.UnSerialized(a)
88 88 ser1 = ns.SerializeIt(unSer1)
89 89 td = ser1.getTypeDescriptor()
90 self.assertEquals(td, 'ndarray')
90 self.assertEqual(td, 'ndarray')
91 91 md = ser1.getMetadata()
92 self.assertEquals(md['shape'], a.shape)
93 self.assertEquals(md['dtype'], a.dtype)
92 self.assertEqual(md['shape'], a.shape)
93 self.assertEqual(md['dtype'], a.dtype)
94 94 buff = ser1.getData()
95 self.assertEquals(buff, buffer(a))
95 self.assertEqual(buff, buffer(a))
96 96 s = ns.Serialized(buff, td, md)
97 97 final = ns.unserialize(s)
98 self.assertEquals(buffer(a), buffer(final))
98 self.assertEqual(buffer(a), buffer(final))
99 99 self.assertTrue((a==final).all())
100 self.assertEquals(a.dtype, final.dtype)
101 self.assertEquals(a.shape, final.shape)
100 self.assertEqual(a.dtype, final.dtype)
101 self.assertEqual(a.shape, final.shape)
102 102 # test non-copying:
103 103 a[2] = 1e9
104 104 self.assertTrue((a==final).all())
105 105
106 106 def test_uncan_function_globals(self):
107 107 """test that uncanning a module function restores it into its module"""
108 108 from re import search
109 109 cf = can(search)
110 110 csearch = uncan(cf)
111 111 self.assertEqual(csearch.__module__, search.__module__)
112 112 self.assertNotEqual(csearch('asd', 'asdf'), None)
113 113 csearch = uncan(cf, dict(a=5))
114 114 self.assertEqual(csearch.__module__, search.__module__)
115 115 self.assertNotEqual(csearch('asd', 'asdf'), None)
116 116
117 117 No newline at end of file
@@ -1,597 +1,597 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import platform
21 21 import time
22 22 from tempfile import mktemp
23 23 from StringIO import StringIO
24 24
25 25 import zmq
26 26 from nose import SkipTest
27 27
28 28 from IPython.testing import decorators as dec
29 29 from IPython.testing.ipunittest import ParametricTestCase
30 30
31 31 from IPython import parallel as pmod
32 32 from IPython.parallel import error
33 33 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
34 34 from IPython.parallel import DirectView
35 35 from IPython.parallel.util import interactive
36 36
37 37 from IPython.parallel.tests import add_engines
38 38
39 39 from .clienttest import ClusterTestCase, crash, wait, skip_without
40 40
41 41 def setup():
42 42 add_engines(3, total=True)
43 43
44 44 class TestView(ClusterTestCase, ParametricTestCase):
45 45
46 46 def setUp(self):
47 47 # On Win XP, wait for resource cleanup, else parallel test group fails
48 48 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
49 49 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
50 50 time.sleep(2)
51 51 super(TestView, self).setUp()
52 52
53 53 def test_z_crash_mux(self):
54 54 """test graceful handling of engine death (direct)"""
55 55 raise SkipTest("crash tests disabled, due to undesirable crash reports")
56 56 # self.add_engines(1)
57 57 eid = self.client.ids[-1]
58 58 ar = self.client[eid].apply_async(crash)
59 59 self.assertRaisesRemote(error.EngineError, ar.get, 10)
60 60 eid = ar.engine_id
61 61 tic = time.time()
62 62 while eid in self.client.ids and time.time()-tic < 5:
63 63 time.sleep(.01)
64 64 self.client.spin()
65 65 self.assertFalse(eid in self.client.ids, "Engine should have died")
66 66
67 67 def test_push_pull(self):
68 68 """test pushing and pulling"""
69 69 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
70 70 t = self.client.ids[-1]
71 71 v = self.client[t]
72 72 push = v.push
73 73 pull = v.pull
74 74 v.block=True
75 75 nengines = len(self.client)
76 76 push({'data':data})
77 77 d = pull('data')
78 self.assertEquals(d, data)
78 self.assertEqual(d, data)
79 79 self.client[:].push({'data':data})
80 80 d = self.client[:].pull('data', block=True)
81 self.assertEquals(d, nengines*[data])
81 self.assertEqual(d, nengines*[data])
82 82 ar = push({'data':data}, block=False)
83 83 self.assertTrue(isinstance(ar, AsyncResult))
84 84 r = ar.get()
85 85 ar = self.client[:].pull('data', block=False)
86 86 self.assertTrue(isinstance(ar, AsyncResult))
87 87 r = ar.get()
88 self.assertEquals(r, nengines*[data])
88 self.assertEqual(r, nengines*[data])
89 89 self.client[:].push(dict(a=10,b=20))
90 90 r = self.client[:].pull(('a','b'), block=True)
91 self.assertEquals(r, nengines*[[10,20]])
91 self.assertEqual(r, nengines*[[10,20]])
92 92
93 93 def test_push_pull_function(self):
94 94 "test pushing and pulling functions"
95 95 def testf(x):
96 96 return 2.0*x
97 97
98 98 t = self.client.ids[-1]
99 99 v = self.client[t]
100 100 v.block=True
101 101 push = v.push
102 102 pull = v.pull
103 103 execute = v.execute
104 104 push({'testf':testf})
105 105 r = pull('testf')
106 106 self.assertEqual(r(1.0), testf(1.0))
107 107 execute('r = testf(10)')
108 108 r = pull('r')
109 self.assertEquals(r, testf(10))
109 self.assertEqual(r, testf(10))
110 110 ar = self.client[:].push({'testf':testf}, block=False)
111 111 ar.get()
112 112 ar = self.client[:].pull('testf', block=False)
113 113 rlist = ar.get()
114 114 for r in rlist:
115 115 self.assertEqual(r(1.0), testf(1.0))
116 116 execute("def g(x): return x*x")
117 117 r = pull(('testf','g'))
118 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
118 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
119 119
120 120 def test_push_function_globals(self):
121 121 """test that pushed functions have access to globals"""
122 122 @interactive
123 123 def geta():
124 124 return a
125 125 # self.add_engines(1)
126 126 v = self.client[-1]
127 127 v.block=True
128 128 v['f'] = geta
129 129 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
130 130 v.execute('a=5')
131 131 v.execute('b=f()')
132 self.assertEquals(v['b'], 5)
132 self.assertEqual(v['b'], 5)
133 133
134 134 def test_push_function_defaults(self):
135 135 """test that pushed functions preserve default args"""
136 136 def echo(a=10):
137 137 return a
138 138 v = self.client[-1]
139 139 v.block=True
140 140 v['f'] = echo
141 141 v.execute('b=f()')
142 self.assertEquals(v['b'], 10)
142 self.assertEqual(v['b'], 10)
143 143
144 144 def test_get_result(self):
145 145 """test getting results from the Hub."""
146 146 c = pmod.Client(profile='iptest')
147 147 # self.add_engines(1)
148 148 t = c.ids[-1]
149 149 v = c[t]
150 150 v2 = self.client[t]
151 151 ar = v.apply_async(wait, 1)
152 152 # give the monitor time to notice the message
153 153 time.sleep(.25)
154 154 ahr = v2.get_result(ar.msg_ids)
155 155 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 self.assertEquals(ahr.get(), ar.get())
156 self.assertEqual(ahr.get(), ar.get())
157 157 ar2 = v2.get_result(ar.msg_ids)
158 158 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 159 c.spin()
160 160 c.close()
161 161
162 162 def test_run_newline(self):
163 163 """test that run appends newline to files"""
164 164 tmpfile = mktemp()
165 165 with open(tmpfile, 'w') as f:
166 166 f.write("""def g():
167 167 return 5
168 168 """)
169 169 v = self.client[-1]
170 170 v.run(tmpfile, block=True)
171 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
171 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
172 172
173 173 def test_apply_tracked(self):
174 174 """test tracking for apply"""
175 175 # self.add_engines(1)
176 176 t = self.client.ids[-1]
177 177 v = self.client[t]
178 178 v.block=False
179 179 def echo(n=1024*1024, **kwargs):
180 180 with v.temp_flags(**kwargs):
181 181 return v.apply(lambda x: x, 'x'*n)
182 182 ar = echo(1, track=False)
183 183 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
184 184 self.assertTrue(ar.sent)
185 185 ar = echo(track=True)
186 186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertEquals(ar.sent, ar._tracker.done)
187 self.assertEqual(ar.sent, ar._tracker.done)
188 188 ar._tracker.wait()
189 189 self.assertTrue(ar.sent)
190 190
191 191 def test_push_tracked(self):
192 192 t = self.client.ids[-1]
193 193 ns = dict(x='x'*1024*1024)
194 194 v = self.client[t]
195 195 ar = v.push(ns, block=False, track=False)
196 196 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
197 197 self.assertTrue(ar.sent)
198 198
199 199 ar = v.push(ns, block=False, track=True)
200 200 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 201 ar._tracker.wait()
202 self.assertEquals(ar.sent, ar._tracker.done)
202 self.assertEqual(ar.sent, ar._tracker.done)
203 203 self.assertTrue(ar.sent)
204 204 ar.get()
205 205
206 206 def test_scatter_tracked(self):
207 207 t = self.client.ids
208 208 x='x'*1024*1024
209 209 ar = self.client[t].scatter('x', x, block=False, track=False)
210 210 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
211 211 self.assertTrue(ar.sent)
212 212
213 213 ar = self.client[t].scatter('x', x, block=False, track=True)
214 214 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
215 self.assertEquals(ar.sent, ar._tracker.done)
215 self.assertEqual(ar.sent, ar._tracker.done)
216 216 ar._tracker.wait()
217 217 self.assertTrue(ar.sent)
218 218 ar.get()
219 219
220 220 def test_remote_reference(self):
221 221 v = self.client[-1]
222 222 v['a'] = 123
223 223 ra = pmod.Reference('a')
224 224 b = v.apply_sync(lambda x: x, ra)
225 self.assertEquals(b, 123)
225 self.assertEqual(b, 123)
226 226
227 227
228 228 def test_scatter_gather(self):
229 229 view = self.client[:]
230 230 seq1 = range(16)
231 231 view.scatter('a', seq1)
232 232 seq2 = view.gather('a', block=True)
233 self.assertEquals(seq2, seq1)
233 self.assertEqual(seq2, seq1)
234 234 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
235 235
236 236 @skip_without('numpy')
237 237 def test_scatter_gather_numpy(self):
238 238 import numpy
239 239 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
240 240 view = self.client[:]
241 241 a = numpy.arange(64)
242 242 view.scatter('a', a)
243 243 b = view.gather('a', block=True)
244 244 assert_array_equal(b, a)
245 245
246 246 def test_scatter_gather_lazy(self):
247 247 """scatter/gather with targets='all'"""
248 248 view = self.client.direct_view(targets='all')
249 249 x = range(64)
250 250 view.scatter('x', x)
251 251 gathered = view.gather('x', block=True)
252 self.assertEquals(gathered, x)
252 self.assertEqual(gathered, x)
253 253
254 254
255 255 @dec.known_failure_py3
256 256 @skip_without('numpy')
257 257 def test_push_numpy_nocopy(self):
258 258 import numpy
259 259 view = self.client[:]
260 260 a = numpy.arange(64)
261 261 view['A'] = a
262 262 @interactive
263 263 def check_writeable(x):
264 264 return x.flags.writeable
265 265
266 266 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
267 267 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
268 268
269 269 view.push(dict(B=a))
270 270 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
271 271 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
272 272
273 273 @skip_without('numpy')
274 274 def test_apply_numpy(self):
275 275 """view.apply(f, ndarray)"""
276 276 import numpy
277 277 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
278 278
279 279 A = numpy.random.random((100,100))
280 280 view = self.client[-1]
281 281 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
282 282 B = A.astype(dt)
283 283 C = view.apply_sync(lambda x:x, B)
284 284 assert_array_equal(B,C)
285 285
286 286 @skip_without('numpy')
287 287 def test_push_pull_recarray(self):
288 288 """push/pull recarrays"""
289 289 import numpy
290 290 from numpy.testing.utils import assert_array_equal
291 291
292 292 view = self.client[-1]
293 293
294 294 R = numpy.array([
295 295 (1, 'hi', 0.),
296 296 (2**30, 'there', 2.5),
297 297 (-99999, 'world', -12345.6789),
298 298 ], [('n', int), ('s', '|S10'), ('f', float)])
299 299
300 300 view['RR'] = R
301 301 R2 = view['RR']
302 302
303 303 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
304 self.assertEquals(r_dtype, R.dtype)
305 self.assertEquals(r_shape, R.shape)
306 self.assertEquals(R2.dtype, R.dtype)
307 self.assertEquals(R2.shape, R.shape)
304 self.assertEqual(r_dtype, R.dtype)
305 self.assertEqual(r_shape, R.shape)
306 self.assertEqual(R2.dtype, R.dtype)
307 self.assertEqual(R2.shape, R.shape)
308 308 assert_array_equal(R2, R)
309 309
310 310 def test_map(self):
311 311 view = self.client[:]
312 312 def f(x):
313 313 return x**2
314 314 data = range(16)
315 315 r = view.map_sync(f, data)
316 self.assertEquals(r, map(f, data))
316 self.assertEqual(r, map(f, data))
317 317
318 318 def test_map_iterable(self):
319 319 """test map on iterables (direct)"""
320 320 view = self.client[:]
321 321 # 101 is prime, so it won't be evenly distributed
322 322 arr = range(101)
323 323 # ensure it will be an iterator, even in Python 3
324 324 it = iter(arr)
325 325 r = view.map_sync(lambda x:x, arr)
326 self.assertEquals(r, list(arr))
326 self.assertEqual(r, list(arr))
327 327
328 328 def test_scatterGatherNonblocking(self):
329 329 data = range(16)
330 330 view = self.client[:]
331 331 view.scatter('a', data, block=False)
332 332 ar = view.gather('a', block=False)
333 self.assertEquals(ar.get(), data)
333 self.assertEqual(ar.get(), data)
334 334
335 335 @skip_without('numpy')
336 336 def test_scatter_gather_numpy_nonblocking(self):
337 337 import numpy
338 338 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
339 339 a = numpy.arange(64)
340 340 view = self.client[:]
341 341 ar = view.scatter('a', a, block=False)
342 342 self.assertTrue(isinstance(ar, AsyncResult))
343 343 amr = view.gather('a', block=False)
344 344 self.assertTrue(isinstance(amr, AsyncMapResult))
345 345 assert_array_equal(amr.get(), a)
346 346
347 347 def test_execute(self):
348 348 view = self.client[:]
349 349 # self.client.debug=True
350 350 execute = view.execute
351 351 ar = execute('c=30', block=False)
352 352 self.assertTrue(isinstance(ar, AsyncResult))
353 353 ar = execute('d=[0,1,2]', block=False)
354 354 self.client.wait(ar, 1)
355 self.assertEquals(len(ar.get()), len(self.client))
355 self.assertEqual(len(ar.get()), len(self.client))
356 356 for c in view['c']:
357 self.assertEquals(c, 30)
357 self.assertEqual(c, 30)
358 358
359 359 def test_abort(self):
360 360 view = self.client[-1]
361 361 ar = view.execute('import time; time.sleep(1)', block=False)
362 362 ar2 = view.apply_async(lambda : 2)
363 363 ar3 = view.apply_async(lambda : 3)
364 364 view.abort(ar2)
365 365 view.abort(ar3.msg_ids)
366 366 self.assertRaises(error.TaskAborted, ar2.get)
367 367 self.assertRaises(error.TaskAborted, ar3.get)
368 368
369 369 def test_abort_all(self):
370 370 """view.abort() aborts all outstanding tasks"""
371 371 view = self.client[-1]
372 372 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
373 373 view.abort()
374 374 view.wait(timeout=5)
375 375 for ar in ars[5:]:
376 376 self.assertRaises(error.TaskAborted, ar.get)
377 377
378 378 def test_temp_flags(self):
379 379 view = self.client[-1]
380 380 view.block=True
381 381 with view.temp_flags(block=False):
382 382 self.assertFalse(view.block)
383 383 self.assertTrue(view.block)
384 384
385 385 @dec.known_failure_py3
386 386 def test_importer(self):
387 387 view = self.client[-1]
388 388 view.clear(block=True)
389 389 with view.importer:
390 390 import re
391 391
392 392 @interactive
393 393 def findall(pat, s):
394 394 # this globals() step isn't necessary in real code
395 395 # only to prevent a closure in the test
396 396 re = globals()['re']
397 397 return re.findall(pat, s)
398 398
399 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
399 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
400 400
401 401 def test_unicode_execute(self):
402 402 """test executing unicode strings"""
403 403 v = self.client[-1]
404 404 v.block=True
405 405 if sys.version_info[0] >= 3:
406 406 code="a='é'"
407 407 else:
408 408 code=u"a=u'é'"
409 409 v.execute(code)
410 self.assertEquals(v['a'], u'é')
410 self.assertEqual(v['a'], u'é')
411 411
412 412 def test_unicode_apply_result(self):
413 413 """test unicode apply results"""
414 414 v = self.client[-1]
415 415 r = v.apply_sync(lambda : u'é')
416 self.assertEquals(r, u'é')
416 self.assertEqual(r, u'é')
417 417
418 418 def test_unicode_apply_arg(self):
419 419 """test passing unicode arguments to apply"""
420 420 v = self.client[-1]
421 421
422 422 @interactive
423 423 def check_unicode(a, check):
424 424 assert isinstance(a, unicode), "%r is not unicode"%a
425 425 assert isinstance(check, bytes), "%r is not bytes"%check
426 426 assert a.encode('utf8') == check, "%s != %s"%(a,check)
427 427
428 428 for s in [ u'é', u'ßø®∫',u'asdf' ]:
429 429 try:
430 430 v.apply_sync(check_unicode, s, s.encode('utf8'))
431 431 except error.RemoteError as e:
432 432 if e.ename == 'AssertionError':
433 433 self.fail(e.evalue)
434 434 else:
435 435 raise e
436 436
437 437 def test_map_reference(self):
438 438 """view.map(<Reference>, *seqs) should work"""
439 439 v = self.client[:]
440 440 v.scatter('n', self.client.ids, flatten=True)
441 441 v.execute("f = lambda x,y: x*y")
442 442 rf = pmod.Reference('f')
443 443 nlist = list(range(10))
444 444 mlist = nlist[::-1]
445 445 expected = [ m*n for m,n in zip(mlist, nlist) ]
446 446 result = v.map_sync(rf, mlist, nlist)
447 self.assertEquals(result, expected)
447 self.assertEqual(result, expected)
448 448
449 449 def test_apply_reference(self):
450 450 """view.apply(<Reference>, *args) should work"""
451 451 v = self.client[:]
452 452 v.scatter('n', self.client.ids, flatten=True)
453 453 v.execute("f = lambda x: n*x")
454 454 rf = pmod.Reference('f')
455 455 result = v.apply_sync(rf, 5)
456 456 expected = [ 5*id for id in self.client.ids ]
457 self.assertEquals(result, expected)
457 self.assertEqual(result, expected)
458 458
459 459 def test_eval_reference(self):
460 460 v = self.client[self.client.ids[0]]
461 461 v['g'] = range(5)
462 462 rg = pmod.Reference('g[0]')
463 463 echo = lambda x:x
464 self.assertEquals(v.apply_sync(echo, rg), 0)
464 self.assertEqual(v.apply_sync(echo, rg), 0)
465 465
466 466 def test_reference_nameerror(self):
467 467 v = self.client[self.client.ids[0]]
468 468 r = pmod.Reference('elvis_has_left')
469 469 echo = lambda x:x
470 470 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
471 471
472 472 def test_single_engine_map(self):
473 473 e0 = self.client[self.client.ids[0]]
474 474 r = range(5)
475 475 check = [ -1*i for i in r ]
476 476 result = e0.map_sync(lambda x: -1*x, r)
477 self.assertEquals(result, check)
477 self.assertEqual(result, check)
478 478
479 479 def test_len(self):
480 480 """len(view) makes sense"""
481 481 e0 = self.client[self.client.ids[0]]
482 yield self.assertEquals(len(e0), 1)
482 yield self.assertEqual(len(e0), 1)
483 483 v = self.client[:]
484 yield self.assertEquals(len(v), len(self.client.ids))
484 yield self.assertEqual(len(v), len(self.client.ids))
485 485 v = self.client.direct_view('all')
486 yield self.assertEquals(len(v), len(self.client.ids))
486 yield self.assertEqual(len(v), len(self.client.ids))
487 487 v = self.client[:2]
488 yield self.assertEquals(len(v), 2)
488 yield self.assertEqual(len(v), 2)
489 489 v = self.client[:1]
490 yield self.assertEquals(len(v), 1)
490 yield self.assertEqual(len(v), 1)
491 491 v = self.client.load_balanced_view()
492 yield self.assertEquals(len(v), len(self.client.ids))
492 yield self.assertEqual(len(v), len(self.client.ids))
493 493 # parametric tests seem to require manual closing?
494 494 self.client.close()
495 495
496 496
497 497 # begin execute tests
498 498
499 499 def test_execute_reply(self):
500 500 e0 = self.client[self.client.ids[0]]
501 501 e0.block = True
502 502 ar = e0.execute("5", silent=False)
503 503 er = ar.get()
504 self.assertEquals(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
505 self.assertEquals(er.pyout['data']['text/plain'], '5')
504 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
505 self.assertEqual(er.pyout['data']['text/plain'], '5')
506 506
507 507 def test_execute_reply_stdout(self):
508 508 e0 = self.client[self.client.ids[0]]
509 509 e0.block = True
510 510 ar = e0.execute("print (5)", silent=False)
511 511 er = ar.get()
512 self.assertEquals(er.stdout.strip(), '5')
512 self.assertEqual(er.stdout.strip(), '5')
513 513
514 514 def test_execute_pyout(self):
515 515 """execute triggers pyout with silent=False"""
516 516 view = self.client[:]
517 517 ar = view.execute("5", silent=False, block=True)
518 518
519 519 expected = [{'text/plain' : '5'}] * len(view)
520 520 mimes = [ out['data'] for out in ar.pyout ]
521 self.assertEquals(mimes, expected)
521 self.assertEqual(mimes, expected)
522 522
523 523 def test_execute_silent(self):
524 524 """execute does not trigger pyout with silent=True"""
525 525 view = self.client[:]
526 526 ar = view.execute("5", block=True)
527 527 expected = [None] * len(view)
528 self.assertEquals(ar.pyout, expected)
528 self.assertEqual(ar.pyout, expected)
529 529
530 530 def test_execute_magic(self):
531 531 """execute accepts IPython commands"""
532 532 view = self.client[:]
533 533 view.execute("a = 5")
534 534 ar = view.execute("%whos", block=True)
535 535 # this will raise, if that failed
536 536 ar.get(5)
537 537 for stdout in ar.stdout:
538 538 lines = stdout.splitlines()
539 self.assertEquals(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
539 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
540 540 found = False
541 541 for line in lines[2:]:
542 542 split = line.split()
543 543 if split == ['a', 'int', '5']:
544 544 found = True
545 545 break
546 546 self.assertTrue(found, "whos output wrong: %s" % stdout)
547 547
548 548 def test_execute_displaypub(self):
549 549 """execute tracks display_pub output"""
550 550 view = self.client[:]
551 551 view.execute("from IPython.core.display import *")
552 552 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
553 553
554 554 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
555 555 for outputs in ar.outputs:
556 556 mimes = [ out['data'] for out in outputs ]
557 self.assertEquals(mimes, expected)
557 self.assertEqual(mimes, expected)
558 558
559 559 def test_apply_displaypub(self):
560 560 """apply tracks display_pub output"""
561 561 view = self.client[:]
562 562 view.execute("from IPython.core.display import *")
563 563
564 564 @interactive
565 565 def publish():
566 566 [ display(i) for i in range(5) ]
567 567
568 568 ar = view.apply_async(publish)
569 569 ar.get(5)
570 570 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
571 571 for outputs in ar.outputs:
572 572 mimes = [ out['data'] for out in outputs ]
573 self.assertEquals(mimes, expected)
573 self.assertEqual(mimes, expected)
574 574
575 575 def test_execute_raises(self):
576 576 """exceptions in execute requests raise appropriately"""
577 577 view = self.client[-1]
578 578 ar = view.execute("1/0")
579 579 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
580 580
581 581 @dec.skipif_not_matplotlib
582 582 def test_magic_pylab(self):
583 583 """%pylab works on engines"""
584 584 view = self.client[-1]
585 585 ar = view.execute("%pylab inline")
586 586 # at least check if this raised:
587 587 reply = ar.get(5)
588 588 # include imports, in case user config
589 589 ar = view.execute("plot(rand(100))", silent=False)
590 590 reply = ar.get(5)
591 self.assertEquals(len(reply.outputs), 1)
591 self.assertEqual(len(reply.outputs), 1)
592 592 output = reply.outputs[0]
593 593 self.assertTrue("data" in output)
594 594 data = output['data']
595 595 self.assertTrue("image/png" in data)
596 596
597 597
@@ -1,131 +1,131 b''
1 1 # encoding: utf-8
2 2 """
3 3 Tests for platutils.py
4 4 """
5 5
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2008-2011 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16
17 17 import sys
18 18 from unittest import TestCase
19 19
20 20 import nose.tools as nt
21 21
22 22 from IPython.utils.process import (find_cmd, FindCmdError, arg_split,
23 23 system, getoutput, getoutputerror)
24 24 from IPython.testing import decorators as dec
25 25 from IPython.testing import tools as tt
26 26
27 27 #-----------------------------------------------------------------------------
28 28 # Tests
29 29 #-----------------------------------------------------------------------------
30 30
31 31 def test_find_cmd_python():
32 32 """Make sure we find sys.exectable for python."""
33 33 nt.assert_equals(find_cmd('python'), sys.executable)
34 34
35 35
36 36 @dec.skip_win32
37 37 def test_find_cmd_ls():
38 38 """Make sure we can find the full path to ls."""
39 39 path = find_cmd('ls')
40 40 nt.assert_true(path.endswith('ls'))
41 41
42 42
43 43 def has_pywin32():
44 44 try:
45 45 import win32api
46 46 except ImportError:
47 47 return False
48 48 return True
49 49
50 50
51 51 @dec.onlyif(has_pywin32, "This test requires win32api to run")
52 52 def test_find_cmd_pythonw():
53 53 """Try to find pythonw on Windows."""
54 54 path = find_cmd('pythonw')
55 55 nt.assert_true(path.endswith('pythonw.exe'))
56 56
57 57
58 58 @dec.onlyif(lambda : sys.platform != 'win32' or has_pywin32(),
59 59 "This test runs on posix or in win32 with win32api installed")
60 60 def test_find_cmd_fail():
61 61 """Make sure that FindCmdError is raised if we can't find the cmd."""
62 62 nt.assert_raises(FindCmdError,find_cmd,'asdfasdf')
63 63
64 64
65 65 @dec.skip_win32
66 66 def test_arg_split():
67 67 """Ensure that argument lines are correctly split like in a shell."""
68 68 tests = [['hi', ['hi']],
69 69 [u'hi', [u'hi']],
70 70 ['hello there', ['hello', 'there']],
71 71 # \u01ce == \N{LATIN SMALL LETTER A WITH CARON}
72 72 # Do not use \N because the tests crash with syntax error in
73 73 # some cases, for example windows python2.6.
74 74 [u'h\u01cello', [u'h\u01cello']],
75 75 ['something "with quotes"', ['something', '"with quotes"']],
76 76 ]
77 77 for argstr, argv in tests:
78 78 nt.assert_equal(arg_split(argstr), argv)
79 79
80 80 @dec.skip_if_not_win32
81 81 def test_arg_split_win32():
82 82 """Ensure that argument lines are correctly split like in a shell."""
83 83 tests = [['hi', ['hi']],
84 84 [u'hi', [u'hi']],
85 85 ['hello there', ['hello', 'there']],
86 86 [u'h\u01cello', [u'h\u01cello']],
87 87 ['something "with quotes"', ['something', 'with quotes']],
88 88 ]
89 89 for argstr, argv in tests:
90 90 nt.assert_equal(arg_split(argstr), argv)
91 91
92 92
93 93 class SubProcessTestCase(TestCase, tt.TempFileMixin):
94 94 def setUp(self):
95 95 """Make a valid python temp file."""
96 96 lines = ["from __future__ import print_function",
97 97 "import sys",
98 98 "print('on stdout', end='', file=sys.stdout)",
99 99 "print('on stderr', end='', file=sys.stderr)",
100 100 "sys.stdout.flush()",
101 101 "sys.stderr.flush()"]
102 102 self.mktmp('\n'.join(lines))
103 103
104 104 def test_system(self):
105 105 status = system('python "%s"' % self.fname)
106 self.assertEquals(status, 0)
106 self.assertEqual(status, 0)
107 107
108 108 def test_system_quotes(self):
109 109 status = system('python -c "import sys"')
110 self.assertEquals(status, 0)
110 self.assertEqual(status, 0)
111 111
112 112 def test_getoutput(self):
113 113 out = getoutput('python "%s"' % self.fname)
114 self.assertEquals(out, 'on stdout')
114 self.assertEqual(out, 'on stdout')
115 115
116 116 def test_getoutput_quoted(self):
117 117 out = getoutput('python -c "print (1)"')
118 self.assertEquals(out.strip(), '1')
118 self.assertEqual(out.strip(), '1')
119 119
120 120 #Invalid quoting on windows
121 121 @dec.skip_win32
122 122 def test_getoutput_quoted2(self):
123 123 out = getoutput("python -c 'print (1)'")
124 self.assertEquals(out.strip(), '1')
124 self.assertEqual(out.strip(), '1')
125 125 out = getoutput("python -c 'print (\"1\")'")
126 self.assertEquals(out.strip(), '1')
126 self.assertEqual(out.strip(), '1')
127 127
128 128 def test_getoutput(self):
129 129 out, err = getoutputerror('python "%s"' % self.fname)
130 self.assertEquals(out, 'on stdout')
131 self.assertEquals(err, 'on stderr')
130 self.assertEqual(out, 'on stdout')
131 self.assertEqual(err, 'on stderr')
@@ -1,908 +1,908 b''
1 1 # encoding: utf-8
2 2 """
3 3 Tests for IPython.utils.traitlets.
4 4
5 5 Authors:
6 6
7 7 * Brian Granger
8 8 * Enthought, Inc. Some of the code in this file comes from enthought.traits
9 9 and is licensed under the BSD license. Also, many of the ideas also come
10 10 from enthought.traits even though our implementation is very different.
11 11 """
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Copyright (C) 2008-2011 The IPython Development Team
15 15 #
16 16 # Distributed under the terms of the BSD License. The full license is in
17 17 # the file COPYING, distributed as part of this software.
18 18 #-----------------------------------------------------------------------------
19 19
20 20 #-----------------------------------------------------------------------------
21 21 # Imports
22 22 #-----------------------------------------------------------------------------
23 23
24 24 import re
25 25 import sys
26 26 from unittest import TestCase
27 27
28 28 from nose import SkipTest
29 29
30 30 from IPython.utils.traitlets import (
31 31 HasTraits, MetaHasTraits, TraitType, Any, CBytes,
32 32 Int, Long, Integer, Float, Complex, Bytes, Unicode, TraitError,
33 33 Undefined, Type, This, Instance, TCPAddress, List, Tuple,
34 34 ObjectName, DottedObjectName, CRegExp
35 35 )
36 36 from IPython.utils import py3compat
37 37 from IPython.testing.decorators import skipif
38 38
39 39 #-----------------------------------------------------------------------------
40 40 # Helper classes for testing
41 41 #-----------------------------------------------------------------------------
42 42
43 43
44 44 class HasTraitsStub(HasTraits):
45 45
46 46 def _notify_trait(self, name, old, new):
47 47 self._notify_name = name
48 48 self._notify_old = old
49 49 self._notify_new = new
50 50
51 51
52 52 #-----------------------------------------------------------------------------
53 53 # Test classes
54 54 #-----------------------------------------------------------------------------
55 55
56 56
57 57 class TestTraitType(TestCase):
58 58
59 59 def test_get_undefined(self):
60 60 class A(HasTraits):
61 61 a = TraitType
62 62 a = A()
63 self.assertEquals(a.a, Undefined)
63 self.assertEqual(a.a, Undefined)
64 64
65 65 def test_set(self):
66 66 class A(HasTraitsStub):
67 67 a = TraitType
68 68
69 69 a = A()
70 70 a.a = 10
71 self.assertEquals(a.a, 10)
72 self.assertEquals(a._notify_name, 'a')
73 self.assertEquals(a._notify_old, Undefined)
74 self.assertEquals(a._notify_new, 10)
71 self.assertEqual(a.a, 10)
72 self.assertEqual(a._notify_name, 'a')
73 self.assertEqual(a._notify_old, Undefined)
74 self.assertEqual(a._notify_new, 10)
75 75
76 76 def test_validate(self):
77 77 class MyTT(TraitType):
78 78 def validate(self, inst, value):
79 79 return -1
80 80 class A(HasTraitsStub):
81 81 tt = MyTT
82 82
83 83 a = A()
84 84 a.tt = 10
85 self.assertEquals(a.tt, -1)
85 self.assertEqual(a.tt, -1)
86 86
87 87 def test_default_validate(self):
88 88 class MyIntTT(TraitType):
89 89 def validate(self, obj, value):
90 90 if isinstance(value, int):
91 91 return value
92 92 self.error(obj, value)
93 93 class A(HasTraits):
94 94 tt = MyIntTT(10)
95 95 a = A()
96 self.assertEquals(a.tt, 10)
96 self.assertEqual(a.tt, 10)
97 97
98 98 # Defaults are validated when the HasTraits is instantiated
99 99 class B(HasTraits):
100 100 tt = MyIntTT('bad default')
101 101 self.assertRaises(TraitError, B)
102 102
103 103 def test_is_valid_for(self):
104 104 class MyTT(TraitType):
105 105 def is_valid_for(self, value):
106 106 return True
107 107 class A(HasTraits):
108 108 tt = MyTT
109 109
110 110 a = A()
111 111 a.tt = 10
112 self.assertEquals(a.tt, 10)
112 self.assertEqual(a.tt, 10)
113 113
114 114 def test_value_for(self):
115 115 class MyTT(TraitType):
116 116 def value_for(self, value):
117 117 return 20
118 118 class A(HasTraits):
119 119 tt = MyTT
120 120
121 121 a = A()
122 122 a.tt = 10
123 self.assertEquals(a.tt, 20)
123 self.assertEqual(a.tt, 20)
124 124
125 125 def test_info(self):
126 126 class A(HasTraits):
127 127 tt = TraitType
128 128 a = A()
129 self.assertEquals(A.tt.info(), 'any value')
129 self.assertEqual(A.tt.info(), 'any value')
130 130
131 131 def test_error(self):
132 132 class A(HasTraits):
133 133 tt = TraitType
134 134 a = A()
135 135 self.assertRaises(TraitError, A.tt.error, a, 10)
136 136
137 137 def test_dynamic_initializer(self):
138 138 class A(HasTraits):
139 139 x = Int(10)
140 140 def _x_default(self):
141 141 return 11
142 142 class B(A):
143 143 x = Int(20)
144 144 class C(A):
145 145 def _x_default(self):
146 146 return 21
147 147
148 148 a = A()
149 self.assertEquals(a._trait_values, {})
150 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
151 self.assertEquals(a.x, 11)
152 self.assertEquals(a._trait_values, {'x': 11})
149 self.assertEqual(a._trait_values, {})
150 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
151 self.assertEqual(a.x, 11)
152 self.assertEqual(a._trait_values, {'x': 11})
153 153 b = B()
154 self.assertEquals(b._trait_values, {'x': 20})
155 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
156 self.assertEquals(b.x, 20)
154 self.assertEqual(b._trait_values, {'x': 20})
155 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
156 self.assertEqual(b.x, 20)
157 157 c = C()
158 self.assertEquals(c._trait_values, {})
159 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
160 self.assertEquals(c.x, 21)
161 self.assertEquals(c._trait_values, {'x': 21})
158 self.assertEqual(c._trait_values, {})
159 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
160 self.assertEqual(c.x, 21)
161 self.assertEqual(c._trait_values, {'x': 21})
162 162 # Ensure that the base class remains unmolested when the _default
163 163 # initializer gets overridden in a subclass.
164 164 a = A()
165 165 c = C()
166 self.assertEquals(a._trait_values, {})
167 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
168 self.assertEquals(a.x, 11)
169 self.assertEquals(a._trait_values, {'x': 11})
166 self.assertEqual(a._trait_values, {})
167 self.assertEqual(a._trait_dyn_inits.keys(), ['x'])
168 self.assertEqual(a.x, 11)
169 self.assertEqual(a._trait_values, {'x': 11})
170 170
171 171
172 172
173 173 class TestHasTraitsMeta(TestCase):
174 174
175 175 def test_metaclass(self):
176 self.assertEquals(type(HasTraits), MetaHasTraits)
176 self.assertEqual(type(HasTraits), MetaHasTraits)
177 177
178 178 class A(HasTraits):
179 179 a = Int
180 180
181 181 a = A()
182 self.assertEquals(type(a.__class__), MetaHasTraits)
183 self.assertEquals(a.a,0)
182 self.assertEqual(type(a.__class__), MetaHasTraits)
183 self.assertEqual(a.a,0)
184 184 a.a = 10
185 self.assertEquals(a.a,10)
185 self.assertEqual(a.a,10)
186 186
187 187 class B(HasTraits):
188 188 b = Int()
189 189
190 190 b = B()
191 self.assertEquals(b.b,0)
191 self.assertEqual(b.b,0)
192 192 b.b = 10
193 self.assertEquals(b.b,10)
193 self.assertEqual(b.b,10)
194 194
195 195 class C(HasTraits):
196 196 c = Int(30)
197 197
198 198 c = C()
199 self.assertEquals(c.c,30)
199 self.assertEqual(c.c,30)
200 200 c.c = 10
201 self.assertEquals(c.c,10)
201 self.assertEqual(c.c,10)
202 202
203 203 def test_this_class(self):
204 204 class A(HasTraits):
205 205 t = This()
206 206 tt = This()
207 207 class B(A):
208 208 tt = This()
209 209 ttt = This()
210 self.assertEquals(A.t.this_class, A)
211 self.assertEquals(B.t.this_class, A)
212 self.assertEquals(B.tt.this_class, B)
213 self.assertEquals(B.ttt.this_class, B)
210 self.assertEqual(A.t.this_class, A)
211 self.assertEqual(B.t.this_class, A)
212 self.assertEqual(B.tt.this_class, B)
213 self.assertEqual(B.ttt.this_class, B)
214 214
215 215 class TestHasTraitsNotify(TestCase):
216 216
217 217 def setUp(self):
218 218 self._notify1 = []
219 219 self._notify2 = []
220 220
221 221 def notify1(self, name, old, new):
222 222 self._notify1.append((name, old, new))
223 223
224 224 def notify2(self, name, old, new):
225 225 self._notify2.append((name, old, new))
226 226
227 227 def test_notify_all(self):
228 228
229 229 class A(HasTraits):
230 230 a = Int
231 231 b = Float
232 232
233 233 a = A()
234 234 a.on_trait_change(self.notify1)
235 235 a.a = 0
236 self.assertEquals(len(self._notify1),0)
236 self.assertEqual(len(self._notify1),0)
237 237 a.b = 0.0
238 self.assertEquals(len(self._notify1),0)
238 self.assertEqual(len(self._notify1),0)
239 239 a.a = 10
240 240 self.assert_(('a',0,10) in self._notify1)
241 241 a.b = 10.0
242 242 self.assert_(('b',0.0,10.0) in self._notify1)
243 243 self.assertRaises(TraitError,setattr,a,'a','bad string')
244 244 self.assertRaises(TraitError,setattr,a,'b','bad string')
245 245 self._notify1 = []
246 246 a.on_trait_change(self.notify1,remove=True)
247 247 a.a = 20
248 248 a.b = 20.0
249 self.assertEquals(len(self._notify1),0)
249 self.assertEqual(len(self._notify1),0)
250 250
251 251 def test_notify_one(self):
252 252
253 253 class A(HasTraits):
254 254 a = Int
255 255 b = Float
256 256
257 257 a = A()
258 258 a.on_trait_change(self.notify1, 'a')
259 259 a.a = 0
260 self.assertEquals(len(self._notify1),0)
260 self.assertEqual(len(self._notify1),0)
261 261 a.a = 10
262 262 self.assert_(('a',0,10) in self._notify1)
263 263 self.assertRaises(TraitError,setattr,a,'a','bad string')
264 264
265 265 def test_subclass(self):
266 266
267 267 class A(HasTraits):
268 268 a = Int
269 269
270 270 class B(A):
271 271 b = Float
272 272
273 273 b = B()
274 self.assertEquals(b.a,0)
275 self.assertEquals(b.b,0.0)
274 self.assertEqual(b.a,0)
275 self.assertEqual(b.b,0.0)
276 276 b.a = 100
277 277 b.b = 100.0
278 self.assertEquals(b.a,100)
279 self.assertEquals(b.b,100.0)
278 self.assertEqual(b.a,100)
279 self.assertEqual(b.b,100.0)
280 280
281 281 def test_notify_subclass(self):
282 282
283 283 class A(HasTraits):
284 284 a = Int
285 285
286 286 class B(A):
287 287 b = Float
288 288
289 289 b = B()
290 290 b.on_trait_change(self.notify1, 'a')
291 291 b.on_trait_change(self.notify2, 'b')
292 292 b.a = 0
293 293 b.b = 0.0
294 self.assertEquals(len(self._notify1),0)
295 self.assertEquals(len(self._notify2),0)
294 self.assertEqual(len(self._notify1),0)
295 self.assertEqual(len(self._notify2),0)
296 296 b.a = 10
297 297 b.b = 10.0
298 298 self.assert_(('a',0,10) in self._notify1)
299 299 self.assert_(('b',0.0,10.0) in self._notify2)
300 300
301 301 def test_static_notify(self):
302 302
303 303 class A(HasTraits):
304 304 a = Int
305 305 _notify1 = []
306 306 def _a_changed(self, name, old, new):
307 307 self._notify1.append((name, old, new))
308 308
309 309 a = A()
310 310 a.a = 0
311 311 # This is broken!!!
312 self.assertEquals(len(a._notify1),0)
312 self.assertEqual(len(a._notify1),0)
313 313 a.a = 10
314 314 self.assert_(('a',0,10) in a._notify1)
315 315
316 316 class B(A):
317 317 b = Float
318 318 _notify2 = []
319 319 def _b_changed(self, name, old, new):
320 320 self._notify2.append((name, old, new))
321 321
322 322 b = B()
323 323 b.a = 10
324 324 b.b = 10.0
325 325 self.assert_(('a',0,10) in b._notify1)
326 326 self.assert_(('b',0.0,10.0) in b._notify2)
327 327
328 328 def test_notify_args(self):
329 329
330 330 def callback0():
331 331 self.cb = ()
332 332 def callback1(name):
333 333 self.cb = (name,)
334 334 def callback2(name, new):
335 335 self.cb = (name, new)
336 336 def callback3(name, old, new):
337 337 self.cb = (name, old, new)
338 338
339 339 class A(HasTraits):
340 340 a = Int
341 341
342 342 a = A()
343 343 a.on_trait_change(callback0, 'a')
344 344 a.a = 10
345 self.assertEquals(self.cb,())
345 self.assertEqual(self.cb,())
346 346 a.on_trait_change(callback0, 'a', remove=True)
347 347
348 348 a.on_trait_change(callback1, 'a')
349 349 a.a = 100
350 self.assertEquals(self.cb,('a',))
350 self.assertEqual(self.cb,('a',))
351 351 a.on_trait_change(callback1, 'a', remove=True)
352 352
353 353 a.on_trait_change(callback2, 'a')
354 354 a.a = 1000
355 self.assertEquals(self.cb,('a',1000))
355 self.assertEqual(self.cb,('a',1000))
356 356 a.on_trait_change(callback2, 'a', remove=True)
357 357
358 358 a.on_trait_change(callback3, 'a')
359 359 a.a = 10000
360 self.assertEquals(self.cb,('a',1000,10000))
360 self.assertEqual(self.cb,('a',1000,10000))
361 361 a.on_trait_change(callback3, 'a', remove=True)
362 362
363 self.assertEquals(len(a._trait_notifiers['a']),0)
363 self.assertEqual(len(a._trait_notifiers['a']),0)
364 364
365 365
366 366 class TestHasTraits(TestCase):
367 367
368 368 def test_trait_names(self):
369 369 class A(HasTraits):
370 370 i = Int
371 371 f = Float
372 372 a = A()
373 self.assertEquals(sorted(a.trait_names()),['f','i'])
374 self.assertEquals(sorted(A.class_trait_names()),['f','i'])
373 self.assertEqual(sorted(a.trait_names()),['f','i'])
374 self.assertEqual(sorted(A.class_trait_names()),['f','i'])
375 375
376 376 def test_trait_metadata(self):
377 377 class A(HasTraits):
378 378 i = Int(config_key='MY_VALUE')
379 379 a = A()
380 self.assertEquals(a.trait_metadata('i','config_key'), 'MY_VALUE')
380 self.assertEqual(a.trait_metadata('i','config_key'), 'MY_VALUE')
381 381
382 382 def test_traits(self):
383 383 class A(HasTraits):
384 384 i = Int
385 385 f = Float
386 386 a = A()
387 self.assertEquals(a.traits(), dict(i=A.i, f=A.f))
388 self.assertEquals(A.class_traits(), dict(i=A.i, f=A.f))
387 self.assertEqual(a.traits(), dict(i=A.i, f=A.f))
388 self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f))
389 389
390 390 def test_traits_metadata(self):
391 391 class A(HasTraits):
392 392 i = Int(config_key='VALUE1', other_thing='VALUE2')
393 393 f = Float(config_key='VALUE3', other_thing='VALUE2')
394 394 j = Int(0)
395 395 a = A()
396 self.assertEquals(a.traits(), dict(i=A.i, f=A.f, j=A.j))
396 self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j))
397 397 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
398 self.assertEquals(traits, dict(i=A.i))
398 self.assertEqual(traits, dict(i=A.i))
399 399
400 400 # This passes, but it shouldn't because I am replicating a bug in
401 401 # traits.
402 402 traits = a.traits(config_key=lambda v: True)
403 self.assertEquals(traits, dict(i=A.i, f=A.f, j=A.j))
403 self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j))
404 404
405 405 def test_init(self):
406 406 class A(HasTraits):
407 407 i = Int()
408 408 x = Float()
409 409 a = A(i=1, x=10.0)
410 self.assertEquals(a.i, 1)
411 self.assertEquals(a.x, 10.0)
410 self.assertEqual(a.i, 1)
411 self.assertEqual(a.x, 10.0)
412 412
413 413 #-----------------------------------------------------------------------------
414 414 # Tests for specific trait types
415 415 #-----------------------------------------------------------------------------
416 416
417 417
418 418 class TestType(TestCase):
419 419
420 420 def test_default(self):
421 421
422 422 class B(object): pass
423 423 class A(HasTraits):
424 424 klass = Type
425 425
426 426 a = A()
427 self.assertEquals(a.klass, None)
427 self.assertEqual(a.klass, None)
428 428
429 429 a.klass = B
430 self.assertEquals(a.klass, B)
430 self.assertEqual(a.klass, B)
431 431 self.assertRaises(TraitError, setattr, a, 'klass', 10)
432 432
433 433 def test_value(self):
434 434
435 435 class B(object): pass
436 436 class C(object): pass
437 437 class A(HasTraits):
438 438 klass = Type(B)
439 439
440 440 a = A()
441 self.assertEquals(a.klass, B)
441 self.assertEqual(a.klass, B)
442 442 self.assertRaises(TraitError, setattr, a, 'klass', C)
443 443 self.assertRaises(TraitError, setattr, a, 'klass', object)
444 444 a.klass = B
445 445
446 446 def test_allow_none(self):
447 447
448 448 class B(object): pass
449 449 class C(B): pass
450 450 class A(HasTraits):
451 451 klass = Type(B, allow_none=False)
452 452
453 453 a = A()
454 self.assertEquals(a.klass, B)
454 self.assertEqual(a.klass, B)
455 455 self.assertRaises(TraitError, setattr, a, 'klass', None)
456 456 a.klass = C
457 self.assertEquals(a.klass, C)
457 self.assertEqual(a.klass, C)
458 458
459 459 def test_validate_klass(self):
460 460
461 461 class A(HasTraits):
462 462 klass = Type('no strings allowed')
463 463
464 464 self.assertRaises(ImportError, A)
465 465
466 466 class A(HasTraits):
467 467 klass = Type('rub.adub.Duck')
468 468
469 469 self.assertRaises(ImportError, A)
470 470
471 471 def test_validate_default(self):
472 472
473 473 class B(object): pass
474 474 class A(HasTraits):
475 475 klass = Type('bad default', B)
476 476
477 477 self.assertRaises(ImportError, A)
478 478
479 479 class C(HasTraits):
480 480 klass = Type(None, B, allow_none=False)
481 481
482 482 self.assertRaises(TraitError, C)
483 483
484 484 def test_str_klass(self):
485 485
486 486 class A(HasTraits):
487 487 klass = Type('IPython.utils.ipstruct.Struct')
488 488
489 489 from IPython.utils.ipstruct import Struct
490 490 a = A()
491 491 a.klass = Struct
492 self.assertEquals(a.klass, Struct)
492 self.assertEqual(a.klass, Struct)
493 493
494 494 self.assertRaises(TraitError, setattr, a, 'klass', 10)
495 495
496 496 class TestInstance(TestCase):
497 497
498 498 def test_basic(self):
499 499 class Foo(object): pass
500 500 class Bar(Foo): pass
501 501 class Bah(object): pass
502 502
503 503 class A(HasTraits):
504 504 inst = Instance(Foo)
505 505
506 506 a = A()
507 507 self.assert_(a.inst is None)
508 508 a.inst = Foo()
509 509 self.assert_(isinstance(a.inst, Foo))
510 510 a.inst = Bar()
511 511 self.assert_(isinstance(a.inst, Foo))
512 512 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
513 513 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
514 514 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
515 515
516 516 def test_unique_default_value(self):
517 517 class Foo(object): pass
518 518 class A(HasTraits):
519 519 inst = Instance(Foo,(),{})
520 520
521 521 a = A()
522 522 b = A()
523 523 self.assert_(a.inst is not b.inst)
524 524
525 525 def test_args_kw(self):
526 526 class Foo(object):
527 527 def __init__(self, c): self.c = c
528 528 class Bar(object): pass
529 529 class Bah(object):
530 530 def __init__(self, c, d):
531 531 self.c = c; self.d = d
532 532
533 533 class A(HasTraits):
534 534 inst = Instance(Foo, (10,))
535 535 a = A()
536 self.assertEquals(a.inst.c, 10)
536 self.assertEqual(a.inst.c, 10)
537 537
538 538 class B(HasTraits):
539 539 inst = Instance(Bah, args=(10,), kw=dict(d=20))
540 540 b = B()
541 self.assertEquals(b.inst.c, 10)
542 self.assertEquals(b.inst.d, 20)
541 self.assertEqual(b.inst.c, 10)
542 self.assertEqual(b.inst.d, 20)
543 543
544 544 class C(HasTraits):
545 545 inst = Instance(Foo)
546 546 c = C()
547 547 self.assert_(c.inst is None)
548 548
549 549 def test_bad_default(self):
550 550 class Foo(object): pass
551 551
552 552 class A(HasTraits):
553 553 inst = Instance(Foo, allow_none=False)
554 554
555 555 self.assertRaises(TraitError, A)
556 556
557 557 def test_instance(self):
558 558 class Foo(object): pass
559 559
560 560 def inner():
561 561 class A(HasTraits):
562 562 inst = Instance(Foo())
563 563
564 564 self.assertRaises(TraitError, inner)
565 565
566 566
567 567 class TestThis(TestCase):
568 568
569 569 def test_this_class(self):
570 570 class Foo(HasTraits):
571 571 this = This
572 572
573 573 f = Foo()
574 self.assertEquals(f.this, None)
574 self.assertEqual(f.this, None)
575 575 g = Foo()
576 576 f.this = g
577 self.assertEquals(f.this, g)
577 self.assertEqual(f.this, g)
578 578 self.assertRaises(TraitError, setattr, f, 'this', 10)
579 579
580 580 def test_this_inst(self):
581 581 class Foo(HasTraits):
582 582 this = This()
583 583
584 584 f = Foo()
585 585 f.this = Foo()
586 586 self.assert_(isinstance(f.this, Foo))
587 587
588 588 def test_subclass(self):
589 589 class Foo(HasTraits):
590 590 t = This()
591 591 class Bar(Foo):
592 592 pass
593 593 f = Foo()
594 594 b = Bar()
595 595 f.t = b
596 596 b.t = f
597 self.assertEquals(f.t, b)
598 self.assertEquals(b.t, f)
597 self.assertEqual(f.t, b)
598 self.assertEqual(b.t, f)
599 599
600 600 def test_subclass_override(self):
601 601 class Foo(HasTraits):
602 602 t = This()
603 603 class Bar(Foo):
604 604 t = This()
605 605 f = Foo()
606 606 b = Bar()
607 607 f.t = b
608 self.assertEquals(f.t, b)
608 self.assertEqual(f.t, b)
609 609 self.assertRaises(TraitError, setattr, b, 't', f)
610 610
611 611 class TraitTestBase(TestCase):
612 612 """A best testing class for basic trait types."""
613 613
614 614 def assign(self, value):
615 615 self.obj.value = value
616 616
617 617 def coerce(self, value):
618 618 return value
619 619
620 620 def test_good_values(self):
621 621 if hasattr(self, '_good_values'):
622 622 for value in self._good_values:
623 623 self.assign(value)
624 self.assertEquals(self.obj.value, self.coerce(value))
624 self.assertEqual(self.obj.value, self.coerce(value))
625 625
626 626 def test_bad_values(self):
627 627 if hasattr(self, '_bad_values'):
628 628 for value in self._bad_values:
629 629 try:
630 630 self.assertRaises(TraitError, self.assign, value)
631 631 except AssertionError:
632 632 assert False, value
633 633
634 634 def test_default_value(self):
635 635 if hasattr(self, '_default_value'):
636 self.assertEquals(self._default_value, self.obj.value)
636 self.assertEqual(self._default_value, self.obj.value)
637 637
638 638 def tearDown(self):
639 639 # restore default value after tests, if set
640 640 if hasattr(self, '_default_value'):
641 641 self.obj.value = self._default_value
642 642
643 643
644 644 class AnyTrait(HasTraits):
645 645
646 646 value = Any
647 647
648 648 class AnyTraitTest(TraitTestBase):
649 649
650 650 obj = AnyTrait()
651 651
652 652 _default_value = None
653 653 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
654 654 _bad_values = []
655 655
656 656
657 657 class IntTrait(HasTraits):
658 658
659 659 value = Int(99)
660 660
661 661 class TestInt(TraitTestBase):
662 662
663 663 obj = IntTrait()
664 664 _default_value = 99
665 665 _good_values = [10, -10]
666 666 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j,
667 667 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
668 668 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
669 669 if not py3compat.PY3:
670 670 _bad_values.extend([10L, -10L, 10*sys.maxint, -10*sys.maxint])
671 671
672 672
673 673 class LongTrait(HasTraits):
674 674
675 675 value = Long(99L)
676 676
677 677 class TestLong(TraitTestBase):
678 678
679 679 obj = LongTrait()
680 680
681 681 _default_value = 99L
682 682 _good_values = [10, -10, 10L, -10L]
683 683 _bad_values = ['ten', u'ten', [10], [10l], {'ten': 10},(10,),(10L,),
684 684 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
685 685 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
686 686 u'-10.1']
687 687 if not py3compat.PY3:
688 688 # maxint undefined on py3, because int == long
689 689 _good_values.extend([10*sys.maxint, -10*sys.maxint])
690 690
691 691 @skipif(py3compat.PY3, "not relevant on py3")
692 692 def test_cast_small(self):
693 693 """Long casts ints to long"""
694 694 self.obj.value = 10
695 self.assertEquals(type(self.obj.value), long)
695 self.assertEqual(type(self.obj.value), long)
696 696
697 697
698 698 class IntegerTrait(HasTraits):
699 699 value = Integer(1)
700 700
701 701 class TestInteger(TestLong):
702 702 obj = IntegerTrait()
703 703 _default_value = 1
704 704
705 705 def coerce(self, n):
706 706 return int(n)
707 707
708 708 @skipif(py3compat.PY3, "not relevant on py3")
709 709 def test_cast_small(self):
710 710 """Integer casts small longs to int"""
711 711 if py3compat.PY3:
712 712 raise SkipTest("not relevant on py3")
713 713
714 714 self.obj.value = 100L
715 self.assertEquals(type(self.obj.value), int)
715 self.assertEqual(type(self.obj.value), int)
716 716
717 717
718 718 class FloatTrait(HasTraits):
719 719
720 720 value = Float(99.0)
721 721
722 722 class TestFloat(TraitTestBase):
723 723
724 724 obj = FloatTrait()
725 725
726 726 _default_value = 99.0
727 727 _good_values = [10, -10, 10.1, -10.1]
728 728 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None,
729 729 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
730 730 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
731 731 if not py3compat.PY3:
732 732 _bad_values.extend([10L, -10L])
733 733
734 734
735 735 class ComplexTrait(HasTraits):
736 736
737 737 value = Complex(99.0-99.0j)
738 738
739 739 class TestComplex(TraitTestBase):
740 740
741 741 obj = ComplexTrait()
742 742
743 743 _default_value = 99.0-99.0j
744 744 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
745 745 10.1j, 10.1+10.1j, 10.1-10.1j]
746 746 _bad_values = [u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
747 747 if not py3compat.PY3:
748 748 _bad_values.extend([10L, -10L])
749 749
750 750
751 751 class BytesTrait(HasTraits):
752 752
753 753 value = Bytes(b'string')
754 754
755 755 class TestBytes(TraitTestBase):
756 756
757 757 obj = BytesTrait()
758 758
759 759 _default_value = b'string'
760 760 _good_values = [b'10', b'-10', b'10L',
761 761 b'-10L', b'10.1', b'-10.1', b'string']
762 762 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j, [10],
763 763 ['ten'],{'ten': 10},(10,), None, u'string']
764 764
765 765
766 766 class UnicodeTrait(HasTraits):
767 767
768 768 value = Unicode(u'unicode')
769 769
770 770 class TestUnicode(TraitTestBase):
771 771
772 772 obj = UnicodeTrait()
773 773
774 774 _default_value = u'unicode'
775 775 _good_values = ['10', '-10', '10L', '-10L', '10.1',
776 776 '-10.1', '', u'', 'string', u'string', u"€"]
777 777 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j,
778 778 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
779 779
780 780
781 781 class ObjectNameTrait(HasTraits):
782 782 value = ObjectName("abc")
783 783
784 784 class TestObjectName(TraitTestBase):
785 785 obj = ObjectNameTrait()
786 786
787 787 _default_value = "abc"
788 788 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
789 789 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
790 790 object(), object]
791 791 if sys.version_info[0] < 3:
792 792 _bad_values.append(u"þ")
793 793 else:
794 794 _good_values.append(u"þ") # þ=1 is valid in Python 3 (PEP 3131).
795 795
796 796
797 797 class DottedObjectNameTrait(HasTraits):
798 798 value = DottedObjectName("a.b")
799 799
800 800 class TestDottedObjectName(TraitTestBase):
801 801 obj = DottedObjectNameTrait()
802 802
803 803 _default_value = "a.b"
804 804 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
805 805 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc."]
806 806 if sys.version_info[0] < 3:
807 807 _bad_values.append(u"t.þ")
808 808 else:
809 809 _good_values.append(u"t.þ")
810 810
811 811
812 812 class TCPAddressTrait(HasTraits):
813 813
814 814 value = TCPAddress()
815 815
816 816 class TestTCPAddress(TraitTestBase):
817 817
818 818 obj = TCPAddressTrait()
819 819
820 820 _default_value = ('127.0.0.1',0)
821 821 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
822 822 _bad_values = [(0,0),('localhost',10.0),('localhost',-1)]
823 823
824 824 class ListTrait(HasTraits):
825 825
826 826 value = List(Int)
827 827
828 828 class TestList(TraitTestBase):
829 829
830 830 obj = ListTrait()
831 831
832 832 _default_value = []
833 833 _good_values = [[], [1], range(10)]
834 834 _bad_values = [10, [1,'a'], 'a', (1,2)]
835 835
836 836 class LenListTrait(HasTraits):
837 837
838 838 value = List(Int, [0], minlen=1, maxlen=2)
839 839
840 840 class TestLenList(TraitTestBase):
841 841
842 842 obj = LenListTrait()
843 843
844 844 _default_value = [0]
845 845 _good_values = [[1], range(2)]
846 846 _bad_values = [10, [1,'a'], 'a', (1,2), [], range(3)]
847 847
848 848 class TupleTrait(HasTraits):
849 849
850 850 value = Tuple(Int)
851 851
852 852 class TestTupleTrait(TraitTestBase):
853 853
854 854 obj = TupleTrait()
855 855
856 856 _default_value = None
857 857 _good_values = [(1,), None,(0,)]
858 858 _bad_values = [10, (1,2), [1],('a'), ()]
859 859
860 860 def test_invalid_args(self):
861 861 self.assertRaises(TypeError, Tuple, 5)
862 862 self.assertRaises(TypeError, Tuple, default_value='hello')
863 863 t = Tuple(Int, CBytes, default_value=(1,5))
864 864
865 865 class LooseTupleTrait(HasTraits):
866 866
867 867 value = Tuple((1,2,3))
868 868
869 869 class TestLooseTupleTrait(TraitTestBase):
870 870
871 871 obj = LooseTupleTrait()
872 872
873 873 _default_value = (1,2,3)
874 874 _good_values = [(1,), None, (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
875 875 _bad_values = [10, 'hello', [1], []]
876 876
877 877 def test_invalid_args(self):
878 878 self.assertRaises(TypeError, Tuple, 5)
879 879 self.assertRaises(TypeError, Tuple, default_value='hello')
880 880 t = Tuple(Int, CBytes, default_value=(1,5))
881 881
882 882
883 883 class MultiTupleTrait(HasTraits):
884 884
885 885 value = Tuple(Int, Bytes, default_value=[99,b'bottles'])
886 886
887 887 class TestMultiTuple(TraitTestBase):
888 888
889 889 obj = MultiTupleTrait()
890 890
891 891 _default_value = (99,b'bottles')
892 892 _good_values = [(1,b'a'), (2,b'b')]
893 893 _bad_values = ((),10, b'a', (1,b'a',3), (b'a',1), (1, u'a'))
894 894
895 895 class CRegExpTrait(HasTraits):
896 896
897 897 value = CRegExp(r'')
898 898
899 899 class TestCRegExp(TraitTestBase):
900 900
901 901 def coerce(self, value):
902 902 return re.compile(value)
903 903
904 904 obj = CRegExpTrait()
905 905
906 906 _default_value = re.compile(r'')
907 907 _good_values = [r'\d+', re.compile(r'\d+')]
908 908 _bad_values = [r'(', None, ()]
@@ -1,212 +1,212 b''
1 1 """test building messages with streamsession"""
2 2
3 3 #-------------------------------------------------------------------------------
4 4 # Copyright (C) 2011 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING, distributed as part of this software.
8 8 #-------------------------------------------------------------------------------
9 9
10 10 #-------------------------------------------------------------------------------
11 11 # Imports
12 12 #-------------------------------------------------------------------------------
13 13
14 14 import os
15 15 import uuid
16 16 import zmq
17 17
18 18 from zmq.tests import BaseZMQTestCase
19 19 from zmq.eventloop.zmqstream import ZMQStream
20 20
21 21 from IPython.zmq import session as ss
22 22
23 23 class SessionTestCase(BaseZMQTestCase):
24 24
25 25 def setUp(self):
26 26 BaseZMQTestCase.setUp(self)
27 27 self.session = ss.Session()
28 28
29 29
30 30 class MockSocket(zmq.Socket):
31 31
32 32 def __init__(self, *args, **kwargs):
33 33 super(MockSocket,self).__init__(*args,**kwargs)
34 34 self.data = []
35 35
36 36 def send_multipart(self, msgparts, *args, **kwargs):
37 37 self.data.extend(msgparts)
38 38
39 39 def send(self, part, *args, **kwargs):
40 40 self.data.append(part)
41 41
42 42 def recv_multipart(self, *args, **kwargs):
43 43 return self.data
44 44
45 45 class TestSession(SessionTestCase):
46 46
47 47 def test_msg(self):
48 48 """message format"""
49 49 msg = self.session.msg('execute')
50 50 thekeys = set('header parent_header content msg_type msg_id'.split())
51 51 s = set(msg.keys())
52 self.assertEquals(s, thekeys)
52 self.assertEqual(s, thekeys)
53 53 self.assertTrue(isinstance(msg['content'],dict))
54 54 self.assertTrue(isinstance(msg['header'],dict))
55 55 self.assertTrue(isinstance(msg['parent_header'],dict))
56 56 self.assertTrue(isinstance(msg['msg_id'],str))
57 57 self.assertTrue(isinstance(msg['msg_type'],str))
58 self.assertEquals(msg['header']['msg_type'], 'execute')
59 self.assertEquals(msg['msg_type'], 'execute')
58 self.assertEqual(msg['header']['msg_type'], 'execute')
59 self.assertEqual(msg['msg_type'], 'execute')
60 60
61 61 def test_serialize(self):
62 62 msg = self.session.msg('execute', content=dict(a=10, b=1.1))
63 63 msg_list = self.session.serialize(msg, ident=b'foo')
64 64 ident, msg_list = self.session.feed_identities(msg_list)
65 65 new_msg = self.session.unserialize(msg_list)
66 self.assertEquals(ident[0], b'foo')
67 self.assertEquals(new_msg['msg_id'],msg['msg_id'])
68 self.assertEquals(new_msg['msg_type'],msg['msg_type'])
69 self.assertEquals(new_msg['header'],msg['header'])
70 self.assertEquals(new_msg['content'],msg['content'])
71 self.assertEquals(new_msg['parent_header'],msg['parent_header'])
66 self.assertEqual(ident[0], b'foo')
67 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
68 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
69 self.assertEqual(new_msg['header'],msg['header'])
70 self.assertEqual(new_msg['content'],msg['content'])
71 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
72 72 # ensure floats don't come out as Decimal:
73 self.assertEquals(type(new_msg['content']['b']),type(new_msg['content']['b']))
73 self.assertEqual(type(new_msg['content']['b']),type(new_msg['content']['b']))
74 74
75 75 def test_send(self):
76 76 socket = MockSocket(zmq.Context.instance(),zmq.PAIR)
77 77
78 78 msg = self.session.msg('execute', content=dict(a=10))
79 79 self.session.send(socket, msg, ident=b'foo', buffers=[b'bar'])
80 80 ident, msg_list = self.session.feed_identities(socket.data)
81 81 new_msg = self.session.unserialize(msg_list)
82 self.assertEquals(ident[0], b'foo')
83 self.assertEquals(new_msg['msg_id'],msg['msg_id'])
84 self.assertEquals(new_msg['msg_type'],msg['msg_type'])
85 self.assertEquals(new_msg['header'],msg['header'])
86 self.assertEquals(new_msg['content'],msg['content'])
87 self.assertEquals(new_msg['parent_header'],msg['parent_header'])
88 self.assertEquals(new_msg['buffers'],[b'bar'])
82 self.assertEqual(ident[0], b'foo')
83 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
84 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
85 self.assertEqual(new_msg['header'],msg['header'])
86 self.assertEqual(new_msg['content'],msg['content'])
87 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
88 self.assertEqual(new_msg['buffers'],[b'bar'])
89 89
90 90 socket.data = []
91 91
92 92 content = msg['content']
93 93 header = msg['header']
94 94 parent = msg['parent_header']
95 95 msg_type = header['msg_type']
96 96 self.session.send(socket, None, content=content, parent=parent,
97 97 header=header, ident=b'foo', buffers=[b'bar'])
98 98 ident, msg_list = self.session.feed_identities(socket.data)
99 99 new_msg = self.session.unserialize(msg_list)
100 self.assertEquals(ident[0], b'foo')
101 self.assertEquals(new_msg['msg_id'],msg['msg_id'])
102 self.assertEquals(new_msg['msg_type'],msg['msg_type'])
103 self.assertEquals(new_msg['header'],msg['header'])
104 self.assertEquals(new_msg['content'],msg['content'])
105 self.assertEquals(new_msg['parent_header'],msg['parent_header'])
106 self.assertEquals(new_msg['buffers'],[b'bar'])
100 self.assertEqual(ident[0], b'foo')
101 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
102 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
103 self.assertEqual(new_msg['header'],msg['header'])
104 self.assertEqual(new_msg['content'],msg['content'])
105 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
106 self.assertEqual(new_msg['buffers'],[b'bar'])
107 107
108 108 socket.data = []
109 109
110 110 self.session.send(socket, msg, ident=b'foo', buffers=[b'bar'])
111 111 ident, new_msg = self.session.recv(socket)
112 self.assertEquals(ident[0], b'foo')
113 self.assertEquals(new_msg['msg_id'],msg['msg_id'])
114 self.assertEquals(new_msg['msg_type'],msg['msg_type'])
115 self.assertEquals(new_msg['header'],msg['header'])
116 self.assertEquals(new_msg['content'],msg['content'])
117 self.assertEquals(new_msg['parent_header'],msg['parent_header'])
118 self.assertEquals(new_msg['buffers'],[b'bar'])
112 self.assertEqual(ident[0], b'foo')
113 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
114 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
115 self.assertEqual(new_msg['header'],msg['header'])
116 self.assertEqual(new_msg['content'],msg['content'])
117 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
118 self.assertEqual(new_msg['buffers'],[b'bar'])
119 119
120 120 socket.close()
121 121
122 122 def test_args(self):
123 123 """initialization arguments for Session"""
124 124 s = self.session
125 125 self.assertTrue(s.pack is ss.default_packer)
126 126 self.assertTrue(s.unpack is ss.default_unpacker)
127 self.assertEquals(s.username, os.environ.get('USER', u'username'))
127 self.assertEqual(s.username, os.environ.get('USER', u'username'))
128 128
129 129 s = ss.Session()
130 self.assertEquals(s.username, os.environ.get('USER', u'username'))
130 self.assertEqual(s.username, os.environ.get('USER', u'username'))
131 131
132 132 self.assertRaises(TypeError, ss.Session, pack='hi')
133 133 self.assertRaises(TypeError, ss.Session, unpack='hi')
134 134 u = str(uuid.uuid4())
135 135 s = ss.Session(username=u'carrot', session=u)
136 self.assertEquals(s.session, u)
137 self.assertEquals(s.username, u'carrot')
136 self.assertEqual(s.session, u)
137 self.assertEqual(s.username, u'carrot')
138 138
139 139 def test_tracking(self):
140 140 """test tracking messages"""
141 141 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
142 142 s = self.session
143 143 stream = ZMQStream(a)
144 144 msg = s.send(a, 'hello', track=False)
145 145 self.assertTrue(msg['tracker'] is None)
146 146 msg = s.send(a, 'hello', track=True)
147 147 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
148 148 M = zmq.Message(b'hi there', track=True)
149 149 msg = s.send(a, 'hello', buffers=[M], track=True)
150 150 t = msg['tracker']
151 151 self.assertTrue(isinstance(t, zmq.MessageTracker))
152 152 self.assertRaises(zmq.NotDone, t.wait, .1)
153 153 del M
154 154 t.wait(1) # this will raise
155 155
156 156
157 157 # def test_rekey(self):
158 158 # """rekeying dict around json str keys"""
159 159 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
160 160 # self.assertRaises(KeyError, ss.rekey, d)
161 161 #
162 162 # d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()}
163 163 # d2 = {0:d['0'],1:d[1],'asdf':d['asdf']}
164 164 # rd = ss.rekey(d)
165 # self.assertEquals(d2,rd)
165 # self.assertEqual(d2,rd)
166 166 #
167 167 # d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()}
168 168 # d2 = {1.5:d['1.5'],1:d['1']}
169 169 # rd = ss.rekey(d)
170 # self.assertEquals(d2,rd)
170 # self.assertEqual(d2,rd)
171 171 #
172 172 # d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()}
173 173 # self.assertRaises(KeyError, ss.rekey, d)
174 174 #
175 175 def test_unique_msg_ids(self):
176 176 """test that messages receive unique ids"""
177 177 ids = set()
178 178 for i in range(2**12):
179 179 h = self.session.msg_header('test')
180 180 msg_id = h['msg_id']
181 181 self.assertTrue(msg_id not in ids)
182 182 ids.add(msg_id)
183 183
184 184 def test_feed_identities(self):
185 185 """scrub the front for zmq IDENTITIES"""
186 186 theids = "engine client other".split()
187 187 content = dict(code='whoda',stuff=object())
188 188 themsg = self.session.msg('execute',content=content)
189 189 pmsg = theids
190 190
191 191 def test_session_id(self):
192 192 session = ss.Session()
193 193 # get bs before us
194 194 bs = session.bsession
195 195 us = session.session
196 self.assertEquals(us.encode('ascii'), bs)
196 self.assertEqual(us.encode('ascii'), bs)
197 197 session = ss.Session()
198 198 # get us before bs
199 199 us = session.session
200 200 bs = session.bsession
201 self.assertEquals(us.encode('ascii'), bs)
201 self.assertEqual(us.encode('ascii'), bs)
202 202 # change propagates:
203 203 session.session = 'something else'
204 204 bs = session.bsession
205 205 us = session.session
206 self.assertEquals(us.encode('ascii'), bs)
206 self.assertEqual(us.encode('ascii'), bs)
207 207 session = ss.Session(session='stuff')
208 208 # get us before bs
209 self.assertEquals(session.bsession, session.session.encode('ascii'))
210 self.assertEquals(b'stuff', session.bsession)
209 self.assertEqual(session.bsession, session.session.encode('ascii'))
210 self.assertEqual(b'stuff', session.bsession)
211 211
212 212
General Comments 0
You need to be logged in to leave comments. Login now