##// END OF EJS Templates
Various fixes to tests in IPython.utils.
Thomas Kluyver -
Show More
@@ -1,85 +1,87 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Utilities for working with stack frames.
3 Utilities for working with stack frames.
4 """
4 """
5
5
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Copyright (C) 2008-2009 The IPython Development Team
7 # Copyright (C) 2008-2009 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16
16
17 import sys
17 import sys
18 from IPython.utils import py3compat
18
19
19 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
20 # Code
21 # Code
21 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
22
23
24 @py3compat.doctest_refactor_print
23 def extract_vars(*names,**kw):
25 def extract_vars(*names,**kw):
24 """Extract a set of variables by name from another frame.
26 """Extract a set of variables by name from another frame.
25
27
26 :Parameters:
28 :Parameters:
27 - `*names`: strings
29 - `*names`: strings
28 One or more variable names which will be extracted from the caller's
30 One or more variable names which will be extracted from the caller's
29 frame.
31 frame.
30
32
31 :Keywords:
33 :Keywords:
32 - `depth`: integer (0)
34 - `depth`: integer (0)
33 How many frames in the stack to walk when looking for your variables.
35 How many frames in the stack to walk when looking for your variables.
34
36
35
37
36 Examples:
38 Examples:
37
39
38 In [2]: def func(x):
40 In [2]: def func(x):
39 ...: y = 1
41 ...: y = 1
40 ...: print extract_vars('x','y')
42 ...: print extract_vars('x','y')
41 ...:
43 ...:
42
44
43 In [3]: func('hello')
45 In [3]: func('hello')
44 {'y': 1, 'x': 'hello'}
46 {'y': 1, 'x': 'hello'}
45 """
47 """
46
48
47 depth = kw.get('depth',0)
49 depth = kw.get('depth',0)
48
50
49 callerNS = sys._getframe(depth+1).f_locals
51 callerNS = sys._getframe(depth+1).f_locals
50 return dict((k,callerNS[k]) for k in names)
52 return dict((k,callerNS[k]) for k in names)
51
53
52
54
53 def extract_vars_above(*names):
55 def extract_vars_above(*names):
54 """Extract a set of variables by name from another frame.
56 """Extract a set of variables by name from another frame.
55
57
56 Similar to extractVars(), but with a specified depth of 1, so that names
58 Similar to extractVars(), but with a specified depth of 1, so that names
57 are exctracted exactly from above the caller.
59 are exctracted exactly from above the caller.
58
60
59 This is simply a convenience function so that the very common case (for us)
61 This is simply a convenience function so that the very common case (for us)
60 of skipping exactly 1 frame doesn't have to construct a special dict for
62 of skipping exactly 1 frame doesn't have to construct a special dict for
61 keyword passing."""
63 keyword passing."""
62
64
63 callerNS = sys._getframe(2).f_locals
65 callerNS = sys._getframe(2).f_locals
64 return dict((k,callerNS[k]) for k in names)
66 return dict((k,callerNS[k]) for k in names)
65
67
66
68
67 def debugx(expr,pre_msg=''):
69 def debugx(expr,pre_msg=''):
68 """Print the value of an expression from the caller's frame.
70 """Print the value of an expression from the caller's frame.
69
71
70 Takes an expression, evaluates it in the caller's frame and prints both
72 Takes an expression, evaluates it in the caller's frame and prints both
71 the given expression and the resulting value (as well as a debug mark
73 the given expression and the resulting value (as well as a debug mark
72 indicating the name of the calling function. The input must be of a form
74 indicating the name of the calling function. The input must be of a form
73 suitable for eval().
75 suitable for eval().
74
76
75 An optional message can be passed, which will be prepended to the printed
77 An optional message can be passed, which will be prepended to the printed
76 expr->value pair."""
78 expr->value pair."""
77
79
78 cf = sys._getframe(1)
80 cf = sys._getframe(1)
79 print '[DBG:%s] %s%s -> %r' % (cf.f_code.co_name,pre_msg,expr,
81 print '[DBG:%s] %s%s -> %r' % (cf.f_code.co_name,pre_msg,expr,
80 eval(expr,cf.f_globals,cf.f_locals))
82 eval(expr,cf.f_globals,cf.f_locals))
81
83
82
84
83 # deactivate it by uncommenting the following line, which makes it a no-op
85 # deactivate it by uncommenting the following line, which makes it a no-op
84 #def debugx(expr,pre_msg=''): pass
86 #def debugx(expr,pre_msg=''): pass
85
87
@@ -1,395 +1,395 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """A dict subclass that supports attribute style access.
2 """A dict subclass that supports attribute style access.
3
3
4 Authors:
4 Authors:
5
5
6 * Fernando Perez (original)
6 * Fernando Perez (original)
7 * Brian Granger (refactoring to a dict subclass)
7 * Brian Granger (refactoring to a dict subclass)
8 """
8 """
9
9
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Copyright (C) 2008-2009 The IPython Development Team
11 # Copyright (C) 2008-2009 The IPython Development Team
12 #
12 #
13 # Distributed under the terms of the BSD License. The full license is in
13 # Distributed under the terms of the BSD License. The full license is in
14 # the file COPYING, distributed as part of this software.
14 # the file COPYING, distributed as part of this software.
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16
16
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18 # Imports
18 # Imports
19 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
20
20
21 from IPython.utils.data import list2dict2
21 from IPython.utils.data import list2dict2
22
22
23 __all__ = ['Struct']
23 __all__ = ['Struct']
24
24
25 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
26 # Code
26 # Code
27 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
28
28
29
29
30 class Struct(dict):
30 class Struct(dict):
31 """A dict subclass with attribute style access.
31 """A dict subclass with attribute style access.
32
32
33 This dict subclass has a a few extra features:
33 This dict subclass has a a few extra features:
34
34
35 * Attribute style access.
35 * Attribute style access.
36 * Protection of class members (like keys, items) when using attribute
36 * Protection of class members (like keys, items) when using attribute
37 style access.
37 style access.
38 * The ability to restrict assignment to only existing keys.
38 * The ability to restrict assignment to only existing keys.
39 * Intelligent merging.
39 * Intelligent merging.
40 * Overloaded operators.
40 * Overloaded operators.
41 """
41 """
42 _allownew = True
42 _allownew = True
43 def __init__(self, *args, **kw):
43 def __init__(self, *args, **kw):
44 """Initialize with a dictionary, another Struct, or data.
44 """Initialize with a dictionary, another Struct, or data.
45
45
46 Parameters
46 Parameters
47 ----------
47 ----------
48 args : dict, Struct
48 args : dict, Struct
49 Initialize with one dict or Struct
49 Initialize with one dict or Struct
50 kw : dict
50 kw : dict
51 Initialize with key, value pairs.
51 Initialize with key, value pairs.
52
52
53 Examples
53 Examples
54 --------
54 --------
55
55
56 >>> s = Struct(a=10,b=30)
56 >>> s = Struct(a=10,b=30)
57 >>> s.a
57 >>> s.a
58 10
58 10
59 >>> s.b
59 >>> s.b
60 30
60 30
61 >>> s2 = Struct(s,c=30)
61 >>> s2 = Struct(s,c=30)
62 >>> s2.keys()
62 >>> s2.keys()
63 ['a', 'c', 'b']
63 ['a', 'c', 'b']
64 """
64 """
65 object.__setattr__(self, '_allownew', True)
65 object.__setattr__(self, '_allownew', True)
66 dict.__init__(self, *args, **kw)
66 dict.__init__(self, *args, **kw)
67
67
68 def __setitem__(self, key, value):
68 def __setitem__(self, key, value):
69 """Set an item with check for allownew.
69 """Set an item with check for allownew.
70
70
71 Examples
71 Examples
72 --------
72 --------
73
73
74 >>> s = Struct()
74 >>> s = Struct()
75 >>> s['a'] = 10
75 >>> s['a'] = 10
76 >>> s.allow_new_attr(False)
76 >>> s.allow_new_attr(False)
77 >>> s['a'] = 10
77 >>> s['a'] = 10
78 >>> s['a']
78 >>> s['a']
79 10
79 10
80 >>> try:
80 >>> try:
81 ... s['b'] = 20
81 ... s['b'] = 20
82 ... except KeyError:
82 ... except KeyError:
83 ... print 'this is not allowed'
83 ... print 'this is not allowed'
84 ...
84 ...
85 this is not allowed
85 this is not allowed
86 """
86 """
87 if not self._allownew and not self.has_key(key):
87 if not self._allownew and not self.has_key(key):
88 raise KeyError(
88 raise KeyError(
89 "can't create new attribute %s when allow_new_attr(False)" % key)
89 "can't create new attribute %s when allow_new_attr(False)" % key)
90 dict.__setitem__(self, key, value)
90 dict.__setitem__(self, key, value)
91
91
92 def __setattr__(self, key, value):
92 def __setattr__(self, key, value):
93 """Set an attr with protection of class members.
93 """Set an attr with protection of class members.
94
94
95 This calls :meth:`self.__setitem__` but convert :exc:`KeyError` to
95 This calls :meth:`self.__setitem__` but convert :exc:`KeyError` to
96 :exc:`AttributeError`.
96 :exc:`AttributeError`.
97
97
98 Examples
98 Examples
99 --------
99 --------
100
100
101 >>> s = Struct()
101 >>> s = Struct()
102 >>> s.a = 10
102 >>> s.a = 10
103 >>> s.a
103 >>> s.a
104 10
104 10
105 >>> try:
105 >>> try:
106 ... s.get = 10
106 ... s.get = 10
107 ... except AttributeError:
107 ... except AttributeError:
108 ... print "you can't set a class member"
108 ... print "you can't set a class member"
109 ...
109 ...
110 you can't set a class member
110 you can't set a class member
111 """
111 """
112 # If key is an str it might be a class member or instance var
112 # If key is an str it might be a class member or instance var
113 if isinstance(key, str):
113 if isinstance(key, str):
114 # I can't simply call hasattr here because it calls getattr, which
114 # I can't simply call hasattr here because it calls getattr, which
115 # calls self.__getattr__, which returns True for keys in
115 # calls self.__getattr__, which returns True for keys in
116 # self._data. But I only want keys in the class and in
116 # self._data. But I only want keys in the class and in
117 # self.__dict__
117 # self.__dict__
118 if key in self.__dict__ or hasattr(Struct, key):
118 if key in self.__dict__ or hasattr(Struct, key):
119 raise AttributeError(
119 raise AttributeError(
120 'attr %s is a protected member of class Struct.' % key
120 'attr %s is a protected member of class Struct.' % key
121 )
121 )
122 try:
122 try:
123 self.__setitem__(key, value)
123 self.__setitem__(key, value)
124 except KeyError, e:
124 except KeyError, e:
125 raise AttributeError(e)
125 raise AttributeError(e)
126
126
127 def __getattr__(self, key):
127 def __getattr__(self, key):
128 """Get an attr by calling :meth:`dict.__getitem__`.
128 """Get an attr by calling :meth:`dict.__getitem__`.
129
129
130 Like :meth:`__setattr__`, this method converts :exc:`KeyError` to
130 Like :meth:`__setattr__`, this method converts :exc:`KeyError` to
131 :exc:`AttributeError`.
131 :exc:`AttributeError`.
132
132
133 Examples
133 Examples
134 --------
134 --------
135
135
136 >>> s = Struct(a=10)
136 >>> s = Struct(a=10)
137 >>> s.a
137 >>> s.a
138 10
138 10
139 >>> type(s.get)
139 >>> type(s.get)
140 <type 'builtin_function_or_method'>
140 <... 'builtin_function_or_method'>
141 >>> try:
141 >>> try:
142 ... s.b
142 ... s.b
143 ... except AttributeError:
143 ... except AttributeError:
144 ... print "I don't have that key"
144 ... print "I don't have that key"
145 ...
145 ...
146 I don't have that key
146 I don't have that key
147 """
147 """
148 try:
148 try:
149 result = self[key]
149 result = self[key]
150 except KeyError:
150 except KeyError:
151 raise AttributeError(key)
151 raise AttributeError(key)
152 else:
152 else:
153 return result
153 return result
154
154
155 def __iadd__(self, other):
155 def __iadd__(self, other):
156 """s += s2 is a shorthand for s.merge(s2).
156 """s += s2 is a shorthand for s.merge(s2).
157
157
158 Examples
158 Examples
159 --------
159 --------
160
160
161 >>> s = Struct(a=10,b=30)
161 >>> s = Struct(a=10,b=30)
162 >>> s2 = Struct(a=20,c=40)
162 >>> s2 = Struct(a=20,c=40)
163 >>> s += s2
163 >>> s += s2
164 >>> s
164 >>> s
165 {'a': 10, 'c': 40, 'b': 30}
165 {'a': 10, 'c': 40, 'b': 30}
166 """
166 """
167 self.merge(other)
167 self.merge(other)
168 return self
168 return self
169
169
170 def __add__(self,other):
170 def __add__(self,other):
171 """s + s2 -> New Struct made from s.merge(s2).
171 """s + s2 -> New Struct made from s.merge(s2).
172
172
173 Examples
173 Examples
174 --------
174 --------
175
175
176 >>> s1 = Struct(a=10,b=30)
176 >>> s1 = Struct(a=10,b=30)
177 >>> s2 = Struct(a=20,c=40)
177 >>> s2 = Struct(a=20,c=40)
178 >>> s = s1 + s2
178 >>> s = s1 + s2
179 >>> s
179 >>> s
180 {'a': 10, 'c': 40, 'b': 30}
180 {'a': 10, 'c': 40, 'b': 30}
181 """
181 """
182 sout = self.copy()
182 sout = self.copy()
183 sout.merge(other)
183 sout.merge(other)
184 return sout
184 return sout
185
185
186 def __sub__(self,other):
186 def __sub__(self,other):
187 """s1 - s2 -> remove keys in s2 from s1.
187 """s1 - s2 -> remove keys in s2 from s1.
188
188
189 Examples
189 Examples
190 --------
190 --------
191
191
192 >>> s1 = Struct(a=10,b=30)
192 >>> s1 = Struct(a=10,b=30)
193 >>> s2 = Struct(a=40)
193 >>> s2 = Struct(a=40)
194 >>> s = s1 - s2
194 >>> s = s1 - s2
195 >>> s
195 >>> s
196 {'b': 30}
196 {'b': 30}
197 """
197 """
198 sout = self.copy()
198 sout = self.copy()
199 sout -= other
199 sout -= other
200 return sout
200 return sout
201
201
202 def __isub__(self,other):
202 def __isub__(self,other):
203 """Inplace remove keys from self that are in other.
203 """Inplace remove keys from self that are in other.
204
204
205 Examples
205 Examples
206 --------
206 --------
207
207
208 >>> s1 = Struct(a=10,b=30)
208 >>> s1 = Struct(a=10,b=30)
209 >>> s2 = Struct(a=40)
209 >>> s2 = Struct(a=40)
210 >>> s1 -= s2
210 >>> s1 -= s2
211 >>> s1
211 >>> s1
212 {'b': 30}
212 {'b': 30}
213 """
213 """
214 for k in other.keys():
214 for k in other.keys():
215 if self.has_key(k):
215 if self.has_key(k):
216 del self[k]
216 del self[k]
217 return self
217 return self
218
218
219 def __dict_invert(self, data):
219 def __dict_invert(self, data):
220 """Helper function for merge.
220 """Helper function for merge.
221
221
222 Takes a dictionary whose values are lists and returns a dict with
222 Takes a dictionary whose values are lists and returns a dict with
223 the elements of each list as keys and the original keys as values.
223 the elements of each list as keys and the original keys as values.
224 """
224 """
225 outdict = {}
225 outdict = {}
226 for k,lst in data.items():
226 for k,lst in data.items():
227 if isinstance(lst, str):
227 if isinstance(lst, str):
228 lst = lst.split()
228 lst = lst.split()
229 for entry in lst:
229 for entry in lst:
230 outdict[entry] = k
230 outdict[entry] = k
231 return outdict
231 return outdict
232
232
233 def dict(self):
233 def dict(self):
234 return self
234 return self
235
235
236 def copy(self):
236 def copy(self):
237 """Return a copy as a Struct.
237 """Return a copy as a Struct.
238
238
239 Examples
239 Examples
240 --------
240 --------
241
241
242 >>> s = Struct(a=10,b=30)
242 >>> s = Struct(a=10,b=30)
243 >>> s2 = s.copy()
243 >>> s2 = s.copy()
244 >>> s2
244 >>> s2
245 {'a': 10, 'b': 30}
245 {'a': 10, 'b': 30}
246 >>> type(s2).__name__
246 >>> type(s2).__name__
247 'Struct'
247 'Struct'
248 """
248 """
249 return Struct(dict.copy(self))
249 return Struct(dict.copy(self))
250
250
251 def hasattr(self, key):
251 def hasattr(self, key):
252 """hasattr function available as a method.
252 """hasattr function available as a method.
253
253
254 Implemented like has_key.
254 Implemented like has_key.
255
255
256 Examples
256 Examples
257 --------
257 --------
258
258
259 >>> s = Struct(a=10)
259 >>> s = Struct(a=10)
260 >>> s.hasattr('a')
260 >>> s.hasattr('a')
261 True
261 True
262 >>> s.hasattr('b')
262 >>> s.hasattr('b')
263 False
263 False
264 >>> s.hasattr('get')
264 >>> s.hasattr('get')
265 False
265 False
266 """
266 """
267 return self.has_key(key)
267 return self.has_key(key)
268
268
269 def allow_new_attr(self, allow = True):
269 def allow_new_attr(self, allow = True):
270 """Set whether new attributes can be created in this Struct.
270 """Set whether new attributes can be created in this Struct.
271
271
272 This can be used to catch typos by verifying that the attribute user
272 This can be used to catch typos by verifying that the attribute user
273 tries to change already exists in this Struct.
273 tries to change already exists in this Struct.
274 """
274 """
275 object.__setattr__(self, '_allownew', allow)
275 object.__setattr__(self, '_allownew', allow)
276
276
277 def merge(self, __loc_data__=None, __conflict_solve=None, **kw):
277 def merge(self, __loc_data__=None, __conflict_solve=None, **kw):
278 """Merge two Structs with customizable conflict resolution.
278 """Merge two Structs with customizable conflict resolution.
279
279
280 This is similar to :meth:`update`, but much more flexible. First, a
280 This is similar to :meth:`update`, but much more flexible. First, a
281 dict is made from data+key=value pairs. When merging this dict with
281 dict is made from data+key=value pairs. When merging this dict with
282 the Struct S, the optional dictionary 'conflict' is used to decide
282 the Struct S, the optional dictionary 'conflict' is used to decide
283 what to do.
283 what to do.
284
284
285 If conflict is not given, the default behavior is to preserve any keys
285 If conflict is not given, the default behavior is to preserve any keys
286 with their current value (the opposite of the :meth:`update` method's
286 with their current value (the opposite of the :meth:`update` method's
287 behavior).
287 behavior).
288
288
289 Parameters
289 Parameters
290 ----------
290 ----------
291 __loc_data : dict, Struct
291 __loc_data : dict, Struct
292 The data to merge into self
292 The data to merge into self
293 __conflict_solve : dict
293 __conflict_solve : dict
294 The conflict policy dict. The keys are binary functions used to
294 The conflict policy dict. The keys are binary functions used to
295 resolve the conflict and the values are lists of strings naming
295 resolve the conflict and the values are lists of strings naming
296 the keys the conflict resolution function applies to. Instead of
296 the keys the conflict resolution function applies to. Instead of
297 a list of strings a space separated string can be used, like
297 a list of strings a space separated string can be used, like
298 'a b c'.
298 'a b c'.
299 kw : dict
299 kw : dict
300 Additional key, value pairs to merge in
300 Additional key, value pairs to merge in
301
301
302 Notes
302 Notes
303 -----
303 -----
304
304
305 The `__conflict_solve` dict is a dictionary of binary functions which will be used to
305 The `__conflict_solve` dict is a dictionary of binary functions which will be used to
306 solve key conflicts. Here is an example::
306 solve key conflicts. Here is an example::
307
307
308 __conflict_solve = dict(
308 __conflict_solve = dict(
309 func1=['a','b','c'],
309 func1=['a','b','c'],
310 func2=['d','e']
310 func2=['d','e']
311 )
311 )
312
312
313 In this case, the function :func:`func1` will be used to resolve
313 In this case, the function :func:`func1` will be used to resolve
314 keys 'a', 'b' and 'c' and the function :func:`func2` will be used for
314 keys 'a', 'b' and 'c' and the function :func:`func2` will be used for
315 keys 'd' and 'e'. This could also be written as::
315 keys 'd' and 'e'. This could also be written as::
316
316
317 __conflict_solve = dict(func1='a b c',func2='d e')
317 __conflict_solve = dict(func1='a b c',func2='d e')
318
318
319 These functions will be called for each key they apply to with the
319 These functions will be called for each key they apply to with the
320 form::
320 form::
321
321
322 func1(self['a'], other['a'])
322 func1(self['a'], other['a'])
323
323
324 The return value is used as the final merged value.
324 The return value is used as the final merged value.
325
325
326 As a convenience, merge() provides five (the most commonly needed)
326 As a convenience, merge() provides five (the most commonly needed)
327 pre-defined policies: preserve, update, add, add_flip and add_s. The
327 pre-defined policies: preserve, update, add, add_flip and add_s. The
328 easiest explanation is their implementation::
328 easiest explanation is their implementation::
329
329
330 preserve = lambda old,new: old
330 preserve = lambda old,new: old
331 update = lambda old,new: new
331 update = lambda old,new: new
332 add = lambda old,new: old + new
332 add = lambda old,new: old + new
333 add_flip = lambda old,new: new + old # note change of order!
333 add_flip = lambda old,new: new + old # note change of order!
334 add_s = lambda old,new: old + ' ' + new # only for str!
334 add_s = lambda old,new: old + ' ' + new # only for str!
335
335
336 You can use those four words (as strings) as keys instead
336 You can use those four words (as strings) as keys instead
337 of defining them as functions, and the merge method will substitute
337 of defining them as functions, and the merge method will substitute
338 the appropriate functions for you.
338 the appropriate functions for you.
339
339
340 For more complicated conflict resolution policies, you still need to
340 For more complicated conflict resolution policies, you still need to
341 construct your own functions.
341 construct your own functions.
342
342
343 Examples
343 Examples
344 --------
344 --------
345
345
346 This show the default policy:
346 This show the default policy:
347
347
348 >>> s = Struct(a=10,b=30)
348 >>> s = Struct(a=10,b=30)
349 >>> s2 = Struct(a=20,c=40)
349 >>> s2 = Struct(a=20,c=40)
350 >>> s.merge(s2)
350 >>> s.merge(s2)
351 >>> s
351 >>> s
352 {'a': 10, 'c': 40, 'b': 30}
352 {'a': 10, 'c': 40, 'b': 30}
353
353
354 Now, show how to specify a conflict dict:
354 Now, show how to specify a conflict dict:
355
355
356 >>> s = Struct(a=10,b=30)
356 >>> s = Struct(a=10,b=30)
357 >>> s2 = Struct(a=20,b=40)
357 >>> s2 = Struct(a=20,b=40)
358 >>> conflict = {'update':'a','add':'b'}
358 >>> conflict = {'update':'a','add':'b'}
359 >>> s.merge(s2,conflict)
359 >>> s.merge(s2,conflict)
360 >>> s
360 >>> s
361 {'a': 20, 'b': 70}
361 {'a': 20, 'b': 70}
362 """
362 """
363
363
364 data_dict = dict(__loc_data__,**kw)
364 data_dict = dict(__loc_data__,**kw)
365
365
366 # policies for conflict resolution: two argument functions which return
366 # policies for conflict resolution: two argument functions which return
367 # the value that will go in the new struct
367 # the value that will go in the new struct
368 preserve = lambda old,new: old
368 preserve = lambda old,new: old
369 update = lambda old,new: new
369 update = lambda old,new: new
370 add = lambda old,new: old + new
370 add = lambda old,new: old + new
371 add_flip = lambda old,new: new + old # note change of order!
371 add_flip = lambda old,new: new + old # note change of order!
372 add_s = lambda old,new: old + ' ' + new
372 add_s = lambda old,new: old + ' ' + new
373
373
374 # default policy is to keep current keys when there's a conflict
374 # default policy is to keep current keys when there's a conflict
375 conflict_solve = list2dict2(self.keys(), default = preserve)
375 conflict_solve = list2dict2(self.keys(), default = preserve)
376
376
377 # the conflict_solve dictionary is given by the user 'inverted': we
377 # the conflict_solve dictionary is given by the user 'inverted': we
378 # need a name-function mapping, it comes as a function -> names
378 # need a name-function mapping, it comes as a function -> names
379 # dict. Make a local copy (b/c we'll make changes), replace user
379 # dict. Make a local copy (b/c we'll make changes), replace user
380 # strings for the three builtin policies and invert it.
380 # strings for the three builtin policies and invert it.
381 if __conflict_solve:
381 if __conflict_solve:
382 inv_conflict_solve_user = __conflict_solve.copy()
382 inv_conflict_solve_user = __conflict_solve.copy()
383 for name, func in [('preserve',preserve), ('update',update),
383 for name, func in [('preserve',preserve), ('update',update),
384 ('add',add), ('add_flip',add_flip),
384 ('add',add), ('add_flip',add_flip),
385 ('add_s',add_s)]:
385 ('add_s',add_s)]:
386 if name in inv_conflict_solve_user.keys():
386 if name in inv_conflict_solve_user.keys():
387 inv_conflict_solve_user[func] = inv_conflict_solve_user[name]
387 inv_conflict_solve_user[func] = inv_conflict_solve_user[name]
388 del inv_conflict_solve_user[name]
388 del inv_conflict_solve_user[name]
389 conflict_solve.update(self.__dict_invert(inv_conflict_solve_user))
389 conflict_solve.update(self.__dict_invert(inv_conflict_solve_user))
390 for key in data_dict:
390 for key in data_dict:
391 if key not in self:
391 if key not in self:
392 self[key] = data_dict[key]
392 self[key] = data_dict[key]
393 else:
393 else:
394 self[key] = conflict_solve[key](self[key],data_dict[key])
394 self[key] = conflict_solve[key](self[key],data_dict[key])
395
395
@@ -1,185 +1,187 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Utilities for getting information about IPython and the system it's running in.
3 Utilities for getting information about IPython and the system it's running in.
4 """
4 """
5
5
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Copyright (C) 2008-2009 The IPython Development Team
7 # Copyright (C) 2008-2009 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16
16
17 import os
17 import os
18 import platform
18 import platform
19 import pprint
19 import pprint
20 import sys
20 import sys
21 import subprocess
21 import subprocess
22
22
23 from ConfigParser import ConfigParser
23 from ConfigParser import ConfigParser
24
24
25 from IPython.core import release
25 from IPython.core import release
26 from IPython.utils import py3compat
26
27
27 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
28 # Globals
29 # Globals
29 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
30 COMMIT_INFO_FNAME = '.git_commit_info.ini'
31 COMMIT_INFO_FNAME = '.git_commit_info.ini'
31
32
32 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
33 # Code
34 # Code
34 #-----------------------------------------------------------------------------
35 #-----------------------------------------------------------------------------
35
36
36 def pkg_commit_hash(pkg_path):
37 def pkg_commit_hash(pkg_path):
37 """Get short form of commit hash given directory `pkg_path`
38 """Get short form of commit hash given directory `pkg_path`
38
39
39 There should be a file called 'COMMIT_INFO.txt' in `pkg_path`. This is a
40 There should be a file called 'COMMIT_INFO.txt' in `pkg_path`. This is a
40 file in INI file format, with at least one section: ``commit hash``, and two
41 file in INI file format, with at least one section: ``commit hash``, and two
41 variables ``archive_subst_hash`` and ``install_hash``. The first has a
42 variables ``archive_subst_hash`` and ``install_hash``. The first has a
42 substitution pattern in it which may have been filled by the execution of
43 substitution pattern in it which may have been filled by the execution of
43 ``git archive`` if this is an archive generated that way. The second is
44 ``git archive`` if this is an archive generated that way. The second is
44 filled in by the installation, if the installation is from a git archive.
45 filled in by the installation, if the installation is from a git archive.
45
46
46 We get the commit hash from (in order of preference):
47 We get the commit hash from (in order of preference):
47
48
48 * A substituted value in ``archive_subst_hash``
49 * A substituted value in ``archive_subst_hash``
49 * A written commit hash value in ``install_hash`
50 * A written commit hash value in ``install_hash`
50 * git output, if we are in a git repository
51 * git output, if we are in a git repository
51
52
52 If all these fail, we return a not-found placeholder tuple
53 If all these fail, we return a not-found placeholder tuple
53
54
54 Parameters
55 Parameters
55 ----------
56 ----------
56 pkg_path : str
57 pkg_path : str
57 directory containing package
58 directory containing package
58
59
59 Returns
60 Returns
60 -------
61 -------
61 hash_from : str
62 hash_from : str
62 Where we got the hash from - description
63 Where we got the hash from - description
63 hash_str : str
64 hash_str : str
64 short form of hash
65 short form of hash
65 """
66 """
66 # Try and get commit from written commit text file
67 # Try and get commit from written commit text file
67 pth = os.path.join(pkg_path, COMMIT_INFO_FNAME)
68 pth = os.path.join(pkg_path, COMMIT_INFO_FNAME)
68 if not os.path.isfile(pth):
69 if not os.path.isfile(pth):
69 raise IOError('Missing commit info file %s' % pth)
70 raise IOError('Missing commit info file %s' % pth)
70 cfg_parser = ConfigParser()
71 cfg_parser = ConfigParser()
71 cfg_parser.read(pth)
72 cfg_parser.read(pth)
72 archive_subst = cfg_parser.get('commit hash', 'archive_subst_hash')
73 archive_subst = cfg_parser.get('commit hash', 'archive_subst_hash')
73 if not archive_subst.startswith('$Format'): # it has been substituted
74 if not archive_subst.startswith('$Format'): # it has been substituted
74 return 'archive substitution', archive_subst
75 return 'archive substitution', archive_subst
75 install_subst = cfg_parser.get('commit hash', 'install_hash')
76 install_subst = cfg_parser.get('commit hash', 'install_hash')
76 if install_subst != '':
77 if install_subst != '':
77 return 'installation', install_subst
78 return 'installation', install_subst
78 # maybe we are in a repository
79 # maybe we are in a repository
79 proc = subprocess.Popen('git rev-parse --short HEAD',
80 proc = subprocess.Popen('git rev-parse --short HEAD',
80 stdout=subprocess.PIPE,
81 stdout=subprocess.PIPE,
81 stderr=subprocess.PIPE,
82 stderr=subprocess.PIPE,
82 cwd=pkg_path, shell=True)
83 cwd=pkg_path, shell=True)
83 repo_commit, _ = proc.communicate()
84 repo_commit, _ = proc.communicate()
84 if repo_commit:
85 if repo_commit:
85 return 'repository', repo_commit.strip()
86 return 'repository', repo_commit.strip()
86 return '(none found)', '<not found>'
87 return '(none found)', '<not found>'
87
88
88
89
89 def pkg_info(pkg_path):
90 def pkg_info(pkg_path):
90 """Return dict describing the context of this package
91 """Return dict describing the context of this package
91
92
92 Parameters
93 Parameters
93 ----------
94 ----------
94 pkg_path : str
95 pkg_path : str
95 path containing __init__.py for package
96 path containing __init__.py for package
96
97
97 Returns
98 Returns
98 -------
99 -------
99 context : dict
100 context : dict
100 with named parameters of interest
101 with named parameters of interest
101 """
102 """
102 src, hsh = pkg_commit_hash(pkg_path)
103 src, hsh = pkg_commit_hash(pkg_path)
103 return dict(
104 return dict(
104 ipython_version=release.version,
105 ipython_version=release.version,
105 ipython_path=pkg_path,
106 ipython_path=pkg_path,
106 commit_source=src,
107 commit_source=src,
107 commit_hash=hsh,
108 commit_hash=hsh,
108 sys_version=sys.version,
109 sys_version=sys.version,
109 sys_executable=sys.executable,
110 sys_executable=sys.executable,
110 sys_platform=sys.platform,
111 sys_platform=sys.platform,
111 platform=platform.platform(),
112 platform=platform.platform(),
112 os_name=os.name,
113 os_name=os.name,
113 )
114 )
114
115
115
116
117 @py3compat.doctest_refactor_print
116 def sys_info():
118 def sys_info():
117 """Return useful information about IPython and the system, as a string.
119 """Return useful information about IPython and the system, as a string.
118
120
119 Example
121 Example
120 -------
122 -------
121 In [2]: print sys_info()
123 In [2]: print sys_info()
122 {'commit_hash': '144fdae', # random
124 {'commit_hash': '144fdae', # random
123 'commit_source': 'repository',
125 'commit_source': 'repository',
124 'ipython_path': '/home/fperez/usr/lib/python2.6/site-packages/IPython',
126 'ipython_path': '/home/fperez/usr/lib/python2.6/site-packages/IPython',
125 'ipython_version': '0.11.dev',
127 'ipython_version': '0.11.dev',
126 'os_name': 'posix',
128 'os_name': 'posix',
127 'platform': 'Linux-2.6.35-22-generic-i686-with-Ubuntu-10.10-maverick',
129 'platform': 'Linux-2.6.35-22-generic-i686-with-Ubuntu-10.10-maverick',
128 'sys_executable': '/usr/bin/python',
130 'sys_executable': '/usr/bin/python',
129 'sys_platform': 'linux2',
131 'sys_platform': 'linux2',
130 'sys_version': '2.6.6 (r266:84292, Sep 15 2010, 15:52:39) \\n[GCC 4.4.5]'}
132 'sys_version': '2.6.6 (r266:84292, Sep 15 2010, 15:52:39) \\n[GCC 4.4.5]'}
131 """
133 """
132 p = os.path
134 p = os.path
133 path = p.dirname(p.abspath(p.join(__file__, '..')))
135 path = p.dirname(p.abspath(p.join(__file__, '..')))
134 return pprint.pformat(pkg_info(path))
136 return pprint.pformat(pkg_info(path))
135
137
136
138
137 def _num_cpus_unix():
139 def _num_cpus_unix():
138 """Return the number of active CPUs on a Unix system."""
140 """Return the number of active CPUs on a Unix system."""
139 return os.sysconf("SC_NPROCESSORS_ONLN")
141 return os.sysconf("SC_NPROCESSORS_ONLN")
140
142
141
143
142 def _num_cpus_darwin():
144 def _num_cpus_darwin():
143 """Return the number of active CPUs on a Darwin system."""
145 """Return the number of active CPUs on a Darwin system."""
144 p = subprocess.Popen(['sysctl','-n','hw.ncpu'],stdout=subprocess.PIPE)
146 p = subprocess.Popen(['sysctl','-n','hw.ncpu'],stdout=subprocess.PIPE)
145 return p.stdout.read()
147 return p.stdout.read()
146
148
147
149
148 def _num_cpus_windows():
150 def _num_cpus_windows():
149 """Return the number of active CPUs on a Windows system."""
151 """Return the number of active CPUs on a Windows system."""
150 return os.environ.get("NUMBER_OF_PROCESSORS")
152 return os.environ.get("NUMBER_OF_PROCESSORS")
151
153
152
154
153 def num_cpus():
155 def num_cpus():
154 """Return the effective number of CPUs in the system as an integer.
156 """Return the effective number of CPUs in the system as an integer.
155
157
156 This cross-platform function makes an attempt at finding the total number of
158 This cross-platform function makes an attempt at finding the total number of
157 available CPUs in the system, as returned by various underlying system and
159 available CPUs in the system, as returned by various underlying system and
158 python calls.
160 python calls.
159
161
160 If it can't find a sensible answer, it returns 1 (though an error *may* make
162 If it can't find a sensible answer, it returns 1 (though an error *may* make
161 it return a large positive number that's actually incorrect).
163 it return a large positive number that's actually incorrect).
162 """
164 """
163
165
164 # Many thanks to the Parallel Python project (http://www.parallelpython.com)
166 # Many thanks to the Parallel Python project (http://www.parallelpython.com)
165 # for the names of the keys we needed to look up for this function. This
167 # for the names of the keys we needed to look up for this function. This
166 # code was inspired by their equivalent function.
168 # code was inspired by their equivalent function.
167
169
168 ncpufuncs = {'Linux':_num_cpus_unix,
170 ncpufuncs = {'Linux':_num_cpus_unix,
169 'Darwin':_num_cpus_darwin,
171 'Darwin':_num_cpus_darwin,
170 'Windows':_num_cpus_windows,
172 'Windows':_num_cpus_windows,
171 # On Vista, python < 2.5.2 has a bug and returns 'Microsoft'
173 # On Vista, python < 2.5.2 has a bug and returns 'Microsoft'
172 # See http://bugs.python.org/issue1082 for details.
174 # See http://bugs.python.org/issue1082 for details.
173 'Microsoft':_num_cpus_windows,
175 'Microsoft':_num_cpus_windows,
174 }
176 }
175
177
176 ncpufunc = ncpufuncs.get(platform.system(),
178 ncpufunc = ncpufuncs.get(platform.system(),
177 # default to unix version (Solaris, AIX, etc)
179 # default to unix version (Solaris, AIX, etc)
178 _num_cpus_unix)
180 _num_cpus_unix)
179
181
180 try:
182 try:
181 ncpus = max(1,int(ncpufunc()))
183 ncpus = max(1,int(ncpufunc()))
182 except:
184 except:
183 ncpus = 1
185 ncpus = 1
184 return ncpus
186 return ncpus
185
187
@@ -1,71 +1,75 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Tests for io.py"""
2 """Tests for io.py"""
3
3
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Copyright (C) 2008 The IPython Development Team
5 # Copyright (C) 2008 The IPython Development Team
6 #
6 #
7 # Distributed under the terms of the BSD License. The full license is in
7 # Distributed under the terms of the BSD License. The full license is in
8 # the file COPYING, distributed as part of this software.
8 # the file COPYING, distributed as part of this software.
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Imports
12 # Imports
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 import sys
15 import sys
16
16
17 from StringIO import StringIO
17 from StringIO import StringIO
18 from subprocess import Popen, PIPE
18 from subprocess import Popen, PIPE
19
19
20 import nose.tools as nt
20 import nose.tools as nt
21
21
22 from IPython.testing import decorators as dec
22 from IPython.testing import decorators as dec
23 from IPython.utils.io import Tee
23 from IPython.utils.io import Tee
24 from IPython.utils.py3compat import doctest_refactor_print
24
25
25 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
26 # Tests
27 # Tests
27 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
28
29
29
30
30 def test_tee_simple():
31 def test_tee_simple():
31 "Very simple check with stdout only"
32 "Very simple check with stdout only"
32 chan = StringIO()
33 chan = StringIO()
33 text = 'Hello'
34 text = 'Hello'
34 tee = Tee(chan, channel='stdout')
35 tee = Tee(chan, channel='stdout')
35 print >> chan, text,
36 print >> chan, text
36 nt.assert_equal(chan.getvalue(), text)
37 nt.assert_equal(chan.getvalue(), text+"\n")
37
38
38
39
39 class TeeTestCase(dec.ParametricTestCase):
40 class TeeTestCase(dec.ParametricTestCase):
40
41
41 def tchan(self, channel, check='close'):
42 def tchan(self, channel, check='close'):
42 trap = StringIO()
43 trap = StringIO()
43 chan = StringIO()
44 chan = StringIO()
44 text = 'Hello'
45 text = 'Hello'
45
46
46 std_ori = getattr(sys, channel)
47 std_ori = getattr(sys, channel)
47 setattr(sys, channel, trap)
48 setattr(sys, channel, trap)
48
49
49 tee = Tee(chan, channel=channel)
50 tee = Tee(chan, channel=channel)
50 print >> chan, text,
51 print >> chan, text,
51 setattr(sys, channel, std_ori)
52 setattr(sys, channel, std_ori)
52 trap_val = trap.getvalue()
53 trap_val = trap.getvalue()
53 nt.assert_equals(chan.getvalue(), text)
54 nt.assert_equals(chan.getvalue(), text)
54 if check=='close':
55 if check=='close':
55 tee.close()
56 tee.close()
56 else:
57 else:
57 del tee
58 del tee
58
59
59 def test(self):
60 def test(self):
60 for chan in ['stdout', 'stderr']:
61 for chan in ['stdout', 'stderr']:
61 for check in ['close', 'del']:
62 for check in ['close', 'del']:
62 yield self.tchan(chan, check)
63 yield self.tchan(chan, check)
63
64
64 def test_io_init():
65 def test_io_init():
65 """Test that io.stdin/out/err exist at startup"""
66 """Test that io.stdin/out/err exist at startup"""
66 for name in ('stdin', 'stdout', 'stderr'):
67 for name in ('stdin', 'stdout', 'stderr'):
67 p = Popen([sys.executable, '-c', "from IPython.utils import io;print io.%s.__class__"%name],
68 cmd = doctest_refactor_print("from IPython.utils import io;print io.%s.__class__"%name)
69 p = Popen([sys.executable, '-c', cmd],
68 stdout=PIPE)
70 stdout=PIPE)
69 p.wait()
71 p.wait()
70 classname = p.stdout.read().strip()
72 classname = p.stdout.read().strip().decode('ascii')
71 nt.assert_equals(classname, 'IPython.utils.io.IOStream')
73 # __class__ is a reference to the class object in Python 3, so we can't
74 # just test for string equality.
75 assert 'IPython.utils.io.IOStream' in classname, classname
@@ -1,450 +1,450 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Tests for IPython.utils.path.py"""
2 """Tests for IPython.utils.path.py"""
3
3
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Copyright (C) 2008 The IPython Development Team
5 # Copyright (C) 2008 The IPython Development Team
6 #
6 #
7 # Distributed under the terms of the BSD License. The full license is in
7 # Distributed under the terms of the BSD License. The full license is in
8 # the file COPYING, distributed as part of this software.
8 # the file COPYING, distributed as part of this software.
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Imports
12 # Imports
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 from __future__ import with_statement
15 from __future__ import with_statement
16
16
17 import os
17 import os
18 import shutil
18 import shutil
19 import sys
19 import sys
20 import tempfile
20 import tempfile
21 import StringIO
21 from io import StringIO
22
22
23 from os.path import join, abspath, split
23 from os.path import join, abspath, split
24
24
25 import nose.tools as nt
25 import nose.tools as nt
26
26
27 from nose import with_setup
27 from nose import with_setup
28
28
29 import IPython
29 import IPython
30 from IPython.testing import decorators as dec
30 from IPython.testing import decorators as dec
31 from IPython.testing.decorators import skip_if_not_win32, skip_win32
31 from IPython.testing.decorators import skip_if_not_win32, skip_win32
32 from IPython.testing.tools import make_tempfile
32 from IPython.testing.tools import make_tempfile
33 from IPython.utils import path, io
33 from IPython.utils import path, io
34 from IPython.utils import py3compat
34 from IPython.utils import py3compat
35
35
36 # Platform-dependent imports
36 # Platform-dependent imports
37 try:
37 try:
38 import _winreg as wreg
38 import _winreg as wreg
39 except ImportError:
39 except ImportError:
40 #Fake _winreg module on none windows platforms
40 #Fake _winreg module on none windows platforms
41 import types
41 import types
42 wr_name = "winreg" if py3compat.PY3 else "_winreg"
42 wr_name = "winreg" if py3compat.PY3 else "_winreg"
43 sys.modules[wr_name] = types.ModuleType(wr_name)
43 sys.modules[wr_name] = types.ModuleType(wr_name)
44 import _winreg as wreg
44 import _winreg as wreg
45 #Add entries that needs to be stubbed by the testing code
45 #Add entries that needs to be stubbed by the testing code
46 (wreg.OpenKey, wreg.QueryValueEx,) = (None, None)
46 (wreg.OpenKey, wreg.QueryValueEx,) = (None, None)
47
47
48 try:
48 try:
49 reload
49 reload
50 except NameError: # Python 3
50 except NameError: # Python 3
51 from imp import reload
51 from imp import reload
52
52
53 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
54 # Globals
54 # Globals
55 #-----------------------------------------------------------------------------
55 #-----------------------------------------------------------------------------
56 env = os.environ
56 env = os.environ
57 TEST_FILE_PATH = split(abspath(__file__))[0]
57 TEST_FILE_PATH = split(abspath(__file__))[0]
58 TMP_TEST_DIR = tempfile.mkdtemp()
58 TMP_TEST_DIR = tempfile.mkdtemp()
59 HOME_TEST_DIR = join(TMP_TEST_DIR, "home_test_dir")
59 HOME_TEST_DIR = join(TMP_TEST_DIR, "home_test_dir")
60 XDG_TEST_DIR = join(HOME_TEST_DIR, "xdg_test_dir")
60 XDG_TEST_DIR = join(HOME_TEST_DIR, "xdg_test_dir")
61 IP_TEST_DIR = join(HOME_TEST_DIR,'.ipython')
61 IP_TEST_DIR = join(HOME_TEST_DIR,'.ipython')
62 #
62 #
63 # Setup/teardown functions/decorators
63 # Setup/teardown functions/decorators
64 #
64 #
65
65
66 def setup():
66 def setup():
67 """Setup testenvironment for the module:
67 """Setup testenvironment for the module:
68
68
69 - Adds dummy home dir tree
69 - Adds dummy home dir tree
70 """
70 """
71 # Do not mask exceptions here. In particular, catching WindowsError is a
71 # Do not mask exceptions here. In particular, catching WindowsError is a
72 # problem because that exception is only defined on Windows...
72 # problem because that exception is only defined on Windows...
73 os.makedirs(IP_TEST_DIR)
73 os.makedirs(IP_TEST_DIR)
74 os.makedirs(os.path.join(XDG_TEST_DIR, 'ipython'))
74 os.makedirs(os.path.join(XDG_TEST_DIR, 'ipython'))
75
75
76
76
77 def teardown():
77 def teardown():
78 """Teardown testenvironment for the module:
78 """Teardown testenvironment for the module:
79
79
80 - Remove dummy home dir tree
80 - Remove dummy home dir tree
81 """
81 """
82 # Note: we remove the parent test dir, which is the root of all test
82 # Note: we remove the parent test dir, which is the root of all test
83 # subdirs we may have created. Use shutil instead of os.removedirs, so
83 # subdirs we may have created. Use shutil instead of os.removedirs, so
84 # that non-empty directories are all recursively removed.
84 # that non-empty directories are all recursively removed.
85 shutil.rmtree(TMP_TEST_DIR)
85 shutil.rmtree(TMP_TEST_DIR)
86
86
87
87
88 def setup_environment():
88 def setup_environment():
89 """Setup testenvironment for some functions that are tested
89 """Setup testenvironment for some functions that are tested
90 in this module. In particular this functions stores attributes
90 in this module. In particular this functions stores attributes
91 and other things that we need to stub in some test functions.
91 and other things that we need to stub in some test functions.
92 This needs to be done on a function level and not module level because
92 This needs to be done on a function level and not module level because
93 each testfunction needs a pristine environment.
93 each testfunction needs a pristine environment.
94 """
94 """
95 global oldstuff, platformstuff
95 global oldstuff, platformstuff
96 oldstuff = (env.copy(), os.name, path.get_home_dir, IPython.__file__, os.getcwd())
96 oldstuff = (env.copy(), os.name, path.get_home_dir, IPython.__file__, os.getcwd())
97
97
98 if os.name == 'nt':
98 if os.name == 'nt':
99 platformstuff = (wreg.OpenKey, wreg.QueryValueEx,)
99 platformstuff = (wreg.OpenKey, wreg.QueryValueEx,)
100
100
101
101
102 def teardown_environment():
102 def teardown_environment():
103 """Restore things that were remebered by the setup_environment function
103 """Restore things that were remebered by the setup_environment function
104 """
104 """
105 (oldenv, os.name, path.get_home_dir, IPython.__file__, old_wd) = oldstuff
105 (oldenv, os.name, path.get_home_dir, IPython.__file__, old_wd) = oldstuff
106 os.chdir(old_wd)
106 os.chdir(old_wd)
107 reload(path)
107 reload(path)
108
108
109 for key in env.keys():
109 for key in env.keys():
110 if key not in oldenv:
110 if key not in oldenv:
111 del env[key]
111 del env[key]
112 env.update(oldenv)
112 env.update(oldenv)
113 if hasattr(sys, 'frozen'):
113 if hasattr(sys, 'frozen'):
114 del sys.frozen
114 del sys.frozen
115 if os.name == 'nt':
115 if os.name == 'nt':
116 (wreg.OpenKey, wreg.QueryValueEx,) = platformstuff
116 (wreg.OpenKey, wreg.QueryValueEx,) = platformstuff
117
117
118 # Build decorator that uses the setup_environment/setup_environment
118 # Build decorator that uses the setup_environment/setup_environment
119 with_environment = with_setup(setup_environment, teardown_environment)
119 with_environment = with_setup(setup_environment, teardown_environment)
120
120
121
121
122 @skip_if_not_win32
122 @skip_if_not_win32
123 @with_environment
123 @with_environment
124 def test_get_home_dir_1():
124 def test_get_home_dir_1():
125 """Testcase for py2exe logic, un-compressed lib
125 """Testcase for py2exe logic, un-compressed lib
126 """
126 """
127 sys.frozen = True
127 sys.frozen = True
128
128
129 #fake filename for IPython.__init__
129 #fake filename for IPython.__init__
130 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Lib/IPython/__init__.py"))
130 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Lib/IPython/__init__.py"))
131
131
132 home_dir = path.get_home_dir()
132 home_dir = path.get_home_dir()
133 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
133 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
134
134
135
135
136 @skip_if_not_win32
136 @skip_if_not_win32
137 @with_environment
137 @with_environment
138 def test_get_home_dir_2():
138 def test_get_home_dir_2():
139 """Testcase for py2exe logic, compressed lib
139 """Testcase for py2exe logic, compressed lib
140 """
140 """
141 sys.frozen = True
141 sys.frozen = True
142 #fake filename for IPython.__init__
142 #fake filename for IPython.__init__
143 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Library.zip/IPython/__init__.py")).lower()
143 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Library.zip/IPython/__init__.py")).lower()
144
144
145 home_dir = path.get_home_dir()
145 home_dir = path.get_home_dir()
146 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR).lower())
146 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR).lower())
147
147
148
148
149 @with_environment
149 @with_environment
150 @skip_win32
150 @skip_win32
151 def test_get_home_dir_3():
151 def test_get_home_dir_3():
152 """Testcase $HOME is set, then use its value as home directory."""
152 """Testcase $HOME is set, then use its value as home directory."""
153 env["HOME"] = HOME_TEST_DIR
153 env["HOME"] = HOME_TEST_DIR
154 home_dir = path.get_home_dir()
154 home_dir = path.get_home_dir()
155 nt.assert_equal(home_dir, env["HOME"])
155 nt.assert_equal(home_dir, env["HOME"])
156
156
157
157
158 @with_environment
158 @with_environment
159 @skip_win32
159 @skip_win32
160 def test_get_home_dir_4():
160 def test_get_home_dir_4():
161 """Testcase $HOME is not set, os=='posix'.
161 """Testcase $HOME is not set, os=='posix'.
162 This should fail with HomeDirError"""
162 This should fail with HomeDirError"""
163
163
164 os.name = 'posix'
164 os.name = 'posix'
165 if 'HOME' in env: del env['HOME']
165 if 'HOME' in env: del env['HOME']
166 nt.assert_raises(path.HomeDirError, path.get_home_dir)
166 nt.assert_raises(path.HomeDirError, path.get_home_dir)
167
167
168
168
169 @skip_if_not_win32
169 @skip_if_not_win32
170 @with_environment
170 @with_environment
171 def test_get_home_dir_5():
171 def test_get_home_dir_5():
172 """Using HOMEDRIVE + HOMEPATH, os=='nt'.
172 """Using HOMEDRIVE + HOMEPATH, os=='nt'.
173
173
174 HOMESHARE is missing.
174 HOMESHARE is missing.
175 """
175 """
176
176
177 os.name = 'nt'
177 os.name = 'nt'
178 env.pop('HOMESHARE', None)
178 env.pop('HOMESHARE', None)
179 env['HOMEDRIVE'], env['HOMEPATH'] = os.path.splitdrive(HOME_TEST_DIR)
179 env['HOMEDRIVE'], env['HOMEPATH'] = os.path.splitdrive(HOME_TEST_DIR)
180 home_dir = path.get_home_dir()
180 home_dir = path.get_home_dir()
181 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
181 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
182
182
183
183
184 @skip_if_not_win32
184 @skip_if_not_win32
185 @with_environment
185 @with_environment
186 def test_get_home_dir_6():
186 def test_get_home_dir_6():
187 """Using USERPROFILE, os=='nt'.
187 """Using USERPROFILE, os=='nt'.
188
188
189 HOMESHARE, HOMEDRIVE, HOMEPATH are missing.
189 HOMESHARE, HOMEDRIVE, HOMEPATH are missing.
190 """
190 """
191
191
192 os.name = 'nt'
192 os.name = 'nt'
193 env.pop('HOMESHARE', None)
193 env.pop('HOMESHARE', None)
194 env.pop('HOMEDRIVE', None)
194 env.pop('HOMEDRIVE', None)
195 env.pop('HOMEPATH', None)
195 env.pop('HOMEPATH', None)
196 env["USERPROFILE"] = abspath(HOME_TEST_DIR)
196 env["USERPROFILE"] = abspath(HOME_TEST_DIR)
197 home_dir = path.get_home_dir()
197 home_dir = path.get_home_dir()
198 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
198 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
199
199
200
200
201 @skip_if_not_win32
201 @skip_if_not_win32
202 @with_environment
202 @with_environment
203 def test_get_home_dir_7():
203 def test_get_home_dir_7():
204 """Using HOMESHARE, os=='nt'."""
204 """Using HOMESHARE, os=='nt'."""
205
205
206 os.name = 'nt'
206 os.name = 'nt'
207 env["HOMESHARE"] = abspath(HOME_TEST_DIR)
207 env["HOMESHARE"] = abspath(HOME_TEST_DIR)
208 home_dir = path.get_home_dir()
208 home_dir = path.get_home_dir()
209 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
209 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
210
210
211
211
212 # Should we stub wreg fully so we can run the test on all platforms?
212 # Should we stub wreg fully so we can run the test on all platforms?
213 @skip_if_not_win32
213 @skip_if_not_win32
214 @with_environment
214 @with_environment
215 def test_get_home_dir_8():
215 def test_get_home_dir_8():
216 """Using registry hack for 'My Documents', os=='nt'
216 """Using registry hack for 'My Documents', os=='nt'
217
217
218 HOMESHARE, HOMEDRIVE, HOMEPATH, USERPROFILE and others are missing.
218 HOMESHARE, HOMEDRIVE, HOMEPATH, USERPROFILE and others are missing.
219 """
219 """
220 os.name = 'nt'
220 os.name = 'nt'
221 # Remove from stub environment all keys that may be set
221 # Remove from stub environment all keys that may be set
222 for key in ['HOME', 'HOMESHARE', 'HOMEDRIVE', 'HOMEPATH', 'USERPROFILE']:
222 for key in ['HOME', 'HOMESHARE', 'HOMEDRIVE', 'HOMEPATH', 'USERPROFILE']:
223 env.pop(key, None)
223 env.pop(key, None)
224
224
225 #Stub windows registry functions
225 #Stub windows registry functions
226 def OpenKey(x, y):
226 def OpenKey(x, y):
227 class key:
227 class key:
228 def Close(self):
228 def Close(self):
229 pass
229 pass
230 return key()
230 return key()
231 def QueryValueEx(x, y):
231 def QueryValueEx(x, y):
232 return [abspath(HOME_TEST_DIR)]
232 return [abspath(HOME_TEST_DIR)]
233
233
234 wreg.OpenKey = OpenKey
234 wreg.OpenKey = OpenKey
235 wreg.QueryValueEx = QueryValueEx
235 wreg.QueryValueEx = QueryValueEx
236
236
237 home_dir = path.get_home_dir()
237 home_dir = path.get_home_dir()
238 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
238 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
239
239
240
240
241 @with_environment
241 @with_environment
242 def test_get_ipython_dir_1():
242 def test_get_ipython_dir_1():
243 """test_get_ipython_dir_1, Testcase to see if we can call get_ipython_dir without Exceptions."""
243 """test_get_ipython_dir_1, Testcase to see if we can call get_ipython_dir without Exceptions."""
244 env_ipdir = os.path.join("someplace", ".ipython")
244 env_ipdir = os.path.join("someplace", ".ipython")
245 path._writable_dir = lambda path: True
245 path._writable_dir = lambda path: True
246 env['IPYTHON_DIR'] = env_ipdir
246 env['IPYTHON_DIR'] = env_ipdir
247 ipdir = path.get_ipython_dir()
247 ipdir = path.get_ipython_dir()
248 nt.assert_equal(ipdir, env_ipdir)
248 nt.assert_equal(ipdir, env_ipdir)
249
249
250
250
251 @with_environment
251 @with_environment
252 def test_get_ipython_dir_2():
252 def test_get_ipython_dir_2():
253 """test_get_ipython_dir_2, Testcase to see if we can call get_ipython_dir without Exceptions."""
253 """test_get_ipython_dir_2, Testcase to see if we can call get_ipython_dir without Exceptions."""
254 path.get_home_dir = lambda : "someplace"
254 path.get_home_dir = lambda : "someplace"
255 path.get_xdg_dir = lambda : None
255 path.get_xdg_dir = lambda : None
256 path._writable_dir = lambda path: True
256 path._writable_dir = lambda path: True
257 os.name = "posix"
257 os.name = "posix"
258 env.pop('IPYTHON_DIR', None)
258 env.pop('IPYTHON_DIR', None)
259 env.pop('IPYTHONDIR', None)
259 env.pop('IPYTHONDIR', None)
260 env.pop('XDG_CONFIG_HOME', None)
260 env.pop('XDG_CONFIG_HOME', None)
261 ipdir = path.get_ipython_dir()
261 ipdir = path.get_ipython_dir()
262 nt.assert_equal(ipdir, os.path.join("someplace", ".ipython"))
262 nt.assert_equal(ipdir, os.path.join("someplace", ".ipython"))
263
263
264 @with_environment
264 @with_environment
265 def test_get_ipython_dir_3():
265 def test_get_ipython_dir_3():
266 """test_get_ipython_dir_3, use XDG if defined, and .ipython doesn't exist."""
266 """test_get_ipython_dir_3, use XDG if defined, and .ipython doesn't exist."""
267 path.get_home_dir = lambda : "someplace"
267 path.get_home_dir = lambda : "someplace"
268 path._writable_dir = lambda path: True
268 path._writable_dir = lambda path: True
269 os.name = "posix"
269 os.name = "posix"
270 env.pop('IPYTHON_DIR', None)
270 env.pop('IPYTHON_DIR', None)
271 env.pop('IPYTHONDIR', None)
271 env.pop('IPYTHONDIR', None)
272 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
272 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
273 ipdir = path.get_ipython_dir()
273 ipdir = path.get_ipython_dir()
274 nt.assert_equal(ipdir, os.path.join(XDG_TEST_DIR, "ipython"))
274 nt.assert_equal(ipdir, os.path.join(XDG_TEST_DIR, "ipython"))
275
275
276 @with_environment
276 @with_environment
277 def test_get_ipython_dir_4():
277 def test_get_ipython_dir_4():
278 """test_get_ipython_dir_4, use XDG if both exist."""
278 """test_get_ipython_dir_4, use XDG if both exist."""
279 path.get_home_dir = lambda : HOME_TEST_DIR
279 path.get_home_dir = lambda : HOME_TEST_DIR
280 os.name = "posix"
280 os.name = "posix"
281 env.pop('IPYTHON_DIR', None)
281 env.pop('IPYTHON_DIR', None)
282 env.pop('IPYTHONDIR', None)
282 env.pop('IPYTHONDIR', None)
283 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
283 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
284 xdg_ipdir = os.path.join(XDG_TEST_DIR, "ipython")
284 xdg_ipdir = os.path.join(XDG_TEST_DIR, "ipython")
285 ipdir = path.get_ipython_dir()
285 ipdir = path.get_ipython_dir()
286 nt.assert_equal(ipdir, xdg_ipdir)
286 nt.assert_equal(ipdir, xdg_ipdir)
287
287
288 @with_environment
288 @with_environment
289 def test_get_ipython_dir_5():
289 def test_get_ipython_dir_5():
290 """test_get_ipython_dir_5, use .ipython if exists and XDG defined, but doesn't exist."""
290 """test_get_ipython_dir_5, use .ipython if exists and XDG defined, but doesn't exist."""
291 path.get_home_dir = lambda : HOME_TEST_DIR
291 path.get_home_dir = lambda : HOME_TEST_DIR
292 os.name = "posix"
292 os.name = "posix"
293 env.pop('IPYTHON_DIR', None)
293 env.pop('IPYTHON_DIR', None)
294 env.pop('IPYTHONDIR', None)
294 env.pop('IPYTHONDIR', None)
295 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
295 env['XDG_CONFIG_HOME'] = XDG_TEST_DIR
296 os.rmdir(os.path.join(XDG_TEST_DIR, 'ipython'))
296 os.rmdir(os.path.join(XDG_TEST_DIR, 'ipython'))
297 ipdir = path.get_ipython_dir()
297 ipdir = path.get_ipython_dir()
298 nt.assert_equal(ipdir, IP_TEST_DIR)
298 nt.assert_equal(ipdir, IP_TEST_DIR)
299
299
300 @with_environment
300 @with_environment
301 def test_get_ipython_dir_6():
301 def test_get_ipython_dir_6():
302 """test_get_ipython_dir_6, use XDG if defined and neither exist."""
302 """test_get_ipython_dir_6, use XDG if defined and neither exist."""
303 xdg = os.path.join(HOME_TEST_DIR, 'somexdg')
303 xdg = os.path.join(HOME_TEST_DIR, 'somexdg')
304 os.mkdir(xdg)
304 os.mkdir(xdg)
305 shutil.rmtree(os.path.join(HOME_TEST_DIR, '.ipython'))
305 shutil.rmtree(os.path.join(HOME_TEST_DIR, '.ipython'))
306 path.get_home_dir = lambda : HOME_TEST_DIR
306 path.get_home_dir = lambda : HOME_TEST_DIR
307 path.get_xdg_dir = lambda : xdg
307 path.get_xdg_dir = lambda : xdg
308 os.name = "posix"
308 os.name = "posix"
309 env.pop('IPYTHON_DIR', None)
309 env.pop('IPYTHON_DIR', None)
310 env.pop('IPYTHONDIR', None)
310 env.pop('IPYTHONDIR', None)
311 env.pop('XDG_CONFIG_HOME', None)
311 env.pop('XDG_CONFIG_HOME', None)
312 xdg_ipdir = os.path.join(xdg, "ipython")
312 xdg_ipdir = os.path.join(xdg, "ipython")
313 ipdir = path.get_ipython_dir()
313 ipdir = path.get_ipython_dir()
314 nt.assert_equal(ipdir, xdg_ipdir)
314 nt.assert_equal(ipdir, xdg_ipdir)
315
315
316 @with_environment
316 @with_environment
317 def test_get_ipython_dir_7():
317 def test_get_ipython_dir_7():
318 """test_get_ipython_dir_7, test home directory expansion on IPYTHON_DIR"""
318 """test_get_ipython_dir_7, test home directory expansion on IPYTHON_DIR"""
319 path._writable_dir = lambda path: True
319 path._writable_dir = lambda path: True
320 home_dir = os.path.expanduser('~')
320 home_dir = os.path.expanduser('~')
321 env['IPYTHON_DIR'] = os.path.join('~', 'somewhere')
321 env['IPYTHON_DIR'] = os.path.join('~', 'somewhere')
322 ipdir = path.get_ipython_dir()
322 ipdir = path.get_ipython_dir()
323 nt.assert_equal(ipdir, os.path.join(home_dir, 'somewhere'))
323 nt.assert_equal(ipdir, os.path.join(home_dir, 'somewhere'))
324
324
325
325
326 @with_environment
326 @with_environment
327 def test_get_xdg_dir_1():
327 def test_get_xdg_dir_1():
328 """test_get_xdg_dir_1, check xdg_dir"""
328 """test_get_xdg_dir_1, check xdg_dir"""
329 reload(path)
329 reload(path)
330 path._writable_dir = lambda path: True
330 path._writable_dir = lambda path: True
331 path.get_home_dir = lambda : 'somewhere'
331 path.get_home_dir = lambda : 'somewhere'
332 os.name = "posix"
332 os.name = "posix"
333 env.pop('IPYTHON_DIR', None)
333 env.pop('IPYTHON_DIR', None)
334 env.pop('IPYTHONDIR', None)
334 env.pop('IPYTHONDIR', None)
335 env.pop('XDG_CONFIG_HOME', None)
335 env.pop('XDG_CONFIG_HOME', None)
336
336
337 nt.assert_equal(path.get_xdg_dir(), os.path.join('somewhere', '.config'))
337 nt.assert_equal(path.get_xdg_dir(), os.path.join('somewhere', '.config'))
338
338
339
339
340 @with_environment
340 @with_environment
341 def test_get_xdg_dir_1():
341 def test_get_xdg_dir_1():
342 """test_get_xdg_dir_1, check nonexistant xdg_dir"""
342 """test_get_xdg_dir_1, check nonexistant xdg_dir"""
343 reload(path)
343 reload(path)
344 path.get_home_dir = lambda : HOME_TEST_DIR
344 path.get_home_dir = lambda : HOME_TEST_DIR
345 os.name = "posix"
345 os.name = "posix"
346 env.pop('IPYTHON_DIR', None)
346 env.pop('IPYTHON_DIR', None)
347 env.pop('IPYTHONDIR', None)
347 env.pop('IPYTHONDIR', None)
348 env.pop('XDG_CONFIG_HOME', None)
348 env.pop('XDG_CONFIG_HOME', None)
349 nt.assert_equal(path.get_xdg_dir(), None)
349 nt.assert_equal(path.get_xdg_dir(), None)
350
350
351 @with_environment
351 @with_environment
352 def test_get_xdg_dir_2():
352 def test_get_xdg_dir_2():
353 """test_get_xdg_dir_2, check xdg_dir default to ~/.config"""
353 """test_get_xdg_dir_2, check xdg_dir default to ~/.config"""
354 reload(path)
354 reload(path)
355 path.get_home_dir = lambda : HOME_TEST_DIR
355 path.get_home_dir = lambda : HOME_TEST_DIR
356 os.name = "posix"
356 os.name = "posix"
357 env.pop('IPYTHON_DIR', None)
357 env.pop('IPYTHON_DIR', None)
358 env.pop('IPYTHONDIR', None)
358 env.pop('IPYTHONDIR', None)
359 env.pop('XDG_CONFIG_HOME', None)
359 env.pop('XDG_CONFIG_HOME', None)
360 cfgdir=os.path.join(path.get_home_dir(), '.config')
360 cfgdir=os.path.join(path.get_home_dir(), '.config')
361 os.makedirs(cfgdir)
361 os.makedirs(cfgdir)
362
362
363 nt.assert_equal(path.get_xdg_dir(), cfgdir)
363 nt.assert_equal(path.get_xdg_dir(), cfgdir)
364
364
365 def test_filefind():
365 def test_filefind():
366 """Various tests for filefind"""
366 """Various tests for filefind"""
367 f = tempfile.NamedTemporaryFile()
367 f = tempfile.NamedTemporaryFile()
368 # print 'fname:',f.name
368 # print 'fname:',f.name
369 alt_dirs = path.get_ipython_dir()
369 alt_dirs = path.get_ipython_dir()
370 t = path.filefind(f.name, alt_dirs)
370 t = path.filefind(f.name, alt_dirs)
371 # print 'found:',t
371 # print 'found:',t
372
372
373
373
374 def test_get_ipython_package_dir():
374 def test_get_ipython_package_dir():
375 ipdir = path.get_ipython_package_dir()
375 ipdir = path.get_ipython_package_dir()
376 nt.assert_true(os.path.isdir(ipdir))
376 nt.assert_true(os.path.isdir(ipdir))
377
377
378
378
379 def test_get_ipython_module_path():
379 def test_get_ipython_module_path():
380 ipapp_path = path.get_ipython_module_path('IPython.frontend.terminal.ipapp')
380 ipapp_path = path.get_ipython_module_path('IPython.frontend.terminal.ipapp')
381 nt.assert_true(os.path.isfile(ipapp_path))
381 nt.assert_true(os.path.isfile(ipapp_path))
382
382
383
383
384 @dec.skip_if_not_win32
384 @dec.skip_if_not_win32
385 def test_get_long_path_name_win32():
385 def test_get_long_path_name_win32():
386 p = path.get_long_path_name('c:\\docume~1')
386 p = path.get_long_path_name('c:\\docume~1')
387 nt.assert_equals(p,u'c:\\Documents and Settings')
387 nt.assert_equals(p,u'c:\\Documents and Settings')
388
388
389
389
390 @dec.skip_win32
390 @dec.skip_win32
391 def test_get_long_path_name():
391 def test_get_long_path_name():
392 p = path.get_long_path_name('/usr/local')
392 p = path.get_long_path_name('/usr/local')
393 nt.assert_equals(p,'/usr/local')
393 nt.assert_equals(p,'/usr/local')
394
394
395 @dec.skip_win32 # can't create not-user-writable dir on win
395 @dec.skip_win32 # can't create not-user-writable dir on win
396 @with_environment
396 @with_environment
397 def test_not_writable_ipdir():
397 def test_not_writable_ipdir():
398 tmpdir = tempfile.mkdtemp()
398 tmpdir = tempfile.mkdtemp()
399 os.name = "posix"
399 os.name = "posix"
400 env.pop('IPYTHON_DIR', None)
400 env.pop('IPYTHON_DIR', None)
401 env.pop('IPYTHONDIR', None)
401 env.pop('IPYTHONDIR', None)
402 env.pop('XDG_CONFIG_HOME', None)
402 env.pop('XDG_CONFIG_HOME', None)
403 env['HOME'] = tmpdir
403 env['HOME'] = tmpdir
404 ipdir = os.path.join(tmpdir, '.ipython')
404 ipdir = os.path.join(tmpdir, '.ipython')
405 os.mkdir(ipdir)
405 os.mkdir(ipdir)
406 os.chmod(ipdir, 600)
406 os.chmod(ipdir, 600)
407 stderr = io.stderr
407 stderr = io.stderr
408 pipe = StringIO.StringIO()
408 pipe = StringIO()
409 io.stderr = pipe
409 io.stderr = pipe
410 ipdir = path.get_ipython_dir()
410 ipdir = path.get_ipython_dir()
411 io.stderr.flush()
411 io.stderr.flush()
412 io.stderr = stderr
412 io.stderr = stderr
413 nt.assert_true('WARNING' in pipe.getvalue())
413 nt.assert_true('WARNING' in pipe.getvalue())
414 env.pop('IPYTHON_DIR', None)
414 env.pop('IPYTHON_DIR', None)
415
415
416 def test_unquote_filename():
416 def test_unquote_filename():
417 for win32 in (True, False):
417 for win32 in (True, False):
418 nt.assert_equals(path.unquote_filename('foo.py', win32=win32), 'foo.py')
418 nt.assert_equals(path.unquote_filename('foo.py', win32=win32), 'foo.py')
419 nt.assert_equals(path.unquote_filename('foo bar.py', win32=win32), 'foo bar.py')
419 nt.assert_equals(path.unquote_filename('foo bar.py', win32=win32), 'foo bar.py')
420 nt.assert_equals(path.unquote_filename('"foo.py"', win32=True), 'foo.py')
420 nt.assert_equals(path.unquote_filename('"foo.py"', win32=True), 'foo.py')
421 nt.assert_equals(path.unquote_filename('"foo bar.py"', win32=True), 'foo bar.py')
421 nt.assert_equals(path.unquote_filename('"foo bar.py"', win32=True), 'foo bar.py')
422 nt.assert_equals(path.unquote_filename("'foo.py'", win32=True), 'foo.py')
422 nt.assert_equals(path.unquote_filename("'foo.py'", win32=True), 'foo.py')
423 nt.assert_equals(path.unquote_filename("'foo bar.py'", win32=True), 'foo bar.py')
423 nt.assert_equals(path.unquote_filename("'foo bar.py'", win32=True), 'foo bar.py')
424 nt.assert_equals(path.unquote_filename('"foo.py"', win32=False), '"foo.py"')
424 nt.assert_equals(path.unquote_filename('"foo.py"', win32=False), '"foo.py"')
425 nt.assert_equals(path.unquote_filename('"foo bar.py"', win32=False), '"foo bar.py"')
425 nt.assert_equals(path.unquote_filename('"foo bar.py"', win32=False), '"foo bar.py"')
426 nt.assert_equals(path.unquote_filename("'foo.py'", win32=False), "'foo.py'")
426 nt.assert_equals(path.unquote_filename("'foo.py'", win32=False), "'foo.py'")
427 nt.assert_equals(path.unquote_filename("'foo bar.py'", win32=False), "'foo bar.py'")
427 nt.assert_equals(path.unquote_filename("'foo bar.py'", win32=False), "'foo bar.py'")
428
428
429 @with_environment
429 @with_environment
430 def test_get_py_filename():
430 def test_get_py_filename():
431 os.chdir(TMP_TEST_DIR)
431 os.chdir(TMP_TEST_DIR)
432 for win32 in (True, False):
432 for win32 in (True, False):
433 with make_tempfile('foo.py'):
433 with make_tempfile('foo.py'):
434 nt.assert_equals(path.get_py_filename('foo.py', force_win32=win32), 'foo.py')
434 nt.assert_equals(path.get_py_filename('foo.py', force_win32=win32), 'foo.py')
435 nt.assert_equals(path.get_py_filename('foo', force_win32=win32), 'foo.py')
435 nt.assert_equals(path.get_py_filename('foo', force_win32=win32), 'foo.py')
436 with make_tempfile('foo'):
436 with make_tempfile('foo'):
437 nt.assert_equals(path.get_py_filename('foo', force_win32=win32), 'foo')
437 nt.assert_equals(path.get_py_filename('foo', force_win32=win32), 'foo')
438 nt.assert_raises(IOError, path.get_py_filename, 'foo.py', force_win32=win32)
438 nt.assert_raises(IOError, path.get_py_filename, 'foo.py', force_win32=win32)
439 nt.assert_raises(IOError, path.get_py_filename, 'foo', force_win32=win32)
439 nt.assert_raises(IOError, path.get_py_filename, 'foo', force_win32=win32)
440 nt.assert_raises(IOError, path.get_py_filename, 'foo.py', force_win32=win32)
440 nt.assert_raises(IOError, path.get_py_filename, 'foo.py', force_win32=win32)
441 true_fn = 'foo with spaces.py'
441 true_fn = 'foo with spaces.py'
442 with make_tempfile(true_fn):
442 with make_tempfile(true_fn):
443 nt.assert_equals(path.get_py_filename('foo with spaces', force_win32=win32), true_fn)
443 nt.assert_equals(path.get_py_filename('foo with spaces', force_win32=win32), true_fn)
444 nt.assert_equals(path.get_py_filename('foo with spaces.py', force_win32=win32), true_fn)
444 nt.assert_equals(path.get_py_filename('foo with spaces.py', force_win32=win32), true_fn)
445 if win32:
445 if win32:
446 nt.assert_equals(path.get_py_filename('"foo with spaces.py"', force_win32=True), true_fn)
446 nt.assert_equals(path.get_py_filename('"foo with spaces.py"', force_win32=True), true_fn)
447 nt.assert_equals(path.get_py_filename("'foo with spaces.py'", force_win32=True), true_fn)
447 nt.assert_equals(path.get_py_filename("'foo with spaces.py'", force_win32=True), true_fn)
448 else:
448 else:
449 nt.assert_raises(IOError, path.get_py_filename, '"foo with spaces.py"', force_win32=False)
449 nt.assert_raises(IOError, path.get_py_filename, '"foo with spaces.py"', force_win32=False)
450 nt.assert_raises(IOError, path.get_py_filename, "'foo with spaces.py'", force_win32=False)
450 nt.assert_raises(IOError, path.get_py_filename, "'foo with spaces.py'", force_win32=False)
@@ -1,87 +1,87 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Tests for IPython.utils.text"""
2 """Tests for IPython.utils.text"""
3
3
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Copyright (C) 2011 The IPython Development Team
5 # Copyright (C) 2011 The IPython Development Team
6 #
6 #
7 # Distributed under the terms of the BSD License. The full license is in
7 # Distributed under the terms of the BSD License. The full license is in
8 # the file COPYING, distributed as part of this software.
8 # the file COPYING, distributed as part of this software.
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Imports
12 # Imports
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 import os
15 import os
16 import math
16 import math
17
17
18 import nose.tools as nt
18 import nose.tools as nt
19
19
20 from nose import with_setup
20 from nose import with_setup
21
21
22 from IPython.testing import decorators as dec
22 from IPython.testing import decorators as dec
23 from IPython.utils import text
23 from IPython.utils import text
24
24
25 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
26 # Globals
26 # Globals
27 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
28
28
29 def test_columnize():
29 def test_columnize():
30 """Basic columnize tests."""
30 """Basic columnize tests."""
31 size = 5
31 size = 5
32 items = [l*size for l in 'abc']
32 items = [l*size for l in 'abc']
33 out = text.columnize(items, displaywidth=80)
33 out = text.columnize(items, displaywidth=80)
34 nt.assert_equals(out, 'aaaaa bbbbb ccccc\n')
34 nt.assert_equals(out, 'aaaaa bbbbb ccccc\n')
35 out = text.columnize(items, displaywidth=10)
35 out = text.columnize(items, displaywidth=10)
36 nt.assert_equals(out, 'aaaaa ccccc\nbbbbb\n')
36 nt.assert_equals(out, 'aaaaa ccccc\nbbbbb\n')
37
37
38
38
39 def test_columnize_long():
39 def test_columnize_long():
40 """Test columnize with inputs longer than the display window"""
40 """Test columnize with inputs longer than the display window"""
41 text.columnize(['a'*81, 'b'*81], displaywidth=80)
41 text.columnize(['a'*81, 'b'*81], displaywidth=80)
42 size = 11
42 size = 11
43 items = [l*size for l in 'abc']
43 items = [l*size for l in 'abc']
44 out = text.columnize(items, displaywidth=size-1)
44 out = text.columnize(items, displaywidth=size-1)
45 nt.assert_equals(out, '\n'.join(items+['']))
45 nt.assert_equals(out, '\n'.join(items+['']))
46
46
47 def test_eval_formatter():
47 def test_eval_formatter():
48 f = text.EvalFormatter()
48 f = text.EvalFormatter()
49 ns = dict(n=12, pi=math.pi, stuff='hello there', os=os)
49 ns = dict(n=12, pi=math.pi, stuff='hello there', os=os)
50 s = f.format("{n} {n/4} {stuff.split()[0]}", **ns)
50 s = f.format("{n} {n//4} {stuff.split()[0]}", **ns)
51 nt.assert_equals(s, "12 3 hello")
51 nt.assert_equals(s, "12 3 hello")
52 s = f.format(' '.join(['{n//%i}'%i for i in range(1,8)]), **ns)
52 s = f.format(' '.join(['{n//%i}'%i for i in range(1,8)]), **ns)
53 nt.assert_equals(s, "12 6 4 3 2 2 1")
53 nt.assert_equals(s, "12 6 4 3 2 2 1")
54 s = f.format('{[n//i for i in range(1,8)]}', **ns)
54 s = f.format('{[n//i for i in range(1,8)]}', **ns)
55 nt.assert_equals(s, "[12, 6, 4, 3, 2, 2, 1]")
55 nt.assert_equals(s, "[12, 6, 4, 3, 2, 2, 1]")
56 s = f.format("{stuff!s}", **ns)
56 s = f.format("{stuff!s}", **ns)
57 nt.assert_equals(s, ns['stuff'])
57 nt.assert_equals(s, ns['stuff'])
58 s = f.format("{stuff!r}", **ns)
58 s = f.format("{stuff!r}", **ns)
59 nt.assert_equals(s, repr(ns['stuff']))
59 nt.assert_equals(s, repr(ns['stuff']))
60
60
61 nt.assert_raises(NameError, f.format, '{dne}', **ns)
61 nt.assert_raises(NameError, f.format, '{dne}', **ns)
62
62
63
63
64 def test_eval_formatter_slicing():
64 def test_eval_formatter_slicing():
65 f = text.EvalFormatter()
65 f = text.EvalFormatter()
66 f.allow_slicing = True
66 f.allow_slicing = True
67 ns = dict(n=12, pi=math.pi, stuff='hello there', os=os)
67 ns = dict(n=12, pi=math.pi, stuff='hello there', os=os)
68 s = f.format(" {stuff.split()[:]} ", **ns)
68 s = f.format(" {stuff.split()[:]} ", **ns)
69 nt.assert_equals(s, " ['hello', 'there'] ")
69 nt.assert_equals(s, " ['hello', 'there'] ")
70 s = f.format(" {stuff.split()[::-1]} ", **ns)
70 s = f.format(" {stuff.split()[::-1]} ", **ns)
71 nt.assert_equals(s, " ['there', 'hello'] ")
71 nt.assert_equals(s, " ['there', 'hello'] ")
72 s = f.format("{stuff[::2]}", **ns)
72 s = f.format("{stuff[::2]}", **ns)
73 nt.assert_equals(s, ns['stuff'][::2])
73 nt.assert_equals(s, ns['stuff'][::2])
74
74
75 nt.assert_raises(SyntaxError, f.format, "{n:x}", **ns)
75 nt.assert_raises(SyntaxError, f.format, "{n:x}", **ns)
76
76
77
77
78 def test_eval_formatter_no_slicing():
78 def test_eval_formatter_no_slicing():
79 f = text.EvalFormatter()
79 f = text.EvalFormatter()
80 f.allow_slicing = False
80 f.allow_slicing = False
81 ns = dict(n=12, pi=math.pi, stuff='hello there', os=os)
81 ns = dict(n=12, pi=math.pi, stuff='hello there', os=os)
82
82
83 s = f.format('{n:x} {pi**2:+f}', **ns)
83 s = f.format('{n:x} {pi**2:+f}', **ns)
84 nt.assert_equals(s, "c +9.869604")
84 nt.assert_equals(s, "c +9.869604")
85
85
86 nt.assert_raises(SyntaxError, f.format, "{a[:]}")
86 nt.assert_raises(SyntaxError, f.format, "{a[:]}")
87
87
@@ -1,846 +1,855 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Tests for IPython.utils.traitlets.
3 Tests for IPython.utils.traitlets.
4
4
5 Authors:
5 Authors:
6
6
7 * Brian Granger
7 * Brian Granger
8 * Enthought, Inc. Some of the code in this file comes from enthought.traits
8 * Enthought, Inc. Some of the code in this file comes from enthought.traits
9 and is licensed under the BSD license. Also, many of the ideas also come
9 and is licensed under the BSD license. Also, many of the ideas also come
10 from enthought.traits even though our implementation is very different.
10 from enthought.traits even though our implementation is very different.
11 """
11 """
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Copyright (C) 2008-2009 The IPython Development Team
14 # Copyright (C) 2008-2009 The IPython Development Team
15 #
15 #
16 # Distributed under the terms of the BSD License. The full license is in
16 # Distributed under the terms of the BSD License. The full license is in
17 # the file COPYING, distributed as part of this software.
17 # the file COPYING, distributed as part of this software.
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21 # Imports
21 # Imports
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23
23
24 import sys
24 import sys
25 from unittest import TestCase
25 from unittest import TestCase
26
26
27 from IPython.utils.traitlets import (
27 from IPython.utils.traitlets import (
28 HasTraits, MetaHasTraits, TraitType, Any, CBytes,
28 HasTraits, MetaHasTraits, TraitType, Any, CBytes,
29 Int, Long, Float, Complex, Bytes, Unicode, TraitError,
29 Int, Long, Float, Complex, Bytes, Unicode, TraitError,
30 Undefined, Type, This, Instance, TCPAddress, List, Tuple,
30 Undefined, Type, This, Instance, TCPAddress, List, Tuple,
31 ObjectName, DottedObjectName
31 ObjectName, DottedObjectName
32 )
32 )
33
33 from IPython.utils import py3compat
34
34
35 #-----------------------------------------------------------------------------
35 #-----------------------------------------------------------------------------
36 # Helper classes for testing
36 # Helper classes for testing
37 #-----------------------------------------------------------------------------
37 #-----------------------------------------------------------------------------
38
38
39
39
40 class HasTraitsStub(HasTraits):
40 class HasTraitsStub(HasTraits):
41
41
42 def _notify_trait(self, name, old, new):
42 def _notify_trait(self, name, old, new):
43 self._notify_name = name
43 self._notify_name = name
44 self._notify_old = old
44 self._notify_old = old
45 self._notify_new = new
45 self._notify_new = new
46
46
47
47
48 #-----------------------------------------------------------------------------
48 #-----------------------------------------------------------------------------
49 # Test classes
49 # Test classes
50 #-----------------------------------------------------------------------------
50 #-----------------------------------------------------------------------------
51
51
52
52
53 class TestTraitType(TestCase):
53 class TestTraitType(TestCase):
54
54
55 def test_get_undefined(self):
55 def test_get_undefined(self):
56 class A(HasTraits):
56 class A(HasTraits):
57 a = TraitType
57 a = TraitType
58 a = A()
58 a = A()
59 self.assertEquals(a.a, Undefined)
59 self.assertEquals(a.a, Undefined)
60
60
61 def test_set(self):
61 def test_set(self):
62 class A(HasTraitsStub):
62 class A(HasTraitsStub):
63 a = TraitType
63 a = TraitType
64
64
65 a = A()
65 a = A()
66 a.a = 10
66 a.a = 10
67 self.assertEquals(a.a, 10)
67 self.assertEquals(a.a, 10)
68 self.assertEquals(a._notify_name, 'a')
68 self.assertEquals(a._notify_name, 'a')
69 self.assertEquals(a._notify_old, Undefined)
69 self.assertEquals(a._notify_old, Undefined)
70 self.assertEquals(a._notify_new, 10)
70 self.assertEquals(a._notify_new, 10)
71
71
72 def test_validate(self):
72 def test_validate(self):
73 class MyTT(TraitType):
73 class MyTT(TraitType):
74 def validate(self, inst, value):
74 def validate(self, inst, value):
75 return -1
75 return -1
76 class A(HasTraitsStub):
76 class A(HasTraitsStub):
77 tt = MyTT
77 tt = MyTT
78
78
79 a = A()
79 a = A()
80 a.tt = 10
80 a.tt = 10
81 self.assertEquals(a.tt, -1)
81 self.assertEquals(a.tt, -1)
82
82
83 def test_default_validate(self):
83 def test_default_validate(self):
84 class MyIntTT(TraitType):
84 class MyIntTT(TraitType):
85 def validate(self, obj, value):
85 def validate(self, obj, value):
86 if isinstance(value, int):
86 if isinstance(value, int):
87 return value
87 return value
88 self.error(obj, value)
88 self.error(obj, value)
89 class A(HasTraits):
89 class A(HasTraits):
90 tt = MyIntTT(10)
90 tt = MyIntTT(10)
91 a = A()
91 a = A()
92 self.assertEquals(a.tt, 10)
92 self.assertEquals(a.tt, 10)
93
93
94 # Defaults are validated when the HasTraits is instantiated
94 # Defaults are validated when the HasTraits is instantiated
95 class B(HasTraits):
95 class B(HasTraits):
96 tt = MyIntTT('bad default')
96 tt = MyIntTT('bad default')
97 self.assertRaises(TraitError, B)
97 self.assertRaises(TraitError, B)
98
98
99 def test_is_valid_for(self):
99 def test_is_valid_for(self):
100 class MyTT(TraitType):
100 class MyTT(TraitType):
101 def is_valid_for(self, value):
101 def is_valid_for(self, value):
102 return True
102 return True
103 class A(HasTraits):
103 class A(HasTraits):
104 tt = MyTT
104 tt = MyTT
105
105
106 a = A()
106 a = A()
107 a.tt = 10
107 a.tt = 10
108 self.assertEquals(a.tt, 10)
108 self.assertEquals(a.tt, 10)
109
109
110 def test_value_for(self):
110 def test_value_for(self):
111 class MyTT(TraitType):
111 class MyTT(TraitType):
112 def value_for(self, value):
112 def value_for(self, value):
113 return 20
113 return 20
114 class A(HasTraits):
114 class A(HasTraits):
115 tt = MyTT
115 tt = MyTT
116
116
117 a = A()
117 a = A()
118 a.tt = 10
118 a.tt = 10
119 self.assertEquals(a.tt, 20)
119 self.assertEquals(a.tt, 20)
120
120
121 def test_info(self):
121 def test_info(self):
122 class A(HasTraits):
122 class A(HasTraits):
123 tt = TraitType
123 tt = TraitType
124 a = A()
124 a = A()
125 self.assertEquals(A.tt.info(), 'any value')
125 self.assertEquals(A.tt.info(), 'any value')
126
126
127 def test_error(self):
127 def test_error(self):
128 class A(HasTraits):
128 class A(HasTraits):
129 tt = TraitType
129 tt = TraitType
130 a = A()
130 a = A()
131 self.assertRaises(TraitError, A.tt.error, a, 10)
131 self.assertRaises(TraitError, A.tt.error, a, 10)
132
132
133 def test_dynamic_initializer(self):
133 def test_dynamic_initializer(self):
134 class A(HasTraits):
134 class A(HasTraits):
135 x = Int(10)
135 x = Int(10)
136 def _x_default(self):
136 def _x_default(self):
137 return 11
137 return 11
138 class B(A):
138 class B(A):
139 x = Int(20)
139 x = Int(20)
140 class C(A):
140 class C(A):
141 def _x_default(self):
141 def _x_default(self):
142 return 21
142 return 21
143
143
144 a = A()
144 a = A()
145 self.assertEquals(a._trait_values, {})
145 self.assertEquals(a._trait_values, {})
146 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
146 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
147 self.assertEquals(a.x, 11)
147 self.assertEquals(a.x, 11)
148 self.assertEquals(a._trait_values, {'x': 11})
148 self.assertEquals(a._trait_values, {'x': 11})
149 b = B()
149 b = B()
150 self.assertEquals(b._trait_values, {'x': 20})
150 self.assertEquals(b._trait_values, {'x': 20})
151 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
151 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
152 self.assertEquals(b.x, 20)
152 self.assertEquals(b.x, 20)
153 c = C()
153 c = C()
154 self.assertEquals(c._trait_values, {})
154 self.assertEquals(c._trait_values, {})
155 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
155 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
156 self.assertEquals(c.x, 21)
156 self.assertEquals(c.x, 21)
157 self.assertEquals(c._trait_values, {'x': 21})
157 self.assertEquals(c._trait_values, {'x': 21})
158 # Ensure that the base class remains unmolested when the _default
158 # Ensure that the base class remains unmolested when the _default
159 # initializer gets overridden in a subclass.
159 # initializer gets overridden in a subclass.
160 a = A()
160 a = A()
161 c = C()
161 c = C()
162 self.assertEquals(a._trait_values, {})
162 self.assertEquals(a._trait_values, {})
163 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
163 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
164 self.assertEquals(a.x, 11)
164 self.assertEquals(a.x, 11)
165 self.assertEquals(a._trait_values, {'x': 11})
165 self.assertEquals(a._trait_values, {'x': 11})
166
166
167
167
168
168
169 class TestHasTraitsMeta(TestCase):
169 class TestHasTraitsMeta(TestCase):
170
170
171 def test_metaclass(self):
171 def test_metaclass(self):
172 self.assertEquals(type(HasTraits), MetaHasTraits)
172 self.assertEquals(type(HasTraits), MetaHasTraits)
173
173
174 class A(HasTraits):
174 class A(HasTraits):
175 a = Int
175 a = Int
176
176
177 a = A()
177 a = A()
178 self.assertEquals(type(a.__class__), MetaHasTraits)
178 self.assertEquals(type(a.__class__), MetaHasTraits)
179 self.assertEquals(a.a,0)
179 self.assertEquals(a.a,0)
180 a.a = 10
180 a.a = 10
181 self.assertEquals(a.a,10)
181 self.assertEquals(a.a,10)
182
182
183 class B(HasTraits):
183 class B(HasTraits):
184 b = Int()
184 b = Int()
185
185
186 b = B()
186 b = B()
187 self.assertEquals(b.b,0)
187 self.assertEquals(b.b,0)
188 b.b = 10
188 b.b = 10
189 self.assertEquals(b.b,10)
189 self.assertEquals(b.b,10)
190
190
191 class C(HasTraits):
191 class C(HasTraits):
192 c = Int(30)
192 c = Int(30)
193
193
194 c = C()
194 c = C()
195 self.assertEquals(c.c,30)
195 self.assertEquals(c.c,30)
196 c.c = 10
196 c.c = 10
197 self.assertEquals(c.c,10)
197 self.assertEquals(c.c,10)
198
198
199 def test_this_class(self):
199 def test_this_class(self):
200 class A(HasTraits):
200 class A(HasTraits):
201 t = This()
201 t = This()
202 tt = This()
202 tt = This()
203 class B(A):
203 class B(A):
204 tt = This()
204 tt = This()
205 ttt = This()
205 ttt = This()
206 self.assertEquals(A.t.this_class, A)
206 self.assertEquals(A.t.this_class, A)
207 self.assertEquals(B.t.this_class, A)
207 self.assertEquals(B.t.this_class, A)
208 self.assertEquals(B.tt.this_class, B)
208 self.assertEquals(B.tt.this_class, B)
209 self.assertEquals(B.ttt.this_class, B)
209 self.assertEquals(B.ttt.this_class, B)
210
210
211 class TestHasTraitsNotify(TestCase):
211 class TestHasTraitsNotify(TestCase):
212
212
213 def setUp(self):
213 def setUp(self):
214 self._notify1 = []
214 self._notify1 = []
215 self._notify2 = []
215 self._notify2 = []
216
216
217 def notify1(self, name, old, new):
217 def notify1(self, name, old, new):
218 self._notify1.append((name, old, new))
218 self._notify1.append((name, old, new))
219
219
220 def notify2(self, name, old, new):
220 def notify2(self, name, old, new):
221 self._notify2.append((name, old, new))
221 self._notify2.append((name, old, new))
222
222
223 def test_notify_all(self):
223 def test_notify_all(self):
224
224
225 class A(HasTraits):
225 class A(HasTraits):
226 a = Int
226 a = Int
227 b = Float
227 b = Float
228
228
229 a = A()
229 a = A()
230 a.on_trait_change(self.notify1)
230 a.on_trait_change(self.notify1)
231 a.a = 0
231 a.a = 0
232 self.assertEquals(len(self._notify1),0)
232 self.assertEquals(len(self._notify1),0)
233 a.b = 0.0
233 a.b = 0.0
234 self.assertEquals(len(self._notify1),0)
234 self.assertEquals(len(self._notify1),0)
235 a.a = 10
235 a.a = 10
236 self.assert_(('a',0,10) in self._notify1)
236 self.assert_(('a',0,10) in self._notify1)
237 a.b = 10.0
237 a.b = 10.0
238 self.assert_(('b',0.0,10.0) in self._notify1)
238 self.assert_(('b',0.0,10.0) in self._notify1)
239 self.assertRaises(TraitError,setattr,a,'a','bad string')
239 self.assertRaises(TraitError,setattr,a,'a','bad string')
240 self.assertRaises(TraitError,setattr,a,'b','bad string')
240 self.assertRaises(TraitError,setattr,a,'b','bad string')
241 self._notify1 = []
241 self._notify1 = []
242 a.on_trait_change(self.notify1,remove=True)
242 a.on_trait_change(self.notify1,remove=True)
243 a.a = 20
243 a.a = 20
244 a.b = 20.0
244 a.b = 20.0
245 self.assertEquals(len(self._notify1),0)
245 self.assertEquals(len(self._notify1),0)
246
246
247 def test_notify_one(self):
247 def test_notify_one(self):
248
248
249 class A(HasTraits):
249 class A(HasTraits):
250 a = Int
250 a = Int
251 b = Float
251 b = Float
252
252
253 a = A()
253 a = A()
254 a.on_trait_change(self.notify1, 'a')
254 a.on_trait_change(self.notify1, 'a')
255 a.a = 0
255 a.a = 0
256 self.assertEquals(len(self._notify1),0)
256 self.assertEquals(len(self._notify1),0)
257 a.a = 10
257 a.a = 10
258 self.assert_(('a',0,10) in self._notify1)
258 self.assert_(('a',0,10) in self._notify1)
259 self.assertRaises(TraitError,setattr,a,'a','bad string')
259 self.assertRaises(TraitError,setattr,a,'a','bad string')
260
260
261 def test_subclass(self):
261 def test_subclass(self):
262
262
263 class A(HasTraits):
263 class A(HasTraits):
264 a = Int
264 a = Int
265
265
266 class B(A):
266 class B(A):
267 b = Float
267 b = Float
268
268
269 b = B()
269 b = B()
270 self.assertEquals(b.a,0)
270 self.assertEquals(b.a,0)
271 self.assertEquals(b.b,0.0)
271 self.assertEquals(b.b,0.0)
272 b.a = 100
272 b.a = 100
273 b.b = 100.0
273 b.b = 100.0
274 self.assertEquals(b.a,100)
274 self.assertEquals(b.a,100)
275 self.assertEquals(b.b,100.0)
275 self.assertEquals(b.b,100.0)
276
276
277 def test_notify_subclass(self):
277 def test_notify_subclass(self):
278
278
279 class A(HasTraits):
279 class A(HasTraits):
280 a = Int
280 a = Int
281
281
282 class B(A):
282 class B(A):
283 b = Float
283 b = Float
284
284
285 b = B()
285 b = B()
286 b.on_trait_change(self.notify1, 'a')
286 b.on_trait_change(self.notify1, 'a')
287 b.on_trait_change(self.notify2, 'b')
287 b.on_trait_change(self.notify2, 'b')
288 b.a = 0
288 b.a = 0
289 b.b = 0.0
289 b.b = 0.0
290 self.assertEquals(len(self._notify1),0)
290 self.assertEquals(len(self._notify1),0)
291 self.assertEquals(len(self._notify2),0)
291 self.assertEquals(len(self._notify2),0)
292 b.a = 10
292 b.a = 10
293 b.b = 10.0
293 b.b = 10.0
294 self.assert_(('a',0,10) in self._notify1)
294 self.assert_(('a',0,10) in self._notify1)
295 self.assert_(('b',0.0,10.0) in self._notify2)
295 self.assert_(('b',0.0,10.0) in self._notify2)
296
296
297 def test_static_notify(self):
297 def test_static_notify(self):
298
298
299 class A(HasTraits):
299 class A(HasTraits):
300 a = Int
300 a = Int
301 _notify1 = []
301 _notify1 = []
302 def _a_changed(self, name, old, new):
302 def _a_changed(self, name, old, new):
303 self._notify1.append((name, old, new))
303 self._notify1.append((name, old, new))
304
304
305 a = A()
305 a = A()
306 a.a = 0
306 a.a = 0
307 # This is broken!!!
307 # This is broken!!!
308 self.assertEquals(len(a._notify1),0)
308 self.assertEquals(len(a._notify1),0)
309 a.a = 10
309 a.a = 10
310 self.assert_(('a',0,10) in a._notify1)
310 self.assert_(('a',0,10) in a._notify1)
311
311
312 class B(A):
312 class B(A):
313 b = Float
313 b = Float
314 _notify2 = []
314 _notify2 = []
315 def _b_changed(self, name, old, new):
315 def _b_changed(self, name, old, new):
316 self._notify2.append((name, old, new))
316 self._notify2.append((name, old, new))
317
317
318 b = B()
318 b = B()
319 b.a = 10
319 b.a = 10
320 b.b = 10.0
320 b.b = 10.0
321 self.assert_(('a',0,10) in b._notify1)
321 self.assert_(('a',0,10) in b._notify1)
322 self.assert_(('b',0.0,10.0) in b._notify2)
322 self.assert_(('b',0.0,10.0) in b._notify2)
323
323
324 def test_notify_args(self):
324 def test_notify_args(self):
325
325
326 def callback0():
326 def callback0():
327 self.cb = ()
327 self.cb = ()
328 def callback1(name):
328 def callback1(name):
329 self.cb = (name,)
329 self.cb = (name,)
330 def callback2(name, new):
330 def callback2(name, new):
331 self.cb = (name, new)
331 self.cb = (name, new)
332 def callback3(name, old, new):
332 def callback3(name, old, new):
333 self.cb = (name, old, new)
333 self.cb = (name, old, new)
334
334
335 class A(HasTraits):
335 class A(HasTraits):
336 a = Int
336 a = Int
337
337
338 a = A()
338 a = A()
339 a.on_trait_change(callback0, 'a')
339 a.on_trait_change(callback0, 'a')
340 a.a = 10
340 a.a = 10
341 self.assertEquals(self.cb,())
341 self.assertEquals(self.cb,())
342 a.on_trait_change(callback0, 'a', remove=True)
342 a.on_trait_change(callback0, 'a', remove=True)
343
343
344 a.on_trait_change(callback1, 'a')
344 a.on_trait_change(callback1, 'a')
345 a.a = 100
345 a.a = 100
346 self.assertEquals(self.cb,('a',))
346 self.assertEquals(self.cb,('a',))
347 a.on_trait_change(callback1, 'a', remove=True)
347 a.on_trait_change(callback1, 'a', remove=True)
348
348
349 a.on_trait_change(callback2, 'a')
349 a.on_trait_change(callback2, 'a')
350 a.a = 1000
350 a.a = 1000
351 self.assertEquals(self.cb,('a',1000))
351 self.assertEquals(self.cb,('a',1000))
352 a.on_trait_change(callback2, 'a', remove=True)
352 a.on_trait_change(callback2, 'a', remove=True)
353
353
354 a.on_trait_change(callback3, 'a')
354 a.on_trait_change(callback3, 'a')
355 a.a = 10000
355 a.a = 10000
356 self.assertEquals(self.cb,('a',1000,10000))
356 self.assertEquals(self.cb,('a',1000,10000))
357 a.on_trait_change(callback3, 'a', remove=True)
357 a.on_trait_change(callback3, 'a', remove=True)
358
358
359 self.assertEquals(len(a._trait_notifiers['a']),0)
359 self.assertEquals(len(a._trait_notifiers['a']),0)
360
360
361
361
362 class TestHasTraits(TestCase):
362 class TestHasTraits(TestCase):
363
363
364 def test_trait_names(self):
364 def test_trait_names(self):
365 class A(HasTraits):
365 class A(HasTraits):
366 i = Int
366 i = Int
367 f = Float
367 f = Float
368 a = A()
368 a = A()
369 self.assertEquals(a.trait_names(),['i','f'])
369 self.assertEquals(a.trait_names(),['i','f'])
370 self.assertEquals(A.class_trait_names(),['i','f'])
370 self.assertEquals(A.class_trait_names(),['i','f'])
371
371
372 def test_trait_metadata(self):
372 def test_trait_metadata(self):
373 class A(HasTraits):
373 class A(HasTraits):
374 i = Int(config_key='MY_VALUE')
374 i = Int(config_key='MY_VALUE')
375 a = A()
375 a = A()
376 self.assertEquals(a.trait_metadata('i','config_key'), 'MY_VALUE')
376 self.assertEquals(a.trait_metadata('i','config_key'), 'MY_VALUE')
377
377
378 def test_traits(self):
378 def test_traits(self):
379 class A(HasTraits):
379 class A(HasTraits):
380 i = Int
380 i = Int
381 f = Float
381 f = Float
382 a = A()
382 a = A()
383 self.assertEquals(a.traits(), dict(i=A.i, f=A.f))
383 self.assertEquals(a.traits(), dict(i=A.i, f=A.f))
384 self.assertEquals(A.class_traits(), dict(i=A.i, f=A.f))
384 self.assertEquals(A.class_traits(), dict(i=A.i, f=A.f))
385
385
386 def test_traits_metadata(self):
386 def test_traits_metadata(self):
387 class A(HasTraits):
387 class A(HasTraits):
388 i = Int(config_key='VALUE1', other_thing='VALUE2')
388 i = Int(config_key='VALUE1', other_thing='VALUE2')
389 f = Float(config_key='VALUE3', other_thing='VALUE2')
389 f = Float(config_key='VALUE3', other_thing='VALUE2')
390 j = Int(0)
390 j = Int(0)
391 a = A()
391 a = A()
392 self.assertEquals(a.traits(), dict(i=A.i, f=A.f, j=A.j))
392 self.assertEquals(a.traits(), dict(i=A.i, f=A.f, j=A.j))
393 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
393 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
394 self.assertEquals(traits, dict(i=A.i))
394 self.assertEquals(traits, dict(i=A.i))
395
395
396 # This passes, but it shouldn't because I am replicating a bug in
396 # This passes, but it shouldn't because I am replicating a bug in
397 # traits.
397 # traits.
398 traits = a.traits(config_key=lambda v: True)
398 traits = a.traits(config_key=lambda v: True)
399 self.assertEquals(traits, dict(i=A.i, f=A.f, j=A.j))
399 self.assertEquals(traits, dict(i=A.i, f=A.f, j=A.j))
400
400
401 def test_init(self):
401 def test_init(self):
402 class A(HasTraits):
402 class A(HasTraits):
403 i = Int()
403 i = Int()
404 x = Float()
404 x = Float()
405 a = A(i=1, x=10.0)
405 a = A(i=1, x=10.0)
406 self.assertEquals(a.i, 1)
406 self.assertEquals(a.i, 1)
407 self.assertEquals(a.x, 10.0)
407 self.assertEquals(a.x, 10.0)
408
408
409 #-----------------------------------------------------------------------------
409 #-----------------------------------------------------------------------------
410 # Tests for specific trait types
410 # Tests for specific trait types
411 #-----------------------------------------------------------------------------
411 #-----------------------------------------------------------------------------
412
412
413
413
414 class TestType(TestCase):
414 class TestType(TestCase):
415
415
416 def test_default(self):
416 def test_default(self):
417
417
418 class B(object): pass
418 class B(object): pass
419 class A(HasTraits):
419 class A(HasTraits):
420 klass = Type
420 klass = Type
421
421
422 a = A()
422 a = A()
423 self.assertEquals(a.klass, None)
423 self.assertEquals(a.klass, None)
424
424
425 a.klass = B
425 a.klass = B
426 self.assertEquals(a.klass, B)
426 self.assertEquals(a.klass, B)
427 self.assertRaises(TraitError, setattr, a, 'klass', 10)
427 self.assertRaises(TraitError, setattr, a, 'klass', 10)
428
428
429 def test_value(self):
429 def test_value(self):
430
430
431 class B(object): pass
431 class B(object): pass
432 class C(object): pass
432 class C(object): pass
433 class A(HasTraits):
433 class A(HasTraits):
434 klass = Type(B)
434 klass = Type(B)
435
435
436 a = A()
436 a = A()
437 self.assertEquals(a.klass, B)
437 self.assertEquals(a.klass, B)
438 self.assertRaises(TraitError, setattr, a, 'klass', C)
438 self.assertRaises(TraitError, setattr, a, 'klass', C)
439 self.assertRaises(TraitError, setattr, a, 'klass', object)
439 self.assertRaises(TraitError, setattr, a, 'klass', object)
440 a.klass = B
440 a.klass = B
441
441
442 def test_allow_none(self):
442 def test_allow_none(self):
443
443
444 class B(object): pass
444 class B(object): pass
445 class C(B): pass
445 class C(B): pass
446 class A(HasTraits):
446 class A(HasTraits):
447 klass = Type(B, allow_none=False)
447 klass = Type(B, allow_none=False)
448
448
449 a = A()
449 a = A()
450 self.assertEquals(a.klass, B)
450 self.assertEquals(a.klass, B)
451 self.assertRaises(TraitError, setattr, a, 'klass', None)
451 self.assertRaises(TraitError, setattr, a, 'klass', None)
452 a.klass = C
452 a.klass = C
453 self.assertEquals(a.klass, C)
453 self.assertEquals(a.klass, C)
454
454
455 def test_validate_klass(self):
455 def test_validate_klass(self):
456
456
457 class A(HasTraits):
457 class A(HasTraits):
458 klass = Type('no strings allowed')
458 klass = Type('no strings allowed')
459
459
460 self.assertRaises(ImportError, A)
460 self.assertRaises(ImportError, A)
461
461
462 class A(HasTraits):
462 class A(HasTraits):
463 klass = Type('rub.adub.Duck')
463 klass = Type('rub.adub.Duck')
464
464
465 self.assertRaises(ImportError, A)
465 self.assertRaises(ImportError, A)
466
466
467 def test_validate_default(self):
467 def test_validate_default(self):
468
468
469 class B(object): pass
469 class B(object): pass
470 class A(HasTraits):
470 class A(HasTraits):
471 klass = Type('bad default', B)
471 klass = Type('bad default', B)
472
472
473 self.assertRaises(ImportError, A)
473 self.assertRaises(ImportError, A)
474
474
475 class C(HasTraits):
475 class C(HasTraits):
476 klass = Type(None, B, allow_none=False)
476 klass = Type(None, B, allow_none=False)
477
477
478 self.assertRaises(TraitError, C)
478 self.assertRaises(TraitError, C)
479
479
480 def test_str_klass(self):
480 def test_str_klass(self):
481
481
482 class A(HasTraits):
482 class A(HasTraits):
483 klass = Type('IPython.utils.ipstruct.Struct')
483 klass = Type('IPython.utils.ipstruct.Struct')
484
484
485 from IPython.utils.ipstruct import Struct
485 from IPython.utils.ipstruct import Struct
486 a = A()
486 a = A()
487 a.klass = Struct
487 a.klass = Struct
488 self.assertEquals(a.klass, Struct)
488 self.assertEquals(a.klass, Struct)
489
489
490 self.assertRaises(TraitError, setattr, a, 'klass', 10)
490 self.assertRaises(TraitError, setattr, a, 'klass', 10)
491
491
492 class TestInstance(TestCase):
492 class TestInstance(TestCase):
493
493
494 def test_basic(self):
494 def test_basic(self):
495 class Foo(object): pass
495 class Foo(object): pass
496 class Bar(Foo): pass
496 class Bar(Foo): pass
497 class Bah(object): pass
497 class Bah(object): pass
498
498
499 class A(HasTraits):
499 class A(HasTraits):
500 inst = Instance(Foo)
500 inst = Instance(Foo)
501
501
502 a = A()
502 a = A()
503 self.assert_(a.inst is None)
503 self.assert_(a.inst is None)
504 a.inst = Foo()
504 a.inst = Foo()
505 self.assert_(isinstance(a.inst, Foo))
505 self.assert_(isinstance(a.inst, Foo))
506 a.inst = Bar()
506 a.inst = Bar()
507 self.assert_(isinstance(a.inst, Foo))
507 self.assert_(isinstance(a.inst, Foo))
508 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
508 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
509 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
509 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
510 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
510 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
511
511
512 def test_unique_default_value(self):
512 def test_unique_default_value(self):
513 class Foo(object): pass
513 class Foo(object): pass
514 class A(HasTraits):
514 class A(HasTraits):
515 inst = Instance(Foo,(),{})
515 inst = Instance(Foo,(),{})
516
516
517 a = A()
517 a = A()
518 b = A()
518 b = A()
519 self.assert_(a.inst is not b.inst)
519 self.assert_(a.inst is not b.inst)
520
520
521 def test_args_kw(self):
521 def test_args_kw(self):
522 class Foo(object):
522 class Foo(object):
523 def __init__(self, c): self.c = c
523 def __init__(self, c): self.c = c
524 class Bar(object): pass
524 class Bar(object): pass
525 class Bah(object):
525 class Bah(object):
526 def __init__(self, c, d):
526 def __init__(self, c, d):
527 self.c = c; self.d = d
527 self.c = c; self.d = d
528
528
529 class A(HasTraits):
529 class A(HasTraits):
530 inst = Instance(Foo, (10,))
530 inst = Instance(Foo, (10,))
531 a = A()
531 a = A()
532 self.assertEquals(a.inst.c, 10)
532 self.assertEquals(a.inst.c, 10)
533
533
534 class B(HasTraits):
534 class B(HasTraits):
535 inst = Instance(Bah, args=(10,), kw=dict(d=20))
535 inst = Instance(Bah, args=(10,), kw=dict(d=20))
536 b = B()
536 b = B()
537 self.assertEquals(b.inst.c, 10)
537 self.assertEquals(b.inst.c, 10)
538 self.assertEquals(b.inst.d, 20)
538 self.assertEquals(b.inst.d, 20)
539
539
540 class C(HasTraits):
540 class C(HasTraits):
541 inst = Instance(Foo)
541 inst = Instance(Foo)
542 c = C()
542 c = C()
543 self.assert_(c.inst is None)
543 self.assert_(c.inst is None)
544
544
545 def test_bad_default(self):
545 def test_bad_default(self):
546 class Foo(object): pass
546 class Foo(object): pass
547
547
548 class A(HasTraits):
548 class A(HasTraits):
549 inst = Instance(Foo, allow_none=False)
549 inst = Instance(Foo, allow_none=False)
550
550
551 self.assertRaises(TraitError, A)
551 self.assertRaises(TraitError, A)
552
552
553 def test_instance(self):
553 def test_instance(self):
554 class Foo(object): pass
554 class Foo(object): pass
555
555
556 def inner():
556 def inner():
557 class A(HasTraits):
557 class A(HasTraits):
558 inst = Instance(Foo())
558 inst = Instance(Foo())
559
559
560 self.assertRaises(TraitError, inner)
560 self.assertRaises(TraitError, inner)
561
561
562
562
563 class TestThis(TestCase):
563 class TestThis(TestCase):
564
564
565 def test_this_class(self):
565 def test_this_class(self):
566 class Foo(HasTraits):
566 class Foo(HasTraits):
567 this = This
567 this = This
568
568
569 f = Foo()
569 f = Foo()
570 self.assertEquals(f.this, None)
570 self.assertEquals(f.this, None)
571 g = Foo()
571 g = Foo()
572 f.this = g
572 f.this = g
573 self.assertEquals(f.this, g)
573 self.assertEquals(f.this, g)
574 self.assertRaises(TraitError, setattr, f, 'this', 10)
574 self.assertRaises(TraitError, setattr, f, 'this', 10)
575
575
576 def test_this_inst(self):
576 def test_this_inst(self):
577 class Foo(HasTraits):
577 class Foo(HasTraits):
578 this = This()
578 this = This()
579
579
580 f = Foo()
580 f = Foo()
581 f.this = Foo()
581 f.this = Foo()
582 self.assert_(isinstance(f.this, Foo))
582 self.assert_(isinstance(f.this, Foo))
583
583
584 def test_subclass(self):
584 def test_subclass(self):
585 class Foo(HasTraits):
585 class Foo(HasTraits):
586 t = This()
586 t = This()
587 class Bar(Foo):
587 class Bar(Foo):
588 pass
588 pass
589 f = Foo()
589 f = Foo()
590 b = Bar()
590 b = Bar()
591 f.t = b
591 f.t = b
592 b.t = f
592 b.t = f
593 self.assertEquals(f.t, b)
593 self.assertEquals(f.t, b)
594 self.assertEquals(b.t, f)
594 self.assertEquals(b.t, f)
595
595
596 def test_subclass_override(self):
596 def test_subclass_override(self):
597 class Foo(HasTraits):
597 class Foo(HasTraits):
598 t = This()
598 t = This()
599 class Bar(Foo):
599 class Bar(Foo):
600 t = This()
600 t = This()
601 f = Foo()
601 f = Foo()
602 b = Bar()
602 b = Bar()
603 f.t = b
603 f.t = b
604 self.assertEquals(f.t, b)
604 self.assertEquals(f.t, b)
605 self.assertRaises(TraitError, setattr, b, 't', f)
605 self.assertRaises(TraitError, setattr, b, 't', f)
606
606
607 class TraitTestBase(TestCase):
607 class TraitTestBase(TestCase):
608 """A best testing class for basic trait types."""
608 """A best testing class for basic trait types."""
609
609
610 def assign(self, value):
610 def assign(self, value):
611 self.obj.value = value
611 self.obj.value = value
612
612
613 def coerce(self, value):
613 def coerce(self, value):
614 return value
614 return value
615
615
616 def test_good_values(self):
616 def test_good_values(self):
617 if hasattr(self, '_good_values'):
617 if hasattr(self, '_good_values'):
618 for value in self._good_values:
618 for value in self._good_values:
619 self.assign(value)
619 self.assign(value)
620 self.assertEquals(self.obj.value, self.coerce(value))
620 self.assertEquals(self.obj.value, self.coerce(value))
621
621
622 def test_bad_values(self):
622 def test_bad_values(self):
623 if hasattr(self, '_bad_values'):
623 if hasattr(self, '_bad_values'):
624 for value in self._bad_values:
624 for value in self._bad_values:
625 try:
625 self.assertRaises(TraitError, self.assign, value)
626 self.assertRaises(TraitError, self.assign, value)
627 except AssertionError:
628 assert False, value
626
629
627 def test_default_value(self):
630 def test_default_value(self):
628 if hasattr(self, '_default_value'):
631 if hasattr(self, '_default_value'):
629 self.assertEquals(self._default_value, self.obj.value)
632 self.assertEquals(self._default_value, self.obj.value)
630
633
631
634
632 class AnyTrait(HasTraits):
635 class AnyTrait(HasTraits):
633
636
634 value = Any
637 value = Any
635
638
636 class AnyTraitTest(TraitTestBase):
639 class AnyTraitTest(TraitTestBase):
637
640
638 obj = AnyTrait()
641 obj = AnyTrait()
639
642
640 _default_value = None
643 _default_value = None
641 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
644 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
642 _bad_values = []
645 _bad_values = []
643
646
644
647
645 class IntTrait(HasTraits):
648 class IntTrait(HasTraits):
646
649
647 value = Int(99)
650 value = Int(99)
648
651
649 class TestInt(TraitTestBase):
652 class TestInt(TraitTestBase):
650
653
651 obj = IntTrait()
654 obj = IntTrait()
652 _default_value = 99
655 _default_value = 99
653 _good_values = [10, -10]
656 _good_values = [10, -10]
654 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j, 10L,
657 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j,
655 -10L, 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
658 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
656 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
659 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
660 if not py3compat.PY3:
661 _bad_values.extend([10L, -10L])
657
662
658
663
659 class LongTrait(HasTraits):
664 class LongTrait(HasTraits):
660
665
661 value = Long(99L)
666 value = Long(99L)
662
667
663 class TestLong(TraitTestBase):
668 class TestLong(TraitTestBase):
664
669
665 obj = LongTrait()
670 obj = LongTrait()
666
671
667 _default_value = 99L
672 _default_value = 99L
668 _good_values = [10, -10, 10L, -10L]
673 _good_values = [10, -10, 10L, -10L]
669 _bad_values = ['ten', u'ten', [10], [10l], {'ten': 10},(10,),(10L,),
674 _bad_values = ['ten', u'ten', [10], [10l], {'ten': 10},(10,),(10L,),
670 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
675 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
671 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
676 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
672 u'-10.1']
677 u'-10.1']
673
678
674
679
675 class FloatTrait(HasTraits):
680 class FloatTrait(HasTraits):
676
681
677 value = Float(99.0)
682 value = Float(99.0)
678
683
679 class TestFloat(TraitTestBase):
684 class TestFloat(TraitTestBase):
680
685
681 obj = FloatTrait()
686 obj = FloatTrait()
682
687
683 _default_value = 99.0
688 _default_value = 99.0
684 _good_values = [10, -10, 10.1, -10.1]
689 _good_values = [10, -10, 10.1, -10.1]
685 _bad_values = [10L, -10L, 'ten', u'ten', [10], {'ten': 10},(10,), None,
690 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None,
686 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
691 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
687 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
692 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
693 if not py3compat.PY3:
694 _bad_values.extend([10L, -10L])
688
695
689
696
690 class ComplexTrait(HasTraits):
697 class ComplexTrait(HasTraits):
691
698
692 value = Complex(99.0-99.0j)
699 value = Complex(99.0-99.0j)
693
700
694 class TestComplex(TraitTestBase):
701 class TestComplex(TraitTestBase):
695
702
696 obj = ComplexTrait()
703 obj = ComplexTrait()
697
704
698 _default_value = 99.0-99.0j
705 _default_value = 99.0-99.0j
699 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
706 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
700 10.1j, 10.1+10.1j, 10.1-10.1j]
707 10.1j, 10.1+10.1j, 10.1-10.1j]
701 _bad_values = [10L, -10L, u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
708 _bad_values = [u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
709 if not py3compat.PY3:
710 _bad_values.extend([10L, -10L])
702
711
703
712
704 class BytesTrait(HasTraits):
713 class BytesTrait(HasTraits):
705
714
706 value = Bytes(b'string')
715 value = Bytes(b'string')
707
716
708 class TestBytes(TraitTestBase):
717 class TestBytes(TraitTestBase):
709
718
710 obj = BytesTrait()
719 obj = BytesTrait()
711
720
712 _default_value = b'string'
721 _default_value = b'string'
713 _good_values = [b'10', b'-10', b'10L',
722 _good_values = [b'10', b'-10', b'10L',
714 b'-10L', b'10.1', b'-10.1', b'string']
723 b'-10L', b'10.1', b'-10.1', b'string']
715 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j, [10],
724 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j, [10],
716 ['ten'],{'ten': 10},(10,), None, u'string']
725 ['ten'],{'ten': 10},(10,), None, u'string']
717
726
718
727
719 class UnicodeTrait(HasTraits):
728 class UnicodeTrait(HasTraits):
720
729
721 value = Unicode(u'unicode')
730 value = Unicode(u'unicode')
722
731
723 class TestUnicode(TraitTestBase):
732 class TestUnicode(TraitTestBase):
724
733
725 obj = UnicodeTrait()
734 obj = UnicodeTrait()
726
735
727 _default_value = u'unicode'
736 _default_value = u'unicode'
728 _good_values = ['10', '-10', '10L', '-10L', '10.1',
737 _good_values = ['10', '-10', '10L', '-10L', '10.1',
729 '-10.1', '', u'', 'string', u'string', u"€"]
738 '-10.1', '', u'', 'string', u'string', u"€"]
730 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j,
739 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j,
731 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
740 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
732
741
733
742
734 class ObjectNameTrait(HasTraits):
743 class ObjectNameTrait(HasTraits):
735 value = ObjectName("abc")
744 value = ObjectName("abc")
736
745
737 class TestObjectName(TraitTestBase):
746 class TestObjectName(TraitTestBase):
738 obj = ObjectNameTrait()
747 obj = ObjectNameTrait()
739
748
740 _default_value = "abc"
749 _default_value = "abc"
741 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
750 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
742 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
751 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
743 object(), object]
752 object(), object]
744 if sys.version_info[0] < 3:
753 if sys.version_info[0] < 3:
745 _bad_values.append(u"ΓΎ")
754 _bad_values.append(u"ΓΎ")
746 else:
755 else:
747 _good_values.append(u"ΓΎ") # ΓΎ=1 is valid in Python 3 (PEP 3131).
756 _good_values.append(u"ΓΎ") # ΓΎ=1 is valid in Python 3 (PEP 3131).
748
757
749
758
750 class DottedObjectNameTrait(HasTraits):
759 class DottedObjectNameTrait(HasTraits):
751 value = DottedObjectName("a.b")
760 value = DottedObjectName("a.b")
752
761
753 class TestDottedObjectName(TraitTestBase):
762 class TestDottedObjectName(TraitTestBase):
754 obj = DottedObjectNameTrait()
763 obj = DottedObjectNameTrait()
755
764
756 _default_value = "a.b"
765 _default_value = "a.b"
757 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
766 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
758 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc."]
767 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc."]
759 if sys.version_info[0] < 3:
768 if sys.version_info[0] < 3:
760 _bad_values.append(u"t.ΓΎ")
769 _bad_values.append(u"t.ΓΎ")
761 else:
770 else:
762 _good_values.append(u"t.ΓΎ")
771 _good_values.append(u"t.ΓΎ")
763
772
764
773
765 class TCPAddressTrait(HasTraits):
774 class TCPAddressTrait(HasTraits):
766
775
767 value = TCPAddress()
776 value = TCPAddress()
768
777
769 class TestTCPAddress(TraitTestBase):
778 class TestTCPAddress(TraitTestBase):
770
779
771 obj = TCPAddressTrait()
780 obj = TCPAddressTrait()
772
781
773 _default_value = ('127.0.0.1',0)
782 _default_value = ('127.0.0.1',0)
774 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
783 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
775 _bad_values = [(0,0),('localhost',10.0),('localhost',-1)]
784 _bad_values = [(0,0),('localhost',10.0),('localhost',-1)]
776
785
777 class ListTrait(HasTraits):
786 class ListTrait(HasTraits):
778
787
779 value = List(Int)
788 value = List(Int)
780
789
781 class TestList(TraitTestBase):
790 class TestList(TraitTestBase):
782
791
783 obj = ListTrait()
792 obj = ListTrait()
784
793
785 _default_value = []
794 _default_value = []
786 _good_values = [[], [1], range(10)]
795 _good_values = [[], [1], range(10)]
787 _bad_values = [10, [1,'a'], 'a', (1,2)]
796 _bad_values = [10, [1,'a'], 'a', (1,2)]
788
797
789 class LenListTrait(HasTraits):
798 class LenListTrait(HasTraits):
790
799
791 value = List(Int, [0], minlen=1, maxlen=2)
800 value = List(Int, [0], minlen=1, maxlen=2)
792
801
793 class TestLenList(TraitTestBase):
802 class TestLenList(TraitTestBase):
794
803
795 obj = LenListTrait()
804 obj = LenListTrait()
796
805
797 _default_value = [0]
806 _default_value = [0]
798 _good_values = [[1], range(2)]
807 _good_values = [[1], range(2)]
799 _bad_values = [10, [1,'a'], 'a', (1,2), [], range(3)]
808 _bad_values = [10, [1,'a'], 'a', (1,2), [], range(3)]
800
809
801 class TupleTrait(HasTraits):
810 class TupleTrait(HasTraits):
802
811
803 value = Tuple(Int)
812 value = Tuple(Int)
804
813
805 class TestTupleTrait(TraitTestBase):
814 class TestTupleTrait(TraitTestBase):
806
815
807 obj = TupleTrait()
816 obj = TupleTrait()
808
817
809 _default_value = None
818 _default_value = None
810 _good_values = [(1,), None,(0,)]
819 _good_values = [(1,), None,(0,)]
811 _bad_values = [10, (1,2), [1],('a'), ()]
820 _bad_values = [10, (1,2), [1],('a'), ()]
812
821
813 def test_invalid_args(self):
822 def test_invalid_args(self):
814 self.assertRaises(TypeError, Tuple, 5)
823 self.assertRaises(TypeError, Tuple, 5)
815 self.assertRaises(TypeError, Tuple, default_value='hello')
824 self.assertRaises(TypeError, Tuple, default_value='hello')
816 t = Tuple(Int, CBytes, default_value=(1,5))
825 t = Tuple(Int, CBytes, default_value=(1,5))
817
826
818 class LooseTupleTrait(HasTraits):
827 class LooseTupleTrait(HasTraits):
819
828
820 value = Tuple((1,2,3))
829 value = Tuple((1,2,3))
821
830
822 class TestLooseTupleTrait(TraitTestBase):
831 class TestLooseTupleTrait(TraitTestBase):
823
832
824 obj = LooseTupleTrait()
833 obj = LooseTupleTrait()
825
834
826 _default_value = (1,2,3)
835 _default_value = (1,2,3)
827 _good_values = [(1,), None, (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
836 _good_values = [(1,), None, (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
828 _bad_values = [10, 'hello', [1], []]
837 _bad_values = [10, 'hello', [1], []]
829
838
830 def test_invalid_args(self):
839 def test_invalid_args(self):
831 self.assertRaises(TypeError, Tuple, 5)
840 self.assertRaises(TypeError, Tuple, 5)
832 self.assertRaises(TypeError, Tuple, default_value='hello')
841 self.assertRaises(TypeError, Tuple, default_value='hello')
833 t = Tuple(Int, CBytes, default_value=(1,5))
842 t = Tuple(Int, CBytes, default_value=(1,5))
834
843
835
844
836 class MultiTupleTrait(HasTraits):
845 class MultiTupleTrait(HasTraits):
837
846
838 value = Tuple(Int, Bytes, default_value=[99,'bottles'])
847 value = Tuple(Int, Bytes, default_value=[99,b'bottles'])
839
848
840 class TestMultiTuple(TraitTestBase):
849 class TestMultiTuple(TraitTestBase):
841
850
842 obj = MultiTupleTrait()
851 obj = MultiTupleTrait()
843
852
844 _default_value = (99,'bottles')
853 _default_value = (99,b'bottles')
845 _good_values = [(1,'a'), (2,'b')]
854 _good_values = [(1,b'a'), (2,b'b')]
846 _bad_values = ((),10, 'a', (1,'a',3), ('a',1))
855 _bad_values = ((),10, b'a', (1,b'a',3), (b'a',1), (1, u'a'))
General Comments 0
You need to be logged in to leave comments. Login now