##// END OF EJS Templates
More Python 3 compatibility fixes.
Thomas Kluyver -
Show More
@@ -1,715 +1,715 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Utilities for working with strings and text.
3 Utilities for working with strings and text.
4 """
4 """
5
5
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Copyright (C) 2008-2009 The IPython Development Team
7 # Copyright (C) 2008-2009 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16
16
17 import __main__
17 import __main__
18
18
19 import os
19 import os
20 import re
20 import re
21 import shutil
21 import shutil
22 import textwrap
22 import textwrap
23 from string import Formatter
23 from string import Formatter
24
24
25 from IPython.external.path import path
25 from IPython.external.path import path
26
26 from IPython.utils import py3compat
27 from IPython.utils.io import nlprint
27 from IPython.utils.io import nlprint
28 from IPython.utils.data import flatten
28 from IPython.utils.data import flatten
29
29
30 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
31 # Code
31 # Code
32 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
33
33
34
34
35 def unquote_ends(istr):
35 def unquote_ends(istr):
36 """Remove a single pair of quotes from the endpoints of a string."""
36 """Remove a single pair of quotes from the endpoints of a string."""
37
37
38 if not istr:
38 if not istr:
39 return istr
39 return istr
40 if (istr[0]=="'" and istr[-1]=="'") or \
40 if (istr[0]=="'" and istr[-1]=="'") or \
41 (istr[0]=='"' and istr[-1]=='"'):
41 (istr[0]=='"' and istr[-1]=='"'):
42 return istr[1:-1]
42 return istr[1:-1]
43 else:
43 else:
44 return istr
44 return istr
45
45
46
46
47 class LSString(str):
47 class LSString(str):
48 """String derivative with a special access attributes.
48 """String derivative with a special access attributes.
49
49
50 These are normal strings, but with the special attributes:
50 These are normal strings, but with the special attributes:
51
51
52 .l (or .list) : value as list (split on newlines).
52 .l (or .list) : value as list (split on newlines).
53 .n (or .nlstr): original value (the string itself).
53 .n (or .nlstr): original value (the string itself).
54 .s (or .spstr): value as whitespace-separated string.
54 .s (or .spstr): value as whitespace-separated string.
55 .p (or .paths): list of path objects
55 .p (or .paths): list of path objects
56
56
57 Any values which require transformations are computed only once and
57 Any values which require transformations are computed only once and
58 cached.
58 cached.
59
59
60 Such strings are very useful to efficiently interact with the shell, which
60 Such strings are very useful to efficiently interact with the shell, which
61 typically only understands whitespace-separated options for commands."""
61 typically only understands whitespace-separated options for commands."""
62
62
63 def get_list(self):
63 def get_list(self):
64 try:
64 try:
65 return self.__list
65 return self.__list
66 except AttributeError:
66 except AttributeError:
67 self.__list = self.split('\n')
67 self.__list = self.split('\n')
68 return self.__list
68 return self.__list
69
69
70 l = list = property(get_list)
70 l = list = property(get_list)
71
71
72 def get_spstr(self):
72 def get_spstr(self):
73 try:
73 try:
74 return self.__spstr
74 return self.__spstr
75 except AttributeError:
75 except AttributeError:
76 self.__spstr = self.replace('\n',' ')
76 self.__spstr = self.replace('\n',' ')
77 return self.__spstr
77 return self.__spstr
78
78
79 s = spstr = property(get_spstr)
79 s = spstr = property(get_spstr)
80
80
81 def get_nlstr(self):
81 def get_nlstr(self):
82 return self
82 return self
83
83
84 n = nlstr = property(get_nlstr)
84 n = nlstr = property(get_nlstr)
85
85
86 def get_paths(self):
86 def get_paths(self):
87 try:
87 try:
88 return self.__paths
88 return self.__paths
89 except AttributeError:
89 except AttributeError:
90 self.__paths = [path(p) for p in self.split('\n') if os.path.exists(p)]
90 self.__paths = [path(p) for p in self.split('\n') if os.path.exists(p)]
91 return self.__paths
91 return self.__paths
92
92
93 p = paths = property(get_paths)
93 p = paths = property(get_paths)
94
94
95 # FIXME: We need to reimplement type specific displayhook and then add this
95 # FIXME: We need to reimplement type specific displayhook and then add this
96 # back as a custom printer. This should also be moved outside utils into the
96 # back as a custom printer. This should also be moved outside utils into the
97 # core.
97 # core.
98
98
99 # def print_lsstring(arg):
99 # def print_lsstring(arg):
100 # """ Prettier (non-repr-like) and more informative printer for LSString """
100 # """ Prettier (non-repr-like) and more informative printer for LSString """
101 # print "LSString (.p, .n, .l, .s available). Value:"
101 # print "LSString (.p, .n, .l, .s available). Value:"
102 # print arg
102 # print arg
103 #
103 #
104 #
104 #
105 # print_lsstring = result_display.when_type(LSString)(print_lsstring)
105 # print_lsstring = result_display.when_type(LSString)(print_lsstring)
106
106
107
107
108 class SList(list):
108 class SList(list):
109 """List derivative with a special access attributes.
109 """List derivative with a special access attributes.
110
110
111 These are normal lists, but with the special attributes:
111 These are normal lists, but with the special attributes:
112
112
113 .l (or .list) : value as list (the list itself).
113 .l (or .list) : value as list (the list itself).
114 .n (or .nlstr): value as a string, joined on newlines.
114 .n (or .nlstr): value as a string, joined on newlines.
115 .s (or .spstr): value as a string, joined on spaces.
115 .s (or .spstr): value as a string, joined on spaces.
116 .p (or .paths): list of path objects
116 .p (or .paths): list of path objects
117
117
118 Any values which require transformations are computed only once and
118 Any values which require transformations are computed only once and
119 cached."""
119 cached."""
120
120
121 def get_list(self):
121 def get_list(self):
122 return self
122 return self
123
123
124 l = list = property(get_list)
124 l = list = property(get_list)
125
125
126 def get_spstr(self):
126 def get_spstr(self):
127 try:
127 try:
128 return self.__spstr
128 return self.__spstr
129 except AttributeError:
129 except AttributeError:
130 self.__spstr = ' '.join(self)
130 self.__spstr = ' '.join(self)
131 return self.__spstr
131 return self.__spstr
132
132
133 s = spstr = property(get_spstr)
133 s = spstr = property(get_spstr)
134
134
135 def get_nlstr(self):
135 def get_nlstr(self):
136 try:
136 try:
137 return self.__nlstr
137 return self.__nlstr
138 except AttributeError:
138 except AttributeError:
139 self.__nlstr = '\n'.join(self)
139 self.__nlstr = '\n'.join(self)
140 return self.__nlstr
140 return self.__nlstr
141
141
142 n = nlstr = property(get_nlstr)
142 n = nlstr = property(get_nlstr)
143
143
144 def get_paths(self):
144 def get_paths(self):
145 try:
145 try:
146 return self.__paths
146 return self.__paths
147 except AttributeError:
147 except AttributeError:
148 self.__paths = [path(p) for p in self if os.path.exists(p)]
148 self.__paths = [path(p) for p in self if os.path.exists(p)]
149 return self.__paths
149 return self.__paths
150
150
151 p = paths = property(get_paths)
151 p = paths = property(get_paths)
152
152
153 def grep(self, pattern, prune = False, field = None):
153 def grep(self, pattern, prune = False, field = None):
154 """ Return all strings matching 'pattern' (a regex or callable)
154 """ Return all strings matching 'pattern' (a regex or callable)
155
155
156 This is case-insensitive. If prune is true, return all items
156 This is case-insensitive. If prune is true, return all items
157 NOT matching the pattern.
157 NOT matching the pattern.
158
158
159 If field is specified, the match must occur in the specified
159 If field is specified, the match must occur in the specified
160 whitespace-separated field.
160 whitespace-separated field.
161
161
162 Examples::
162 Examples::
163
163
164 a.grep( lambda x: x.startswith('C') )
164 a.grep( lambda x: x.startswith('C') )
165 a.grep('Cha.*log', prune=1)
165 a.grep('Cha.*log', prune=1)
166 a.grep('chm', field=-1)
166 a.grep('chm', field=-1)
167 """
167 """
168
168
169 def match_target(s):
169 def match_target(s):
170 if field is None:
170 if field is None:
171 return s
171 return s
172 parts = s.split()
172 parts = s.split()
173 try:
173 try:
174 tgt = parts[field]
174 tgt = parts[field]
175 return tgt
175 return tgt
176 except IndexError:
176 except IndexError:
177 return ""
177 return ""
178
178
179 if isinstance(pattern, basestring):
179 if isinstance(pattern, basestring):
180 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
180 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
181 else:
181 else:
182 pred = pattern
182 pred = pattern
183 if not prune:
183 if not prune:
184 return SList([el for el in self if pred(match_target(el))])
184 return SList([el for el in self if pred(match_target(el))])
185 else:
185 else:
186 return SList([el for el in self if not pred(match_target(el))])
186 return SList([el for el in self if not pred(match_target(el))])
187
187
188 def fields(self, *fields):
188 def fields(self, *fields):
189 """ Collect whitespace-separated fields from string list
189 """ Collect whitespace-separated fields from string list
190
190
191 Allows quick awk-like usage of string lists.
191 Allows quick awk-like usage of string lists.
192
192
193 Example data (in var a, created by 'a = !ls -l')::
193 Example data (in var a, created by 'a = !ls -l')::
194 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
194 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
195 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
195 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
196
196
197 a.fields(0) is ['-rwxrwxrwx', 'drwxrwxrwx+']
197 a.fields(0) is ['-rwxrwxrwx', 'drwxrwxrwx+']
198 a.fields(1,0) is ['1 -rwxrwxrwx', '6 drwxrwxrwx+']
198 a.fields(1,0) is ['1 -rwxrwxrwx', '6 drwxrwxrwx+']
199 (note the joining by space).
199 (note the joining by space).
200 a.fields(-1) is ['ChangeLog', 'IPython']
200 a.fields(-1) is ['ChangeLog', 'IPython']
201
201
202 IndexErrors are ignored.
202 IndexErrors are ignored.
203
203
204 Without args, fields() just split()'s the strings.
204 Without args, fields() just split()'s the strings.
205 """
205 """
206 if len(fields) == 0:
206 if len(fields) == 0:
207 return [el.split() for el in self]
207 return [el.split() for el in self]
208
208
209 res = SList()
209 res = SList()
210 for el in [f.split() for f in self]:
210 for el in [f.split() for f in self]:
211 lineparts = []
211 lineparts = []
212
212
213 for fd in fields:
213 for fd in fields:
214 try:
214 try:
215 lineparts.append(el[fd])
215 lineparts.append(el[fd])
216 except IndexError:
216 except IndexError:
217 pass
217 pass
218 if lineparts:
218 if lineparts:
219 res.append(" ".join(lineparts))
219 res.append(" ".join(lineparts))
220
220
221 return res
221 return res
222
222
223 def sort(self,field= None, nums = False):
223 def sort(self,field= None, nums = False):
224 """ sort by specified fields (see fields())
224 """ sort by specified fields (see fields())
225
225
226 Example::
226 Example::
227 a.sort(1, nums = True)
227 a.sort(1, nums = True)
228
228
229 Sorts a by second field, in numerical order (so that 21 > 3)
229 Sorts a by second field, in numerical order (so that 21 > 3)
230
230
231 """
231 """
232
232
233 #decorate, sort, undecorate
233 #decorate, sort, undecorate
234 if field is not None:
234 if field is not None:
235 dsu = [[SList([line]).fields(field), line] for line in self]
235 dsu = [[SList([line]).fields(field), line] for line in self]
236 else:
236 else:
237 dsu = [[line, line] for line in self]
237 dsu = [[line, line] for line in self]
238 if nums:
238 if nums:
239 for i in range(len(dsu)):
239 for i in range(len(dsu)):
240 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
240 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
241 try:
241 try:
242 n = int(numstr)
242 n = int(numstr)
243 except ValueError:
243 except ValueError:
244 n = 0;
244 n = 0;
245 dsu[i][0] = n
245 dsu[i][0] = n
246
246
247
247
248 dsu.sort()
248 dsu.sort()
249 return SList([t[1] for t in dsu])
249 return SList([t[1] for t in dsu])
250
250
251
251
252 # FIXME: We need to reimplement type specific displayhook and then add this
252 # FIXME: We need to reimplement type specific displayhook and then add this
253 # back as a custom printer. This should also be moved outside utils into the
253 # back as a custom printer. This should also be moved outside utils into the
254 # core.
254 # core.
255
255
256 # def print_slist(arg):
256 # def print_slist(arg):
257 # """ Prettier (non-repr-like) and more informative printer for SList """
257 # """ Prettier (non-repr-like) and more informative printer for SList """
258 # print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
258 # print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
259 # if hasattr(arg, 'hideonce') and arg.hideonce:
259 # if hasattr(arg, 'hideonce') and arg.hideonce:
260 # arg.hideonce = False
260 # arg.hideonce = False
261 # return
261 # return
262 #
262 #
263 # nlprint(arg)
263 # nlprint(arg)
264 #
264 #
265 # print_slist = result_display.when_type(SList)(print_slist)
265 # print_slist = result_display.when_type(SList)(print_slist)
266
266
267
267
268 def esc_quotes(strng):
268 def esc_quotes(strng):
269 """Return the input string with single and double quotes escaped out"""
269 """Return the input string with single and double quotes escaped out"""
270
270
271 return strng.replace('"','\\"').replace("'","\\'")
271 return strng.replace('"','\\"').replace("'","\\'")
272
272
273
273
274 def make_quoted_expr(s):
274 def make_quoted_expr(s):
275 """Return string s in appropriate quotes, using raw string if possible.
275 """Return string s in appropriate quotes, using raw string if possible.
276
276
277 XXX - example removed because it caused encoding errors in documentation
277 XXX - example removed because it caused encoding errors in documentation
278 generation. We need a new example that doesn't contain invalid chars.
278 generation. We need a new example that doesn't contain invalid chars.
279
279
280 Note the use of raw string and padding at the end to allow trailing
280 Note the use of raw string and padding at the end to allow trailing
281 backslash.
281 backslash.
282 """
282 """
283
283
284 tail = ''
284 tail = ''
285 tailpadding = ''
285 tailpadding = ''
286 raw = ''
286 raw = ''
287 ucode = 'u'
287 ucode = '' if py3compat.PY3 else 'u'
288 if "\\" in s:
288 if "\\" in s:
289 raw = 'r'
289 raw = 'r'
290 if s.endswith('\\'):
290 if s.endswith('\\'):
291 tail = '[:-1]'
291 tail = '[:-1]'
292 tailpadding = '_'
292 tailpadding = '_'
293 if '"' not in s:
293 if '"' not in s:
294 quote = '"'
294 quote = '"'
295 elif "'" not in s:
295 elif "'" not in s:
296 quote = "'"
296 quote = "'"
297 elif '"""' not in s and not s.endswith('"'):
297 elif '"""' not in s and not s.endswith('"'):
298 quote = '"""'
298 quote = '"""'
299 elif "'''" not in s and not s.endswith("'"):
299 elif "'''" not in s and not s.endswith("'"):
300 quote = "'''"
300 quote = "'''"
301 else:
301 else:
302 # give up, backslash-escaped string will do
302 # give up, backslash-escaped string will do
303 return '"%s"' % esc_quotes(s)
303 return '"%s"' % esc_quotes(s)
304 res = ucode + raw + quote + s + tailpadding + quote + tail
304 res = ucode + raw + quote + s + tailpadding + quote + tail
305 return res
305 return res
306
306
307
307
308 def qw(words,flat=0,sep=None,maxsplit=-1):
308 def qw(words,flat=0,sep=None,maxsplit=-1):
309 """Similar to Perl's qw() operator, but with some more options.
309 """Similar to Perl's qw() operator, but with some more options.
310
310
311 qw(words,flat=0,sep=' ',maxsplit=-1) -> words.split(sep,maxsplit)
311 qw(words,flat=0,sep=' ',maxsplit=-1) -> words.split(sep,maxsplit)
312
312
313 words can also be a list itself, and with flat=1, the output will be
313 words can also be a list itself, and with flat=1, the output will be
314 recursively flattened.
314 recursively flattened.
315
315
316 Examples:
316 Examples:
317
317
318 >>> qw('1 2')
318 >>> qw('1 2')
319 ['1', '2']
319 ['1', '2']
320
320
321 >>> qw(['a b','1 2',['m n','p q']])
321 >>> qw(['a b','1 2',['m n','p q']])
322 [['a', 'b'], ['1', '2'], [['m', 'n'], ['p', 'q']]]
322 [['a', 'b'], ['1', '2'], [['m', 'n'], ['p', 'q']]]
323
323
324 >>> qw(['a b','1 2',['m n','p q']],flat=1)
324 >>> qw(['a b','1 2',['m n','p q']],flat=1)
325 ['a', 'b', '1', '2', 'm', 'n', 'p', 'q']
325 ['a', 'b', '1', '2', 'm', 'n', 'p', 'q']
326 """
326 """
327
327
328 if isinstance(words, basestring):
328 if isinstance(words, basestring):
329 return [word.strip() for word in words.split(sep,maxsplit)
329 return [word.strip() for word in words.split(sep,maxsplit)
330 if word and not word.isspace() ]
330 if word and not word.isspace() ]
331 if flat:
331 if flat:
332 return flatten(map(qw,words,[1]*len(words)))
332 return flatten(map(qw,words,[1]*len(words)))
333 return map(qw,words)
333 return map(qw,words)
334
334
335
335
336 def qwflat(words,sep=None,maxsplit=-1):
336 def qwflat(words,sep=None,maxsplit=-1):
337 """Calls qw(words) in flat mode. It's just a convenient shorthand."""
337 """Calls qw(words) in flat mode. It's just a convenient shorthand."""
338 return qw(words,1,sep,maxsplit)
338 return qw(words,1,sep,maxsplit)
339
339
340
340
341 def qw_lol(indata):
341 def qw_lol(indata):
342 """qw_lol('a b') -> [['a','b']],
342 """qw_lol('a b') -> [['a','b']],
343 otherwise it's just a call to qw().
343 otherwise it's just a call to qw().
344
344
345 We need this to make sure the modules_some keys *always* end up as a
345 We need this to make sure the modules_some keys *always* end up as a
346 list of lists."""
346 list of lists."""
347
347
348 if isinstance(indata, basestring):
348 if isinstance(indata, basestring):
349 return [qw(indata)]
349 return [qw(indata)]
350 else:
350 else:
351 return qw(indata)
351 return qw(indata)
352
352
353
353
354 def grep(pat,list,case=1):
354 def grep(pat,list,case=1):
355 """Simple minded grep-like function.
355 """Simple minded grep-like function.
356 grep(pat,list) returns occurrences of pat in list, None on failure.
356 grep(pat,list) returns occurrences of pat in list, None on failure.
357
357
358 It only does simple string matching, with no support for regexps. Use the
358 It only does simple string matching, with no support for regexps. Use the
359 option case=0 for case-insensitive matching."""
359 option case=0 for case-insensitive matching."""
360
360
361 # This is pretty crude. At least it should implement copying only references
361 # This is pretty crude. At least it should implement copying only references
362 # to the original data in case it's big. Now it copies the data for output.
362 # to the original data in case it's big. Now it copies the data for output.
363 out=[]
363 out=[]
364 if case:
364 if case:
365 for term in list:
365 for term in list:
366 if term.find(pat)>-1: out.append(term)
366 if term.find(pat)>-1: out.append(term)
367 else:
367 else:
368 lpat=pat.lower()
368 lpat=pat.lower()
369 for term in list:
369 for term in list:
370 if term.lower().find(lpat)>-1: out.append(term)
370 if term.lower().find(lpat)>-1: out.append(term)
371
371
372 if len(out): return out
372 if len(out): return out
373 else: return None
373 else: return None
374
374
375
375
376 def dgrep(pat,*opts):
376 def dgrep(pat,*opts):
377 """Return grep() on dir()+dir(__builtins__).
377 """Return grep() on dir()+dir(__builtins__).
378
378
379 A very common use of grep() when working interactively."""
379 A very common use of grep() when working interactively."""
380
380
381 return grep(pat,dir(__main__)+dir(__main__.__builtins__),*opts)
381 return grep(pat,dir(__main__)+dir(__main__.__builtins__),*opts)
382
382
383
383
384 def idgrep(pat):
384 def idgrep(pat):
385 """Case-insensitive dgrep()"""
385 """Case-insensitive dgrep()"""
386
386
387 return dgrep(pat,0)
387 return dgrep(pat,0)
388
388
389
389
390 def igrep(pat,list):
390 def igrep(pat,list):
391 """Synonym for case-insensitive grep."""
391 """Synonym for case-insensitive grep."""
392
392
393 return grep(pat,list,case=0)
393 return grep(pat,list,case=0)
394
394
395
395
396 def indent(instr,nspaces=4, ntabs=0, flatten=False):
396 def indent(instr,nspaces=4, ntabs=0, flatten=False):
397 """Indent a string a given number of spaces or tabstops.
397 """Indent a string a given number of spaces or tabstops.
398
398
399 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
399 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
400
400
401 Parameters
401 Parameters
402 ----------
402 ----------
403
403
404 instr : basestring
404 instr : basestring
405 The string to be indented.
405 The string to be indented.
406 nspaces : int (default: 4)
406 nspaces : int (default: 4)
407 The number of spaces to be indented.
407 The number of spaces to be indented.
408 ntabs : int (default: 0)
408 ntabs : int (default: 0)
409 The number of tabs to be indented.
409 The number of tabs to be indented.
410 flatten : bool (default: False)
410 flatten : bool (default: False)
411 Whether to scrub existing indentation. If True, all lines will be
411 Whether to scrub existing indentation. If True, all lines will be
412 aligned to the same indentation. If False, existing indentation will
412 aligned to the same indentation. If False, existing indentation will
413 be strictly increased.
413 be strictly increased.
414
414
415 Returns
415 Returns
416 -------
416 -------
417
417
418 str|unicode : string indented by ntabs and nspaces.
418 str|unicode : string indented by ntabs and nspaces.
419
419
420 """
420 """
421 if instr is None:
421 if instr is None:
422 return
422 return
423 ind = '\t'*ntabs+' '*nspaces
423 ind = '\t'*ntabs+' '*nspaces
424 if flatten:
424 if flatten:
425 pat = re.compile(r'^\s*', re.MULTILINE)
425 pat = re.compile(r'^\s*', re.MULTILINE)
426 else:
426 else:
427 pat = re.compile(r'^', re.MULTILINE)
427 pat = re.compile(r'^', re.MULTILINE)
428 outstr = re.sub(pat, ind, instr)
428 outstr = re.sub(pat, ind, instr)
429 if outstr.endswith(os.linesep+ind):
429 if outstr.endswith(os.linesep+ind):
430 return outstr[:-len(ind)]
430 return outstr[:-len(ind)]
431 else:
431 else:
432 return outstr
432 return outstr
433
433
434 def native_line_ends(filename,backup=1):
434 def native_line_ends(filename,backup=1):
435 """Convert (in-place) a file to line-ends native to the current OS.
435 """Convert (in-place) a file to line-ends native to the current OS.
436
436
437 If the optional backup argument is given as false, no backup of the
437 If the optional backup argument is given as false, no backup of the
438 original file is left. """
438 original file is left. """
439
439
440 backup_suffixes = {'posix':'~','dos':'.bak','nt':'.bak','mac':'.bak'}
440 backup_suffixes = {'posix':'~','dos':'.bak','nt':'.bak','mac':'.bak'}
441
441
442 bak_filename = filename + backup_suffixes[os.name]
442 bak_filename = filename + backup_suffixes[os.name]
443
443
444 original = open(filename).read()
444 original = open(filename).read()
445 shutil.copy2(filename,bak_filename)
445 shutil.copy2(filename,bak_filename)
446 try:
446 try:
447 new = open(filename,'wb')
447 new = open(filename,'wb')
448 new.write(os.linesep.join(original.splitlines()))
448 new.write(os.linesep.join(original.splitlines()))
449 new.write(os.linesep) # ALWAYS put an eol at the end of the file
449 new.write(os.linesep) # ALWAYS put an eol at the end of the file
450 new.close()
450 new.close()
451 except:
451 except:
452 os.rename(bak_filename,filename)
452 os.rename(bak_filename,filename)
453 if not backup:
453 if not backup:
454 try:
454 try:
455 os.remove(bak_filename)
455 os.remove(bak_filename)
456 except:
456 except:
457 pass
457 pass
458
458
459
459
460 def list_strings(arg):
460 def list_strings(arg):
461 """Always return a list of strings, given a string or list of strings
461 """Always return a list of strings, given a string or list of strings
462 as input.
462 as input.
463
463
464 :Examples:
464 :Examples:
465
465
466 In [7]: list_strings('A single string')
466 In [7]: list_strings('A single string')
467 Out[7]: ['A single string']
467 Out[7]: ['A single string']
468
468
469 In [8]: list_strings(['A single string in a list'])
469 In [8]: list_strings(['A single string in a list'])
470 Out[8]: ['A single string in a list']
470 Out[8]: ['A single string in a list']
471
471
472 In [9]: list_strings(['A','list','of','strings'])
472 In [9]: list_strings(['A','list','of','strings'])
473 Out[9]: ['A', 'list', 'of', 'strings']
473 Out[9]: ['A', 'list', 'of', 'strings']
474 """
474 """
475
475
476 if isinstance(arg,basestring): return [arg]
476 if isinstance(arg,basestring): return [arg]
477 else: return arg
477 else: return arg
478
478
479
479
480 def marquee(txt='',width=78,mark='*'):
480 def marquee(txt='',width=78,mark='*'):
481 """Return the input string centered in a 'marquee'.
481 """Return the input string centered in a 'marquee'.
482
482
483 :Examples:
483 :Examples:
484
484
485 In [16]: marquee('A test',40)
485 In [16]: marquee('A test',40)
486 Out[16]: '**************** A test ****************'
486 Out[16]: '**************** A test ****************'
487
487
488 In [17]: marquee('A test',40,'-')
488 In [17]: marquee('A test',40,'-')
489 Out[17]: '---------------- A test ----------------'
489 Out[17]: '---------------- A test ----------------'
490
490
491 In [18]: marquee('A test',40,' ')
491 In [18]: marquee('A test',40,' ')
492 Out[18]: ' A test '
492 Out[18]: ' A test '
493
493
494 """
494 """
495 if not txt:
495 if not txt:
496 return (mark*width)[:width]
496 return (mark*width)[:width]
497 nmark = (width-len(txt)-2)/len(mark)/2
497 nmark = (width-len(txt)-2)//len(mark)//2
498 if nmark < 0: nmark =0
498 if nmark < 0: nmark =0
499 marks = mark*nmark
499 marks = mark*nmark
500 return '%s %s %s' % (marks,txt,marks)
500 return '%s %s %s' % (marks,txt,marks)
501
501
502
502
503 ini_spaces_re = re.compile(r'^(\s+)')
503 ini_spaces_re = re.compile(r'^(\s+)')
504
504
505 def num_ini_spaces(strng):
505 def num_ini_spaces(strng):
506 """Return the number of initial spaces in a string"""
506 """Return the number of initial spaces in a string"""
507
507
508 ini_spaces = ini_spaces_re.match(strng)
508 ini_spaces = ini_spaces_re.match(strng)
509 if ini_spaces:
509 if ini_spaces:
510 return ini_spaces.end()
510 return ini_spaces.end()
511 else:
511 else:
512 return 0
512 return 0
513
513
514
514
515 def format_screen(strng):
515 def format_screen(strng):
516 """Format a string for screen printing.
516 """Format a string for screen printing.
517
517
518 This removes some latex-type format codes."""
518 This removes some latex-type format codes."""
519 # Paragraph continue
519 # Paragraph continue
520 par_re = re.compile(r'\\$',re.MULTILINE)
520 par_re = re.compile(r'\\$',re.MULTILINE)
521 strng = par_re.sub('',strng)
521 strng = par_re.sub('',strng)
522 return strng
522 return strng
523
523
524 def dedent(text):
524 def dedent(text):
525 """Equivalent of textwrap.dedent that ignores unindented first line.
525 """Equivalent of textwrap.dedent that ignores unindented first line.
526
526
527 This means it will still dedent strings like:
527 This means it will still dedent strings like:
528 '''foo
528 '''foo
529 is a bar
529 is a bar
530 '''
530 '''
531
531
532 For use in wrap_paragraphs.
532 For use in wrap_paragraphs.
533 """
533 """
534
534
535 if text.startswith('\n'):
535 if text.startswith('\n'):
536 # text starts with blank line, don't ignore the first line
536 # text starts with blank line, don't ignore the first line
537 return textwrap.dedent(text)
537 return textwrap.dedent(text)
538
538
539 # split first line
539 # split first line
540 splits = text.split('\n',1)
540 splits = text.split('\n',1)
541 if len(splits) == 1:
541 if len(splits) == 1:
542 # only one line
542 # only one line
543 return textwrap.dedent(text)
543 return textwrap.dedent(text)
544
544
545 first, rest = splits
545 first, rest = splits
546 # dedent everything but the first line
546 # dedent everything but the first line
547 rest = textwrap.dedent(rest)
547 rest = textwrap.dedent(rest)
548 return '\n'.join([first, rest])
548 return '\n'.join([first, rest])
549
549
550 def wrap_paragraphs(text, ncols=80):
550 def wrap_paragraphs(text, ncols=80):
551 """Wrap multiple paragraphs to fit a specified width.
551 """Wrap multiple paragraphs to fit a specified width.
552
552
553 This is equivalent to textwrap.wrap, but with support for multiple
553 This is equivalent to textwrap.wrap, but with support for multiple
554 paragraphs, as separated by empty lines.
554 paragraphs, as separated by empty lines.
555
555
556 Returns
556 Returns
557 -------
557 -------
558
558
559 list of complete paragraphs, wrapped to fill `ncols` columns.
559 list of complete paragraphs, wrapped to fill `ncols` columns.
560 """
560 """
561 paragraph_re = re.compile(r'\n(\s*\n)+', re.MULTILINE)
561 paragraph_re = re.compile(r'\n(\s*\n)+', re.MULTILINE)
562 text = dedent(text).strip()
562 text = dedent(text).strip()
563 paragraphs = paragraph_re.split(text)[::2] # every other entry is space
563 paragraphs = paragraph_re.split(text)[::2] # every other entry is space
564 out_ps = []
564 out_ps = []
565 indent_re = re.compile(r'\n\s+', re.MULTILINE)
565 indent_re = re.compile(r'\n\s+', re.MULTILINE)
566 for p in paragraphs:
566 for p in paragraphs:
567 # presume indentation that survives dedent is meaningful formatting,
567 # presume indentation that survives dedent is meaningful formatting,
568 # so don't fill unless text is flush.
568 # so don't fill unless text is flush.
569 if indent_re.search(p) is None:
569 if indent_re.search(p) is None:
570 # wrap paragraph
570 # wrap paragraph
571 p = textwrap.fill(p, ncols)
571 p = textwrap.fill(p, ncols)
572 out_ps.append(p)
572 out_ps.append(p)
573 return out_ps
573 return out_ps
574
574
575
575
576
576
577 class EvalFormatter(Formatter):
577 class EvalFormatter(Formatter):
578 """A String Formatter that allows evaluation of simple expressions.
578 """A String Formatter that allows evaluation of simple expressions.
579
579
580 Any time a format key is not found in the kwargs,
580 Any time a format key is not found in the kwargs,
581 it will be tried as an expression in the kwargs namespace.
581 it will be tried as an expression in the kwargs namespace.
582
582
583 This is to be used in templating cases, such as the parallel batch
583 This is to be used in templating cases, such as the parallel batch
584 script templates, where simple arithmetic on arguments is useful.
584 script templates, where simple arithmetic on arguments is useful.
585
585
586 Examples
586 Examples
587 --------
587 --------
588
588
589 In [1]: f = EvalFormatter()
589 In [1]: f = EvalFormatter()
590 In [2]: f.format('{n/4}', n=8)
590 In [2]: f.format('{n/4}', n=8)
591 Out[2]: '2'
591 Out[2]: '2'
592
592
593 In [3]: f.format('{range(3)}')
593 In [3]: f.format('{range(3)}')
594 Out[3]: '[0, 1, 2]'
594 Out[3]: '[0, 1, 2]'
595
595
596 In [4]: f.format('{3*2}')
596 In [4]: f.format('{3*2}')
597 Out[4]: '6'
597 Out[4]: '6'
598 """
598 """
599
599
600 # should we allow slicing by disabling the format_spec feature?
600 # should we allow slicing by disabling the format_spec feature?
601 allow_slicing = True
601 allow_slicing = True
602
602
603 # copied from Formatter._vformat with minor changes to allow eval
603 # copied from Formatter._vformat with minor changes to allow eval
604 # and replace the format_spec code with slicing
604 # and replace the format_spec code with slicing
605 def _vformat(self, format_string, args, kwargs, used_args, recursion_depth):
605 def _vformat(self, format_string, args, kwargs, used_args, recursion_depth):
606 if recursion_depth < 0:
606 if recursion_depth < 0:
607 raise ValueError('Max string recursion exceeded')
607 raise ValueError('Max string recursion exceeded')
608 result = []
608 result = []
609 for literal_text, field_name, format_spec, conversion in \
609 for literal_text, field_name, format_spec, conversion in \
610 self.parse(format_string):
610 self.parse(format_string):
611
611
612 # output the literal text
612 # output the literal text
613 if literal_text:
613 if literal_text:
614 result.append(literal_text)
614 result.append(literal_text)
615
615
616 # if there's a field, output it
616 # if there's a field, output it
617 if field_name is not None:
617 if field_name is not None:
618 # this is some markup, find the object and do
618 # this is some markup, find the object and do
619 # the formatting
619 # the formatting
620
620
621 if self.allow_slicing and format_spec:
621 if self.allow_slicing and format_spec:
622 # override format spec, to allow slicing:
622 # override format spec, to allow slicing:
623 field_name = ':'.join([field_name, format_spec])
623 field_name = ':'.join([field_name, format_spec])
624 format_spec = ''
624 format_spec = ''
625
625
626 # eval the contents of the field for the object
626 # eval the contents of the field for the object
627 # to be formatted
627 # to be formatted
628 obj = eval(field_name, kwargs)
628 obj = eval(field_name, kwargs)
629
629
630 # do any conversion on the resulting object
630 # do any conversion on the resulting object
631 obj = self.convert_field(obj, conversion)
631 obj = self.convert_field(obj, conversion)
632
632
633 # expand the format spec, if needed
633 # expand the format spec, if needed
634 format_spec = self._vformat(format_spec, args, kwargs,
634 format_spec = self._vformat(format_spec, args, kwargs,
635 used_args, recursion_depth-1)
635 used_args, recursion_depth-1)
636
636
637 # format the object and append to the result
637 # format the object and append to the result
638 result.append(self.format_field(obj, format_spec))
638 result.append(self.format_field(obj, format_spec))
639
639
640 return ''.join(result)
640 return ''.join(result)
641
641
642
642
643 def columnize(items, separator=' ', displaywidth=80):
643 def columnize(items, separator=' ', displaywidth=80):
644 """ Transform a list of strings into a single string with columns.
644 """ Transform a list of strings into a single string with columns.
645
645
646 Parameters
646 Parameters
647 ----------
647 ----------
648 items : sequence of strings
648 items : sequence of strings
649 The strings to process.
649 The strings to process.
650
650
651 separator : str, optional [default is two spaces]
651 separator : str, optional [default is two spaces]
652 The string that separates columns.
652 The string that separates columns.
653
653
654 displaywidth : int, optional [default is 80]
654 displaywidth : int, optional [default is 80]
655 Width of the display in number of characters.
655 Width of the display in number of characters.
656
656
657 Returns
657 Returns
658 -------
658 -------
659 The formatted string.
659 The formatted string.
660 """
660 """
661 # Note: this code is adapted from columnize 0.3.2.
661 # Note: this code is adapted from columnize 0.3.2.
662 # See http://code.google.com/p/pycolumnize/
662 # See http://code.google.com/p/pycolumnize/
663
663
664 # Some degenerate cases.
664 # Some degenerate cases.
665 size = len(items)
665 size = len(items)
666 if size == 0:
666 if size == 0:
667 return '\n'
667 return '\n'
668 elif size == 1:
668 elif size == 1:
669 return '%s\n' % items[0]
669 return '%s\n' % items[0]
670
670
671 # Special case: if any item is longer than the maximum width, there's no
671 # Special case: if any item is longer than the maximum width, there's no
672 # point in triggering the logic below...
672 # point in triggering the logic below...
673 item_len = map(len, items) # save these, we can reuse them below
673 item_len = map(len, items) # save these, we can reuse them below
674 longest = max(item_len)
674 longest = max(item_len)
675 if longest >= displaywidth:
675 if longest >= displaywidth:
676 return '\n'.join(items+[''])
676 return '\n'.join(items+[''])
677
677
678 # Try every row count from 1 upwards
678 # Try every row count from 1 upwards
679 array_index = lambda nrows, row, col: nrows*col + row
679 array_index = lambda nrows, row, col: nrows*col + row
680 for nrows in range(1, size):
680 for nrows in range(1, size):
681 ncols = (size + nrows - 1) // nrows
681 ncols = (size + nrows - 1) // nrows
682 colwidths = []
682 colwidths = []
683 totwidth = -len(separator)
683 totwidth = -len(separator)
684 for col in range(ncols):
684 for col in range(ncols):
685 # Get max column width for this column
685 # Get max column width for this column
686 colwidth = 0
686 colwidth = 0
687 for row in range(nrows):
687 for row in range(nrows):
688 i = array_index(nrows, row, col)
688 i = array_index(nrows, row, col)
689 if i >= size: break
689 if i >= size: break
690 x, len_x = items[i], item_len[i]
690 x, len_x = items[i], item_len[i]
691 colwidth = max(colwidth, len_x)
691 colwidth = max(colwidth, len_x)
692 colwidths.append(colwidth)
692 colwidths.append(colwidth)
693 totwidth += colwidth + len(separator)
693 totwidth += colwidth + len(separator)
694 if totwidth > displaywidth:
694 if totwidth > displaywidth:
695 break
695 break
696 if totwidth <= displaywidth:
696 if totwidth <= displaywidth:
697 break
697 break
698
698
699 # The smallest number of rows computed and the max widths for each
699 # The smallest number of rows computed and the max widths for each
700 # column has been obtained. Now we just have to format each of the rows.
700 # column has been obtained. Now we just have to format each of the rows.
701 string = ''
701 string = ''
702 for row in range(nrows):
702 for row in range(nrows):
703 texts = []
703 texts = []
704 for col in range(ncols):
704 for col in range(ncols):
705 i = row + nrows*col
705 i = row + nrows*col
706 if i >= size:
706 if i >= size:
707 texts.append('')
707 texts.append('')
708 else:
708 else:
709 texts.append(items[i])
709 texts.append(items[i])
710 while texts and not texts[-1]:
710 while texts and not texts[-1]:
711 del texts[-1]
711 del texts[-1]
712 for col in range(len(texts)):
712 for col in range(len(texts)):
713 texts[col] = texts[col].ljust(colwidths[col])
713 texts[col] = texts[col].ljust(colwidths[col])
714 string += '%s\n' % separator.join(texts)
714 string += '%s\n' % separator.join(texts)
715 return string
715 return string
@@ -1,1396 +1,1398 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 A lightweight Traits like module.
3 A lightweight Traits like module.
4
4
5 This is designed to provide a lightweight, simple, pure Python version of
5 This is designed to provide a lightweight, simple, pure Python version of
6 many of the capabilities of enthought.traits. This includes:
6 many of the capabilities of enthought.traits. This includes:
7
7
8 * Validation
8 * Validation
9 * Type specification with defaults
9 * Type specification with defaults
10 * Static and dynamic notification
10 * Static and dynamic notification
11 * Basic predefined types
11 * Basic predefined types
12 * An API that is similar to enthought.traits
12 * An API that is similar to enthought.traits
13
13
14 We don't support:
14 We don't support:
15
15
16 * Delegation
16 * Delegation
17 * Automatic GUI generation
17 * Automatic GUI generation
18 * A full set of trait types. Most importantly, we don't provide container
18 * A full set of trait types. Most importantly, we don't provide container
19 traits (list, dict, tuple) that can trigger notifications if their
19 traits (list, dict, tuple) that can trigger notifications if their
20 contents change.
20 contents change.
21 * API compatibility with enthought.traits
21 * API compatibility with enthought.traits
22
22
23 There are also some important difference in our design:
23 There are also some important difference in our design:
24
24
25 * enthought.traits does not validate default values. We do.
25 * enthought.traits does not validate default values. We do.
26
26
27 We choose to create this module because we need these capabilities, but
27 We choose to create this module because we need these capabilities, but
28 we need them to be pure Python so they work in all Python implementations,
28 we need them to be pure Python so they work in all Python implementations,
29 including Jython and IronPython.
29 including Jython and IronPython.
30
30
31 Authors:
31 Authors:
32
32
33 * Brian Granger
33 * Brian Granger
34 * Enthought, Inc. Some of the code in this file comes from enthought.traits
34 * Enthought, Inc. Some of the code in this file comes from enthought.traits
35 and is licensed under the BSD license. Also, many of the ideas also come
35 and is licensed under the BSD license. Also, many of the ideas also come
36 from enthought.traits even though our implementation is very different.
36 from enthought.traits even though our implementation is very different.
37 """
37 """
38
38
39 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
40 # Copyright (C) 2008-2009 The IPython Development Team
40 # Copyright (C) 2008-2009 The IPython Development Team
41 #
41 #
42 # Distributed under the terms of the BSD License. The full license is in
42 # Distributed under the terms of the BSD License. The full license is in
43 # the file COPYING, distributed as part of this software.
43 # the file COPYING, distributed as part of this software.
44 #-----------------------------------------------------------------------------
44 #-----------------------------------------------------------------------------
45
45
46 #-----------------------------------------------------------------------------
46 #-----------------------------------------------------------------------------
47 # Imports
47 # Imports
48 #-----------------------------------------------------------------------------
48 #-----------------------------------------------------------------------------
49
49
50
50
51 import inspect
51 import inspect
52 import re
52 import re
53 import sys
53 import sys
54 import types
54 import types
55 from types import (
55 from types import FunctionType
56 InstanceType, ClassType, FunctionType,
56 try:
57 ListType, TupleType
57 from types import ClassType, InstanceType
58 )
58 ClassTypes = (ClassType, type)
59 except:
60 ClassTypes = (type,)
61
59 from .importstring import import_item
62 from .importstring import import_item
63 from IPython.utils import py3compat
60
64
61 ClassTypes = (ClassType, type)
65 SequenceTypes = (list, tuple, set, frozenset)
62
63 SequenceTypes = (ListType, TupleType, set, frozenset)
64
66
65 #-----------------------------------------------------------------------------
67 #-----------------------------------------------------------------------------
66 # Basic classes
68 # Basic classes
67 #-----------------------------------------------------------------------------
69 #-----------------------------------------------------------------------------
68
70
69
71
70 class NoDefaultSpecified ( object ): pass
72 class NoDefaultSpecified ( object ): pass
71 NoDefaultSpecified = NoDefaultSpecified()
73 NoDefaultSpecified = NoDefaultSpecified()
72
74
73
75
74 class Undefined ( object ): pass
76 class Undefined ( object ): pass
75 Undefined = Undefined()
77 Undefined = Undefined()
76
78
77 class TraitError(Exception):
79 class TraitError(Exception):
78 pass
80 pass
79
81
80 #-----------------------------------------------------------------------------
82 #-----------------------------------------------------------------------------
81 # Utilities
83 # Utilities
82 #-----------------------------------------------------------------------------
84 #-----------------------------------------------------------------------------
83
85
84
86
85 def class_of ( object ):
87 def class_of ( object ):
86 """ Returns a string containing the class name of an object with the
88 """ Returns a string containing the class name of an object with the
87 correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image',
89 correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image',
88 'a PlotValue').
90 'a PlotValue').
89 """
91 """
90 if isinstance( object, basestring ):
92 if isinstance( object, basestring ):
91 return add_article( object )
93 return add_article( object )
92
94
93 return add_article( object.__class__.__name__ )
95 return add_article( object.__class__.__name__ )
94
96
95
97
96 def add_article ( name ):
98 def add_article ( name ):
97 """ Returns a string containing the correct indefinite article ('a' or 'an')
99 """ Returns a string containing the correct indefinite article ('a' or 'an')
98 prefixed to the specified string.
100 prefixed to the specified string.
99 """
101 """
100 if name[:1].lower() in 'aeiou':
102 if name[:1].lower() in 'aeiou':
101 return 'an ' + name
103 return 'an ' + name
102
104
103 return 'a ' + name
105 return 'a ' + name
104
106
105
107
106 def repr_type(obj):
108 def repr_type(obj):
107 """ Return a string representation of a value and its type for readable
109 """ Return a string representation of a value and its type for readable
108 error messages.
110 error messages.
109 """
111 """
110 the_type = type(obj)
112 the_type = type(obj)
111 if the_type is InstanceType:
113 if (not py3compat.PY3) and the_type is InstanceType:
112 # Old-style class.
114 # Old-style class.
113 the_type = obj.__class__
115 the_type = obj.__class__
114 msg = '%r %r' % (obj, the_type)
116 msg = '%r %r' % (obj, the_type)
115 return msg
117 return msg
116
118
117
119
118 def parse_notifier_name(name):
120 def parse_notifier_name(name):
119 """Convert the name argument to a list of names.
121 """Convert the name argument to a list of names.
120
122
121 Examples
123 Examples
122 --------
124 --------
123
125
124 >>> parse_notifier_name('a')
126 >>> parse_notifier_name('a')
125 ['a']
127 ['a']
126 >>> parse_notifier_name(['a','b'])
128 >>> parse_notifier_name(['a','b'])
127 ['a', 'b']
129 ['a', 'b']
128 >>> parse_notifier_name(None)
130 >>> parse_notifier_name(None)
129 ['anytrait']
131 ['anytrait']
130 """
132 """
131 if isinstance(name, str):
133 if isinstance(name, str):
132 return [name]
134 return [name]
133 elif name is None:
135 elif name is None:
134 return ['anytrait']
136 return ['anytrait']
135 elif isinstance(name, (list, tuple)):
137 elif isinstance(name, (list, tuple)):
136 for n in name:
138 for n in name:
137 assert isinstance(n, str), "names must be strings"
139 assert isinstance(n, str), "names must be strings"
138 return name
140 return name
139
141
140
142
141 class _SimpleTest:
143 class _SimpleTest:
142 def __init__ ( self, value ): self.value = value
144 def __init__ ( self, value ): self.value = value
143 def __call__ ( self, test ):
145 def __call__ ( self, test ):
144 return test == self.value
146 return test == self.value
145 def __repr__(self):
147 def __repr__(self):
146 return "<SimpleTest(%r)" % self.value
148 return "<SimpleTest(%r)" % self.value
147 def __str__(self):
149 def __str__(self):
148 return self.__repr__()
150 return self.__repr__()
149
151
150
152
151 def getmembers(object, predicate=None):
153 def getmembers(object, predicate=None):
152 """A safe version of inspect.getmembers that handles missing attributes.
154 """A safe version of inspect.getmembers that handles missing attributes.
153
155
154 This is useful when there are descriptor based attributes that for
156 This is useful when there are descriptor based attributes that for
155 some reason raise AttributeError even though they exist. This happens
157 some reason raise AttributeError even though they exist. This happens
156 in zope.inteface with the __provides__ attribute.
158 in zope.inteface with the __provides__ attribute.
157 """
159 """
158 results = []
160 results = []
159 for key in dir(object):
161 for key in dir(object):
160 try:
162 try:
161 value = getattr(object, key)
163 value = getattr(object, key)
162 except AttributeError:
164 except AttributeError:
163 pass
165 pass
164 else:
166 else:
165 if not predicate or predicate(value):
167 if not predicate or predicate(value):
166 results.append((key, value))
168 results.append((key, value))
167 results.sort()
169 results.sort()
168 return results
170 return results
169
171
170
172
171 #-----------------------------------------------------------------------------
173 #-----------------------------------------------------------------------------
172 # Base TraitType for all traits
174 # Base TraitType for all traits
173 #-----------------------------------------------------------------------------
175 #-----------------------------------------------------------------------------
174
176
175
177
176 class TraitType(object):
178 class TraitType(object):
177 """A base class for all trait descriptors.
179 """A base class for all trait descriptors.
178
180
179 Notes
181 Notes
180 -----
182 -----
181 Our implementation of traits is based on Python's descriptor
183 Our implementation of traits is based on Python's descriptor
182 prototol. This class is the base class for all such descriptors. The
184 prototol. This class is the base class for all such descriptors. The
183 only magic we use is a custom metaclass for the main :class:`HasTraits`
185 only magic we use is a custom metaclass for the main :class:`HasTraits`
184 class that does the following:
186 class that does the following:
185
187
186 1. Sets the :attr:`name` attribute of every :class:`TraitType`
188 1. Sets the :attr:`name` attribute of every :class:`TraitType`
187 instance in the class dict to the name of the attribute.
189 instance in the class dict to the name of the attribute.
188 2. Sets the :attr:`this_class` attribute of every :class:`TraitType`
190 2. Sets the :attr:`this_class` attribute of every :class:`TraitType`
189 instance in the class dict to the *class* that declared the trait.
191 instance in the class dict to the *class* that declared the trait.
190 This is used by the :class:`This` trait to allow subclasses to
192 This is used by the :class:`This` trait to allow subclasses to
191 accept superclasses for :class:`This` values.
193 accept superclasses for :class:`This` values.
192 """
194 """
193
195
194
196
195 metadata = {}
197 metadata = {}
196 default_value = Undefined
198 default_value = Undefined
197 info_text = 'any value'
199 info_text = 'any value'
198
200
199 def __init__(self, default_value=NoDefaultSpecified, **metadata):
201 def __init__(self, default_value=NoDefaultSpecified, **metadata):
200 """Create a TraitType.
202 """Create a TraitType.
201 """
203 """
202 if default_value is not NoDefaultSpecified:
204 if default_value is not NoDefaultSpecified:
203 self.default_value = default_value
205 self.default_value = default_value
204
206
205 if len(metadata) > 0:
207 if len(metadata) > 0:
206 if len(self.metadata) > 0:
208 if len(self.metadata) > 0:
207 self._metadata = self.metadata.copy()
209 self._metadata = self.metadata.copy()
208 self._metadata.update(metadata)
210 self._metadata.update(metadata)
209 else:
211 else:
210 self._metadata = metadata
212 self._metadata = metadata
211 else:
213 else:
212 self._metadata = self.metadata
214 self._metadata = self.metadata
213
215
214 self.init()
216 self.init()
215
217
216 def init(self):
218 def init(self):
217 pass
219 pass
218
220
219 def get_default_value(self):
221 def get_default_value(self):
220 """Create a new instance of the default value."""
222 """Create a new instance of the default value."""
221 return self.default_value
223 return self.default_value
222
224
223 def instance_init(self, obj):
225 def instance_init(self, obj):
224 """This is called by :meth:`HasTraits.__new__` to finish init'ing.
226 """This is called by :meth:`HasTraits.__new__` to finish init'ing.
225
227
226 Some stages of initialization must be delayed until the parent
228 Some stages of initialization must be delayed until the parent
227 :class:`HasTraits` instance has been created. This method is
229 :class:`HasTraits` instance has been created. This method is
228 called in :meth:`HasTraits.__new__` after the instance has been
230 called in :meth:`HasTraits.__new__` after the instance has been
229 created.
231 created.
230
232
231 This method trigger the creation and validation of default values
233 This method trigger the creation and validation of default values
232 and also things like the resolution of str given class names in
234 and also things like the resolution of str given class names in
233 :class:`Type` and :class`Instance`.
235 :class:`Type` and :class`Instance`.
234
236
235 Parameters
237 Parameters
236 ----------
238 ----------
237 obj : :class:`HasTraits` instance
239 obj : :class:`HasTraits` instance
238 The parent :class:`HasTraits` instance that has just been
240 The parent :class:`HasTraits` instance that has just been
239 created.
241 created.
240 """
242 """
241 self.set_default_value(obj)
243 self.set_default_value(obj)
242
244
243 def set_default_value(self, obj):
245 def set_default_value(self, obj):
244 """Set the default value on a per instance basis.
246 """Set the default value on a per instance basis.
245
247
246 This method is called by :meth:`instance_init` to create and
248 This method is called by :meth:`instance_init` to create and
247 validate the default value. The creation and validation of
249 validate the default value. The creation and validation of
248 default values must be delayed until the parent :class:`HasTraits`
250 default values must be delayed until the parent :class:`HasTraits`
249 class has been instantiated.
251 class has been instantiated.
250 """
252 """
251 # Check for a deferred initializer defined in the same class as the
253 # Check for a deferred initializer defined in the same class as the
252 # trait declaration or above.
254 # trait declaration or above.
253 mro = type(obj).mro()
255 mro = type(obj).mro()
254 meth_name = '_%s_default' % self.name
256 meth_name = '_%s_default' % self.name
255 for cls in mro[:mro.index(self.this_class)+1]:
257 for cls in mro[:mro.index(self.this_class)+1]:
256 if meth_name in cls.__dict__:
258 if meth_name in cls.__dict__:
257 break
259 break
258 else:
260 else:
259 # We didn't find one. Do static initialization.
261 # We didn't find one. Do static initialization.
260 dv = self.get_default_value()
262 dv = self.get_default_value()
261 newdv = self._validate(obj, dv)
263 newdv = self._validate(obj, dv)
262 obj._trait_values[self.name] = newdv
264 obj._trait_values[self.name] = newdv
263 return
265 return
264 # Complete the dynamic initialization.
266 # Complete the dynamic initialization.
265 obj._trait_dyn_inits[self.name] = cls.__dict__[meth_name]
267 obj._trait_dyn_inits[self.name] = cls.__dict__[meth_name]
266
268
267 def __get__(self, obj, cls=None):
269 def __get__(self, obj, cls=None):
268 """Get the value of the trait by self.name for the instance.
270 """Get the value of the trait by self.name for the instance.
269
271
270 Default values are instantiated when :meth:`HasTraits.__new__`
272 Default values are instantiated when :meth:`HasTraits.__new__`
271 is called. Thus by the time this method gets called either the
273 is called. Thus by the time this method gets called either the
272 default value or a user defined value (they called :meth:`__set__`)
274 default value or a user defined value (they called :meth:`__set__`)
273 is in the :class:`HasTraits` instance.
275 is in the :class:`HasTraits` instance.
274 """
276 """
275 if obj is None:
277 if obj is None:
276 return self
278 return self
277 else:
279 else:
278 try:
280 try:
279 value = obj._trait_values[self.name]
281 value = obj._trait_values[self.name]
280 except KeyError:
282 except KeyError:
281 # Check for a dynamic initializer.
283 # Check for a dynamic initializer.
282 if self.name in obj._trait_dyn_inits:
284 if self.name in obj._trait_dyn_inits:
283 value = obj._trait_dyn_inits[self.name](obj)
285 value = obj._trait_dyn_inits[self.name](obj)
284 # FIXME: Do we really validate here?
286 # FIXME: Do we really validate here?
285 value = self._validate(obj, value)
287 value = self._validate(obj, value)
286 obj._trait_values[self.name] = value
288 obj._trait_values[self.name] = value
287 return value
289 return value
288 else:
290 else:
289 raise TraitError('Unexpected error in TraitType: '
291 raise TraitError('Unexpected error in TraitType: '
290 'both default value and dynamic initializer are '
292 'both default value and dynamic initializer are '
291 'absent.')
293 'absent.')
292 except Exception:
294 except Exception:
293 # HasTraits should call set_default_value to populate
295 # HasTraits should call set_default_value to populate
294 # this. So this should never be reached.
296 # this. So this should never be reached.
295 raise TraitError('Unexpected error in TraitType: '
297 raise TraitError('Unexpected error in TraitType: '
296 'default value not set properly')
298 'default value not set properly')
297 else:
299 else:
298 return value
300 return value
299
301
300 def __set__(self, obj, value):
302 def __set__(self, obj, value):
301 new_value = self._validate(obj, value)
303 new_value = self._validate(obj, value)
302 old_value = self.__get__(obj)
304 old_value = self.__get__(obj)
303 if old_value != new_value:
305 if old_value != new_value:
304 obj._trait_values[self.name] = new_value
306 obj._trait_values[self.name] = new_value
305 obj._notify_trait(self.name, old_value, new_value)
307 obj._notify_trait(self.name, old_value, new_value)
306
308
307 def _validate(self, obj, value):
309 def _validate(self, obj, value):
308 if hasattr(self, 'validate'):
310 if hasattr(self, 'validate'):
309 return self.validate(obj, value)
311 return self.validate(obj, value)
310 elif hasattr(self, 'is_valid_for'):
312 elif hasattr(self, 'is_valid_for'):
311 valid = self.is_valid_for(value)
313 valid = self.is_valid_for(value)
312 if valid:
314 if valid:
313 return value
315 return value
314 else:
316 else:
315 raise TraitError('invalid value for type: %r' % value)
317 raise TraitError('invalid value for type: %r' % value)
316 elif hasattr(self, 'value_for'):
318 elif hasattr(self, 'value_for'):
317 return self.value_for(value)
319 return self.value_for(value)
318 else:
320 else:
319 return value
321 return value
320
322
321 def info(self):
323 def info(self):
322 return self.info_text
324 return self.info_text
323
325
324 def error(self, obj, value):
326 def error(self, obj, value):
325 if obj is not None:
327 if obj is not None:
326 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
328 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
327 % (self.name, class_of(obj),
329 % (self.name, class_of(obj),
328 self.info(), repr_type(value))
330 self.info(), repr_type(value))
329 else:
331 else:
330 e = "The '%s' trait must be %s, but a value of %r was specified." \
332 e = "The '%s' trait must be %s, but a value of %r was specified." \
331 % (self.name, self.info(), repr_type(value))
333 % (self.name, self.info(), repr_type(value))
332 raise TraitError(e)
334 raise TraitError(e)
333
335
334 def get_metadata(self, key):
336 def get_metadata(self, key):
335 return getattr(self, '_metadata', {}).get(key, None)
337 return getattr(self, '_metadata', {}).get(key, None)
336
338
337 def set_metadata(self, key, value):
339 def set_metadata(self, key, value):
338 getattr(self, '_metadata', {})[key] = value
340 getattr(self, '_metadata', {})[key] = value
339
341
340
342
341 #-----------------------------------------------------------------------------
343 #-----------------------------------------------------------------------------
342 # The HasTraits implementation
344 # The HasTraits implementation
343 #-----------------------------------------------------------------------------
345 #-----------------------------------------------------------------------------
344
346
345
347
346 class MetaHasTraits(type):
348 class MetaHasTraits(type):
347 """A metaclass for HasTraits.
349 """A metaclass for HasTraits.
348
350
349 This metaclass makes sure that any TraitType class attributes are
351 This metaclass makes sure that any TraitType class attributes are
350 instantiated and sets their name attribute.
352 instantiated and sets their name attribute.
351 """
353 """
352
354
353 def __new__(mcls, name, bases, classdict):
355 def __new__(mcls, name, bases, classdict):
354 """Create the HasTraits class.
356 """Create the HasTraits class.
355
357
356 This instantiates all TraitTypes in the class dict and sets their
358 This instantiates all TraitTypes in the class dict and sets their
357 :attr:`name` attribute.
359 :attr:`name` attribute.
358 """
360 """
359 # print "MetaHasTraitlets (mcls, name): ", mcls, name
361 # print "MetaHasTraitlets (mcls, name): ", mcls, name
360 # print "MetaHasTraitlets (bases): ", bases
362 # print "MetaHasTraitlets (bases): ", bases
361 # print "MetaHasTraitlets (classdict): ", classdict
363 # print "MetaHasTraitlets (classdict): ", classdict
362 for k,v in classdict.iteritems():
364 for k,v in classdict.iteritems():
363 if isinstance(v, TraitType):
365 if isinstance(v, TraitType):
364 v.name = k
366 v.name = k
365 elif inspect.isclass(v):
367 elif inspect.isclass(v):
366 if issubclass(v, TraitType):
368 if issubclass(v, TraitType):
367 vinst = v()
369 vinst = v()
368 vinst.name = k
370 vinst.name = k
369 classdict[k] = vinst
371 classdict[k] = vinst
370 return super(MetaHasTraits, mcls).__new__(mcls, name, bases, classdict)
372 return super(MetaHasTraits, mcls).__new__(mcls, name, bases, classdict)
371
373
372 def __init__(cls, name, bases, classdict):
374 def __init__(cls, name, bases, classdict):
373 """Finish initializing the HasTraits class.
375 """Finish initializing the HasTraits class.
374
376
375 This sets the :attr:`this_class` attribute of each TraitType in the
377 This sets the :attr:`this_class` attribute of each TraitType in the
376 class dict to the newly created class ``cls``.
378 class dict to the newly created class ``cls``.
377 """
379 """
378 for k, v in classdict.iteritems():
380 for k, v in classdict.iteritems():
379 if isinstance(v, TraitType):
381 if isinstance(v, TraitType):
380 v.this_class = cls
382 v.this_class = cls
381 super(MetaHasTraits, cls).__init__(name, bases, classdict)
383 super(MetaHasTraits, cls).__init__(name, bases, classdict)
382
384
383 class HasTraits(object):
385 class HasTraits(object):
384
386
385 __metaclass__ = MetaHasTraits
387 __metaclass__ = MetaHasTraits
386
388
387 def __new__(cls, **kw):
389 def __new__(cls, **kw):
388 # This is needed because in Python 2.6 object.__new__ only accepts
390 # This is needed because in Python 2.6 object.__new__ only accepts
389 # the cls argument.
391 # the cls argument.
390 new_meth = super(HasTraits, cls).__new__
392 new_meth = super(HasTraits, cls).__new__
391 if new_meth is object.__new__:
393 if new_meth is object.__new__:
392 inst = new_meth(cls)
394 inst = new_meth(cls)
393 else:
395 else:
394 inst = new_meth(cls, **kw)
396 inst = new_meth(cls, **kw)
395 inst._trait_values = {}
397 inst._trait_values = {}
396 inst._trait_notifiers = {}
398 inst._trait_notifiers = {}
397 inst._trait_dyn_inits = {}
399 inst._trait_dyn_inits = {}
398 # Here we tell all the TraitType instances to set their default
400 # Here we tell all the TraitType instances to set their default
399 # values on the instance.
401 # values on the instance.
400 for key in dir(cls):
402 for key in dir(cls):
401 # Some descriptors raise AttributeError like zope.interface's
403 # Some descriptors raise AttributeError like zope.interface's
402 # __provides__ attributes even though they exist. This causes
404 # __provides__ attributes even though they exist. This causes
403 # AttributeErrors even though they are listed in dir(cls).
405 # AttributeErrors even though they are listed in dir(cls).
404 try:
406 try:
405 value = getattr(cls, key)
407 value = getattr(cls, key)
406 except AttributeError:
408 except AttributeError:
407 pass
409 pass
408 else:
410 else:
409 if isinstance(value, TraitType):
411 if isinstance(value, TraitType):
410 value.instance_init(inst)
412 value.instance_init(inst)
411
413
412 return inst
414 return inst
413
415
414 def __init__(self, **kw):
416 def __init__(self, **kw):
415 # Allow trait values to be set using keyword arguments.
417 # Allow trait values to be set using keyword arguments.
416 # We need to use setattr for this to trigger validation and
418 # We need to use setattr for this to trigger validation and
417 # notifications.
419 # notifications.
418 for key, value in kw.iteritems():
420 for key, value in kw.iteritems():
419 setattr(self, key, value)
421 setattr(self, key, value)
420
422
421 def _notify_trait(self, name, old_value, new_value):
423 def _notify_trait(self, name, old_value, new_value):
422
424
423 # First dynamic ones
425 # First dynamic ones
424 callables = self._trait_notifiers.get(name,[])
426 callables = self._trait_notifiers.get(name,[])
425 more_callables = self._trait_notifiers.get('anytrait',[])
427 more_callables = self._trait_notifiers.get('anytrait',[])
426 callables.extend(more_callables)
428 callables.extend(more_callables)
427
429
428 # Now static ones
430 # Now static ones
429 try:
431 try:
430 cb = getattr(self, '_%s_changed' % name)
432 cb = getattr(self, '_%s_changed' % name)
431 except:
433 except:
432 pass
434 pass
433 else:
435 else:
434 callables.append(cb)
436 callables.append(cb)
435
437
436 # Call them all now
438 # Call them all now
437 for c in callables:
439 for c in callables:
438 # Traits catches and logs errors here. I allow them to raise
440 # Traits catches and logs errors here. I allow them to raise
439 if callable(c):
441 if callable(c):
440 argspec = inspect.getargspec(c)
442 argspec = inspect.getargspec(c)
441 nargs = len(argspec[0])
443 nargs = len(argspec[0])
442 # Bound methods have an additional 'self' argument
444 # Bound methods have an additional 'self' argument
443 # I don't know how to treat unbound methods, but they
445 # I don't know how to treat unbound methods, but they
444 # can't really be used for callbacks.
446 # can't really be used for callbacks.
445 if isinstance(c, types.MethodType):
447 if isinstance(c, types.MethodType):
446 offset = -1
448 offset = -1
447 else:
449 else:
448 offset = 0
450 offset = 0
449 if nargs + offset == 0:
451 if nargs + offset == 0:
450 c()
452 c()
451 elif nargs + offset == 1:
453 elif nargs + offset == 1:
452 c(name)
454 c(name)
453 elif nargs + offset == 2:
455 elif nargs + offset == 2:
454 c(name, new_value)
456 c(name, new_value)
455 elif nargs + offset == 3:
457 elif nargs + offset == 3:
456 c(name, old_value, new_value)
458 c(name, old_value, new_value)
457 else:
459 else:
458 raise TraitError('a trait changed callback '
460 raise TraitError('a trait changed callback '
459 'must have 0-3 arguments.')
461 'must have 0-3 arguments.')
460 else:
462 else:
461 raise TraitError('a trait changed callback '
463 raise TraitError('a trait changed callback '
462 'must be callable.')
464 'must be callable.')
463
465
464
466
465 def _add_notifiers(self, handler, name):
467 def _add_notifiers(self, handler, name):
466 if not self._trait_notifiers.has_key(name):
468 if not self._trait_notifiers.has_key(name):
467 nlist = []
469 nlist = []
468 self._trait_notifiers[name] = nlist
470 self._trait_notifiers[name] = nlist
469 else:
471 else:
470 nlist = self._trait_notifiers[name]
472 nlist = self._trait_notifiers[name]
471 if handler not in nlist:
473 if handler not in nlist:
472 nlist.append(handler)
474 nlist.append(handler)
473
475
474 def _remove_notifiers(self, handler, name):
476 def _remove_notifiers(self, handler, name):
475 if self._trait_notifiers.has_key(name):
477 if self._trait_notifiers.has_key(name):
476 nlist = self._trait_notifiers[name]
478 nlist = self._trait_notifiers[name]
477 try:
479 try:
478 index = nlist.index(handler)
480 index = nlist.index(handler)
479 except ValueError:
481 except ValueError:
480 pass
482 pass
481 else:
483 else:
482 del nlist[index]
484 del nlist[index]
483
485
484 def on_trait_change(self, handler, name=None, remove=False):
486 def on_trait_change(self, handler, name=None, remove=False):
485 """Setup a handler to be called when a trait changes.
487 """Setup a handler to be called when a trait changes.
486
488
487 This is used to setup dynamic notifications of trait changes.
489 This is used to setup dynamic notifications of trait changes.
488
490
489 Static handlers can be created by creating methods on a HasTraits
491 Static handlers can be created by creating methods on a HasTraits
490 subclass with the naming convention '_[traitname]_changed'. Thus,
492 subclass with the naming convention '_[traitname]_changed'. Thus,
491 to create static handler for the trait 'a', create the method
493 to create static handler for the trait 'a', create the method
492 _a_changed(self, name, old, new) (fewer arguments can be used, see
494 _a_changed(self, name, old, new) (fewer arguments can be used, see
493 below).
495 below).
494
496
495 Parameters
497 Parameters
496 ----------
498 ----------
497 handler : callable
499 handler : callable
498 A callable that is called when a trait changes. Its
500 A callable that is called when a trait changes. Its
499 signature can be handler(), handler(name), handler(name, new)
501 signature can be handler(), handler(name), handler(name, new)
500 or handler(name, old, new).
502 or handler(name, old, new).
501 name : list, str, None
503 name : list, str, None
502 If None, the handler will apply to all traits. If a list
504 If None, the handler will apply to all traits. If a list
503 of str, handler will apply to all names in the list. If a
505 of str, handler will apply to all names in the list. If a
504 str, the handler will apply just to that name.
506 str, the handler will apply just to that name.
505 remove : bool
507 remove : bool
506 If False (the default), then install the handler. If True
508 If False (the default), then install the handler. If True
507 then unintall it.
509 then unintall it.
508 """
510 """
509 if remove:
511 if remove:
510 names = parse_notifier_name(name)
512 names = parse_notifier_name(name)
511 for n in names:
513 for n in names:
512 self._remove_notifiers(handler, n)
514 self._remove_notifiers(handler, n)
513 else:
515 else:
514 names = parse_notifier_name(name)
516 names = parse_notifier_name(name)
515 for n in names:
517 for n in names:
516 self._add_notifiers(handler, n)
518 self._add_notifiers(handler, n)
517
519
518 @classmethod
520 @classmethod
519 def class_trait_names(cls, **metadata):
521 def class_trait_names(cls, **metadata):
520 """Get a list of all the names of this classes traits.
522 """Get a list of all the names of this classes traits.
521
523
522 This method is just like the :meth:`trait_names` method, but is unbound.
524 This method is just like the :meth:`trait_names` method, but is unbound.
523 """
525 """
524 return cls.class_traits(**metadata).keys()
526 return cls.class_traits(**metadata).keys()
525
527
526 @classmethod
528 @classmethod
527 def class_traits(cls, **metadata):
529 def class_traits(cls, **metadata):
528 """Get a list of all the traits of this class.
530 """Get a list of all the traits of this class.
529
531
530 This method is just like the :meth:`traits` method, but is unbound.
532 This method is just like the :meth:`traits` method, but is unbound.
531
533
532 The TraitTypes returned don't know anything about the values
534 The TraitTypes returned don't know anything about the values
533 that the various HasTrait's instances are holding.
535 that the various HasTrait's instances are holding.
534
536
535 This follows the same algorithm as traits does and does not allow
537 This follows the same algorithm as traits does and does not allow
536 for any simple way of specifying merely that a metadata name
538 for any simple way of specifying merely that a metadata name
537 exists, but has any value. This is because get_metadata returns
539 exists, but has any value. This is because get_metadata returns
538 None if a metadata key doesn't exist.
540 None if a metadata key doesn't exist.
539 """
541 """
540 traits = dict([memb for memb in getmembers(cls) if \
542 traits = dict([memb for memb in getmembers(cls) if \
541 isinstance(memb[1], TraitType)])
543 isinstance(memb[1], TraitType)])
542
544
543 if len(metadata) == 0:
545 if len(metadata) == 0:
544 return traits
546 return traits
545
547
546 for meta_name, meta_eval in metadata.items():
548 for meta_name, meta_eval in metadata.items():
547 if type(meta_eval) is not FunctionType:
549 if type(meta_eval) is not FunctionType:
548 metadata[meta_name] = _SimpleTest(meta_eval)
550 metadata[meta_name] = _SimpleTest(meta_eval)
549
551
550 result = {}
552 result = {}
551 for name, trait in traits.items():
553 for name, trait in traits.items():
552 for meta_name, meta_eval in metadata.items():
554 for meta_name, meta_eval in metadata.items():
553 if not meta_eval(trait.get_metadata(meta_name)):
555 if not meta_eval(trait.get_metadata(meta_name)):
554 break
556 break
555 else:
557 else:
556 result[name] = trait
558 result[name] = trait
557
559
558 return result
560 return result
559
561
560 def trait_names(self, **metadata):
562 def trait_names(self, **metadata):
561 """Get a list of all the names of this classes traits."""
563 """Get a list of all the names of this classes traits."""
562 return self.traits(**metadata).keys()
564 return self.traits(**metadata).keys()
563
565
564 def traits(self, **metadata):
566 def traits(self, **metadata):
565 """Get a list of all the traits of this class.
567 """Get a list of all the traits of this class.
566
568
567 The TraitTypes returned don't know anything about the values
569 The TraitTypes returned don't know anything about the values
568 that the various HasTrait's instances are holding.
570 that the various HasTrait's instances are holding.
569
571
570 This follows the same algorithm as traits does and does not allow
572 This follows the same algorithm as traits does and does not allow
571 for any simple way of specifying merely that a metadata name
573 for any simple way of specifying merely that a metadata name
572 exists, but has any value. This is because get_metadata returns
574 exists, but has any value. This is because get_metadata returns
573 None if a metadata key doesn't exist.
575 None if a metadata key doesn't exist.
574 """
576 """
575 traits = dict([memb for memb in getmembers(self.__class__) if \
577 traits = dict([memb for memb in getmembers(self.__class__) if \
576 isinstance(memb[1], TraitType)])
578 isinstance(memb[1], TraitType)])
577
579
578 if len(metadata) == 0:
580 if len(metadata) == 0:
579 return traits
581 return traits
580
582
581 for meta_name, meta_eval in metadata.items():
583 for meta_name, meta_eval in metadata.items():
582 if type(meta_eval) is not FunctionType:
584 if type(meta_eval) is not FunctionType:
583 metadata[meta_name] = _SimpleTest(meta_eval)
585 metadata[meta_name] = _SimpleTest(meta_eval)
584
586
585 result = {}
587 result = {}
586 for name, trait in traits.items():
588 for name, trait in traits.items():
587 for meta_name, meta_eval in metadata.items():
589 for meta_name, meta_eval in metadata.items():
588 if not meta_eval(trait.get_metadata(meta_name)):
590 if not meta_eval(trait.get_metadata(meta_name)):
589 break
591 break
590 else:
592 else:
591 result[name] = trait
593 result[name] = trait
592
594
593 return result
595 return result
594
596
595 def trait_metadata(self, traitname, key):
597 def trait_metadata(self, traitname, key):
596 """Get metadata values for trait by key."""
598 """Get metadata values for trait by key."""
597 try:
599 try:
598 trait = getattr(self.__class__, traitname)
600 trait = getattr(self.__class__, traitname)
599 except AttributeError:
601 except AttributeError:
600 raise TraitError("Class %s does not have a trait named %s" %
602 raise TraitError("Class %s does not have a trait named %s" %
601 (self.__class__.__name__, traitname))
603 (self.__class__.__name__, traitname))
602 else:
604 else:
603 return trait.get_metadata(key)
605 return trait.get_metadata(key)
604
606
605 #-----------------------------------------------------------------------------
607 #-----------------------------------------------------------------------------
606 # Actual TraitTypes implementations/subclasses
608 # Actual TraitTypes implementations/subclasses
607 #-----------------------------------------------------------------------------
609 #-----------------------------------------------------------------------------
608
610
609 #-----------------------------------------------------------------------------
611 #-----------------------------------------------------------------------------
610 # TraitTypes subclasses for handling classes and instances of classes
612 # TraitTypes subclasses for handling classes and instances of classes
611 #-----------------------------------------------------------------------------
613 #-----------------------------------------------------------------------------
612
614
613
615
614 class ClassBasedTraitType(TraitType):
616 class ClassBasedTraitType(TraitType):
615 """A trait with error reporting for Type, Instance and This."""
617 """A trait with error reporting for Type, Instance and This."""
616
618
617 def error(self, obj, value):
619 def error(self, obj, value):
618 kind = type(value)
620 kind = type(value)
619 if kind is InstanceType:
621 if (not py3compat.PY3) and kind is InstanceType:
620 msg = 'class %s' % value.__class__.__name__
622 msg = 'class %s' % value.__class__.__name__
621 else:
623 else:
622 msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
624 msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
623
625
624 if obj is not None:
626 if obj is not None:
625 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
627 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
626 % (self.name, class_of(obj),
628 % (self.name, class_of(obj),
627 self.info(), msg)
629 self.info(), msg)
628 else:
630 else:
629 e = "The '%s' trait must be %s, but a value of %r was specified." \
631 e = "The '%s' trait must be %s, but a value of %r was specified." \
630 % (self.name, self.info(), msg)
632 % (self.name, self.info(), msg)
631
633
632 raise TraitError(e)
634 raise TraitError(e)
633
635
634
636
635 class Type(ClassBasedTraitType):
637 class Type(ClassBasedTraitType):
636 """A trait whose value must be a subclass of a specified class."""
638 """A trait whose value must be a subclass of a specified class."""
637
639
638 def __init__ (self, default_value=None, klass=None, allow_none=True, **metadata ):
640 def __init__ (self, default_value=None, klass=None, allow_none=True, **metadata ):
639 """Construct a Type trait
641 """Construct a Type trait
640
642
641 A Type trait specifies that its values must be subclasses of
643 A Type trait specifies that its values must be subclasses of
642 a particular class.
644 a particular class.
643
645
644 If only ``default_value`` is given, it is used for the ``klass`` as
646 If only ``default_value`` is given, it is used for the ``klass`` as
645 well.
647 well.
646
648
647 Parameters
649 Parameters
648 ----------
650 ----------
649 default_value : class, str or None
651 default_value : class, str or None
650 The default value must be a subclass of klass. If an str,
652 The default value must be a subclass of klass. If an str,
651 the str must be a fully specified class name, like 'foo.bar.Bah'.
653 the str must be a fully specified class name, like 'foo.bar.Bah'.
652 The string is resolved into real class, when the parent
654 The string is resolved into real class, when the parent
653 :class:`HasTraits` class is instantiated.
655 :class:`HasTraits` class is instantiated.
654 klass : class, str, None
656 klass : class, str, None
655 Values of this trait must be a subclass of klass. The klass
657 Values of this trait must be a subclass of klass. The klass
656 may be specified in a string like: 'foo.bar.MyClass'.
658 may be specified in a string like: 'foo.bar.MyClass'.
657 The string is resolved into real class, when the parent
659 The string is resolved into real class, when the parent
658 :class:`HasTraits` class is instantiated.
660 :class:`HasTraits` class is instantiated.
659 allow_none : boolean
661 allow_none : boolean
660 Indicates whether None is allowed as an assignable value. Even if
662 Indicates whether None is allowed as an assignable value. Even if
661 ``False``, the default value may be ``None``.
663 ``False``, the default value may be ``None``.
662 """
664 """
663 if default_value is None:
665 if default_value is None:
664 if klass is None:
666 if klass is None:
665 klass = object
667 klass = object
666 elif klass is None:
668 elif klass is None:
667 klass = default_value
669 klass = default_value
668
670
669 if not (inspect.isclass(klass) or isinstance(klass, basestring)):
671 if not (inspect.isclass(klass) or isinstance(klass, basestring)):
670 raise TraitError("A Type trait must specify a class.")
672 raise TraitError("A Type trait must specify a class.")
671
673
672 self.klass = klass
674 self.klass = klass
673 self._allow_none = allow_none
675 self._allow_none = allow_none
674
676
675 super(Type, self).__init__(default_value, **metadata)
677 super(Type, self).__init__(default_value, **metadata)
676
678
677 def validate(self, obj, value):
679 def validate(self, obj, value):
678 """Validates that the value is a valid object instance."""
680 """Validates that the value is a valid object instance."""
679 try:
681 try:
680 if issubclass(value, self.klass):
682 if issubclass(value, self.klass):
681 return value
683 return value
682 except:
684 except:
683 if (value is None) and (self._allow_none):
685 if (value is None) and (self._allow_none):
684 return value
686 return value
685
687
686 self.error(obj, value)
688 self.error(obj, value)
687
689
688 def info(self):
690 def info(self):
689 """ Returns a description of the trait."""
691 """ Returns a description of the trait."""
690 if isinstance(self.klass, basestring):
692 if isinstance(self.klass, basestring):
691 klass = self.klass
693 klass = self.klass
692 else:
694 else:
693 klass = self.klass.__name__
695 klass = self.klass.__name__
694 result = 'a subclass of ' + klass
696 result = 'a subclass of ' + klass
695 if self._allow_none:
697 if self._allow_none:
696 return result + ' or None'
698 return result + ' or None'
697 return result
699 return result
698
700
699 def instance_init(self, obj):
701 def instance_init(self, obj):
700 self._resolve_classes()
702 self._resolve_classes()
701 super(Type, self).instance_init(obj)
703 super(Type, self).instance_init(obj)
702
704
703 def _resolve_classes(self):
705 def _resolve_classes(self):
704 if isinstance(self.klass, basestring):
706 if isinstance(self.klass, basestring):
705 self.klass = import_item(self.klass)
707 self.klass = import_item(self.klass)
706 if isinstance(self.default_value, basestring):
708 if isinstance(self.default_value, basestring):
707 self.default_value = import_item(self.default_value)
709 self.default_value = import_item(self.default_value)
708
710
709 def get_default_value(self):
711 def get_default_value(self):
710 return self.default_value
712 return self.default_value
711
713
712
714
713 class DefaultValueGenerator(object):
715 class DefaultValueGenerator(object):
714 """A class for generating new default value instances."""
716 """A class for generating new default value instances."""
715
717
716 def __init__(self, *args, **kw):
718 def __init__(self, *args, **kw):
717 self.args = args
719 self.args = args
718 self.kw = kw
720 self.kw = kw
719
721
720 def generate(self, klass):
722 def generate(self, klass):
721 return klass(*self.args, **self.kw)
723 return klass(*self.args, **self.kw)
722
724
723
725
724 class Instance(ClassBasedTraitType):
726 class Instance(ClassBasedTraitType):
725 """A trait whose value must be an instance of a specified class.
727 """A trait whose value must be an instance of a specified class.
726
728
727 The value can also be an instance of a subclass of the specified class.
729 The value can also be an instance of a subclass of the specified class.
728 """
730 """
729
731
730 def __init__(self, klass=None, args=None, kw=None,
732 def __init__(self, klass=None, args=None, kw=None,
731 allow_none=True, **metadata ):
733 allow_none=True, **metadata ):
732 """Construct an Instance trait.
734 """Construct an Instance trait.
733
735
734 This trait allows values that are instances of a particular
736 This trait allows values that are instances of a particular
735 class or its sublclasses. Our implementation is quite different
737 class or its sublclasses. Our implementation is quite different
736 from that of enthough.traits as we don't allow instances to be used
738 from that of enthough.traits as we don't allow instances to be used
737 for klass and we handle the ``args`` and ``kw`` arguments differently.
739 for klass and we handle the ``args`` and ``kw`` arguments differently.
738
740
739 Parameters
741 Parameters
740 ----------
742 ----------
741 klass : class, str
743 klass : class, str
742 The class that forms the basis for the trait. Class names
744 The class that forms the basis for the trait. Class names
743 can also be specified as strings, like 'foo.bar.Bar'.
745 can also be specified as strings, like 'foo.bar.Bar'.
744 args : tuple
746 args : tuple
745 Positional arguments for generating the default value.
747 Positional arguments for generating the default value.
746 kw : dict
748 kw : dict
747 Keyword arguments for generating the default value.
749 Keyword arguments for generating the default value.
748 allow_none : bool
750 allow_none : bool
749 Indicates whether None is allowed as a value.
751 Indicates whether None is allowed as a value.
750
752
751 Default Value
753 Default Value
752 -------------
754 -------------
753 If both ``args`` and ``kw`` are None, then the default value is None.
755 If both ``args`` and ``kw`` are None, then the default value is None.
754 If ``args`` is a tuple and ``kw`` is a dict, then the default is
756 If ``args`` is a tuple and ``kw`` is a dict, then the default is
755 created as ``klass(*args, **kw)``. If either ``args`` or ``kw`` is
757 created as ``klass(*args, **kw)``. If either ``args`` or ``kw`` is
756 not (but not both), None is replace by ``()`` or ``{}``.
758 not (but not both), None is replace by ``()`` or ``{}``.
757 """
759 """
758
760
759 self._allow_none = allow_none
761 self._allow_none = allow_none
760
762
761 if (klass is None) or (not (inspect.isclass(klass) or isinstance(klass, basestring))):
763 if (klass is None) or (not (inspect.isclass(klass) or isinstance(klass, basestring))):
762 raise TraitError('The klass argument must be a class'
764 raise TraitError('The klass argument must be a class'
763 ' you gave: %r' % klass)
765 ' you gave: %r' % klass)
764 self.klass = klass
766 self.klass = klass
765
767
766 # self.klass is a class, so handle default_value
768 # self.klass is a class, so handle default_value
767 if args is None and kw is None:
769 if args is None and kw is None:
768 default_value = None
770 default_value = None
769 else:
771 else:
770 if args is None:
772 if args is None:
771 # kw is not None
773 # kw is not None
772 args = ()
774 args = ()
773 elif kw is None:
775 elif kw is None:
774 # args is not None
776 # args is not None
775 kw = {}
777 kw = {}
776
778
777 if not isinstance(kw, dict):
779 if not isinstance(kw, dict):
778 raise TraitError("The 'kw' argument must be a dict or None.")
780 raise TraitError("The 'kw' argument must be a dict or None.")
779 if not isinstance(args, tuple):
781 if not isinstance(args, tuple):
780 raise TraitError("The 'args' argument must be a tuple or None.")
782 raise TraitError("The 'args' argument must be a tuple or None.")
781
783
782 default_value = DefaultValueGenerator(*args, **kw)
784 default_value = DefaultValueGenerator(*args, **kw)
783
785
784 super(Instance, self).__init__(default_value, **metadata)
786 super(Instance, self).__init__(default_value, **metadata)
785
787
786 def validate(self, obj, value):
788 def validate(self, obj, value):
787 if value is None:
789 if value is None:
788 if self._allow_none:
790 if self._allow_none:
789 return value
791 return value
790 self.error(obj, value)
792 self.error(obj, value)
791
793
792 if isinstance(value, self.klass):
794 if isinstance(value, self.klass):
793 return value
795 return value
794 else:
796 else:
795 self.error(obj, value)
797 self.error(obj, value)
796
798
797 def info(self):
799 def info(self):
798 if isinstance(self.klass, basestring):
800 if isinstance(self.klass, basestring):
799 klass = self.klass
801 klass = self.klass
800 else:
802 else:
801 klass = self.klass.__name__
803 klass = self.klass.__name__
802 result = class_of(klass)
804 result = class_of(klass)
803 if self._allow_none:
805 if self._allow_none:
804 return result + ' or None'
806 return result + ' or None'
805
807
806 return result
808 return result
807
809
808 def instance_init(self, obj):
810 def instance_init(self, obj):
809 self._resolve_classes()
811 self._resolve_classes()
810 super(Instance, self).instance_init(obj)
812 super(Instance, self).instance_init(obj)
811
813
812 def _resolve_classes(self):
814 def _resolve_classes(self):
813 if isinstance(self.klass, basestring):
815 if isinstance(self.klass, basestring):
814 self.klass = import_item(self.klass)
816 self.klass = import_item(self.klass)
815
817
816 def get_default_value(self):
818 def get_default_value(self):
817 """Instantiate a default value instance.
819 """Instantiate a default value instance.
818
820
819 This is called when the containing HasTraits classes'
821 This is called when the containing HasTraits classes'
820 :meth:`__new__` method is called to ensure that a unique instance
822 :meth:`__new__` method is called to ensure that a unique instance
821 is created for each HasTraits instance.
823 is created for each HasTraits instance.
822 """
824 """
823 dv = self.default_value
825 dv = self.default_value
824 if isinstance(dv, DefaultValueGenerator):
826 if isinstance(dv, DefaultValueGenerator):
825 return dv.generate(self.klass)
827 return dv.generate(self.klass)
826 else:
828 else:
827 return dv
829 return dv
828
830
829
831
830 class This(ClassBasedTraitType):
832 class This(ClassBasedTraitType):
831 """A trait for instances of the class containing this trait.
833 """A trait for instances of the class containing this trait.
832
834
833 Because how how and when class bodies are executed, the ``This``
835 Because how how and when class bodies are executed, the ``This``
834 trait can only have a default value of None. This, and because we
836 trait can only have a default value of None. This, and because we
835 always validate default values, ``allow_none`` is *always* true.
837 always validate default values, ``allow_none`` is *always* true.
836 """
838 """
837
839
838 info_text = 'an instance of the same type as the receiver or None'
840 info_text = 'an instance of the same type as the receiver or None'
839
841
840 def __init__(self, **metadata):
842 def __init__(self, **metadata):
841 super(This, self).__init__(None, **metadata)
843 super(This, self).__init__(None, **metadata)
842
844
843 def validate(self, obj, value):
845 def validate(self, obj, value):
844 # What if value is a superclass of obj.__class__? This is
846 # What if value is a superclass of obj.__class__? This is
845 # complicated if it was the superclass that defined the This
847 # complicated if it was the superclass that defined the This
846 # trait.
848 # trait.
847 if isinstance(value, self.this_class) or (value is None):
849 if isinstance(value, self.this_class) or (value is None):
848 return value
850 return value
849 else:
851 else:
850 self.error(obj, value)
852 self.error(obj, value)
851
853
852
854
853 #-----------------------------------------------------------------------------
855 #-----------------------------------------------------------------------------
854 # Basic TraitTypes implementations/subclasses
856 # Basic TraitTypes implementations/subclasses
855 #-----------------------------------------------------------------------------
857 #-----------------------------------------------------------------------------
856
858
857
859
858 class Any(TraitType):
860 class Any(TraitType):
859 default_value = None
861 default_value = None
860 info_text = 'any value'
862 info_text = 'any value'
861
863
862
864
863 class Int(TraitType):
865 class Int(TraitType):
864 """A integer trait."""
866 """A integer trait."""
865
867
866 default_value = 0
868 default_value = 0
867 info_text = 'an integer'
869 info_text = 'an integer'
868
870
869 def validate(self, obj, value):
871 def validate(self, obj, value):
870 if isinstance(value, int):
872 if isinstance(value, int):
871 return value
873 return value
872 self.error(obj, value)
874 self.error(obj, value)
873
875
874 class CInt(Int):
876 class CInt(Int):
875 """A casting version of the int trait."""
877 """A casting version of the int trait."""
876
878
877 def validate(self, obj, value):
879 def validate(self, obj, value):
878 try:
880 try:
879 return int(value)
881 return int(value)
880 except:
882 except:
881 self.error(obj, value)
883 self.error(obj, value)
882
884
885 if not py3compat.PY3:
886 class Long(TraitType):
887 """A long integer trait."""
883
888
884 class Long(TraitType):
889 default_value = 0L
885 """A long integer trait."""
890 info_text = 'a long'
886
891
887 default_value = 0L
892 def validate(self, obj, value):
888 info_text = 'a long'
893 if isinstance(value, long):
889
894 return value
890 def validate(self, obj, value):
895 if isinstance(value, int):
891 if isinstance(value, long):
896 return long(value)
892 return value
897 self.error(obj, value)
893 if isinstance(value, int):
894 return long(value)
895 self.error(obj, value)
896
898
897
899
898 class CLong(Long):
900 class CLong(Long):
899 """A casting version of the long integer trait."""
901 """A casting version of the long integer trait."""
900
902
901 def validate(self, obj, value):
903 def validate(self, obj, value):
902 try:
904 try:
903 return long(value)
905 return long(value)
904 except:
906 except:
905 self.error(obj, value)
907 self.error(obj, value)
906
908
907
909
908 class Float(TraitType):
910 class Float(TraitType):
909 """A float trait."""
911 """A float trait."""
910
912
911 default_value = 0.0
913 default_value = 0.0
912 info_text = 'a float'
914 info_text = 'a float'
913
915
914 def validate(self, obj, value):
916 def validate(self, obj, value):
915 if isinstance(value, float):
917 if isinstance(value, float):
916 return value
918 return value
917 if isinstance(value, int):
919 if isinstance(value, int):
918 return float(value)
920 return float(value)
919 self.error(obj, value)
921 self.error(obj, value)
920
922
921
923
922 class CFloat(Float):
924 class CFloat(Float):
923 """A casting version of the float trait."""
925 """A casting version of the float trait."""
924
926
925 def validate(self, obj, value):
927 def validate(self, obj, value):
926 try:
928 try:
927 return float(value)
929 return float(value)
928 except:
930 except:
929 self.error(obj, value)
931 self.error(obj, value)
930
932
931 class Complex(TraitType):
933 class Complex(TraitType):
932 """A trait for complex numbers."""
934 """A trait for complex numbers."""
933
935
934 default_value = 0.0 + 0.0j
936 default_value = 0.0 + 0.0j
935 info_text = 'a complex number'
937 info_text = 'a complex number'
936
938
937 def validate(self, obj, value):
939 def validate(self, obj, value):
938 if isinstance(value, complex):
940 if isinstance(value, complex):
939 return value
941 return value
940 if isinstance(value, (float, int)):
942 if isinstance(value, (float, int)):
941 return complex(value)
943 return complex(value)
942 self.error(obj, value)
944 self.error(obj, value)
943
945
944
946
945 class CComplex(Complex):
947 class CComplex(Complex):
946 """A casting version of the complex number trait."""
948 """A casting version of the complex number trait."""
947
949
948 def validate (self, obj, value):
950 def validate (self, obj, value):
949 try:
951 try:
950 return complex(value)
952 return complex(value)
951 except:
953 except:
952 self.error(obj, value)
954 self.error(obj, value)
953
955
954 # We should always be explicit about whether we're using bytes or unicode, both
956 # We should always be explicit about whether we're using bytes or unicode, both
955 # for Python 3 conversion and for reliable unicode behaviour on Python 2. So
957 # for Python 3 conversion and for reliable unicode behaviour on Python 2. So
956 # we don't have a Str type.
958 # we don't have a Str type.
957 class Bytes(TraitType):
959 class Bytes(TraitType):
958 """A trait for strings."""
960 """A trait for byte strings."""
959
961
960 default_value = ''
962 default_value = ''
961 info_text = 'a string'
963 info_text = 'a string'
962
964
963 def validate(self, obj, value):
965 def validate(self, obj, value):
964 if isinstance(value, bytes):
966 if isinstance(value, bytes):
965 return value
967 return value
966 self.error(obj, value)
968 self.error(obj, value)
967
969
968
970
969 class CBytes(Bytes):
971 class CBytes(Bytes):
970 """A casting version of the string trait."""
972 """A casting version of the byte string trait."""
971
973
972 def validate(self, obj, value):
974 def validate(self, obj, value):
973 try:
975 try:
974 return bytes(value)
976 return bytes(value)
975 except:
977 except:
976 self.error(obj, value)
978 self.error(obj, value)
977
979
978
980
979 class Unicode(TraitType):
981 class Unicode(TraitType):
980 """A trait for unicode strings."""
982 """A trait for unicode strings."""
981
983
982 default_value = u''
984 default_value = u''
983 info_text = 'a unicode string'
985 info_text = 'a unicode string'
984
986
985 def validate(self, obj, value):
987 def validate(self, obj, value):
986 if isinstance(value, unicode):
988 if isinstance(value, unicode):
987 return value
989 return value
988 if isinstance(value, bytes):
990 if isinstance(value, bytes):
989 return unicode(value)
991 return unicode(value)
990 self.error(obj, value)
992 self.error(obj, value)
991
993
992
994
993 class CUnicode(Unicode):
995 class CUnicode(Unicode):
994 """A casting version of the unicode trait."""
996 """A casting version of the unicode trait."""
995
997
996 def validate(self, obj, value):
998 def validate(self, obj, value):
997 try:
999 try:
998 return unicode(value)
1000 return unicode(value)
999 except:
1001 except:
1000 self.error(obj, value)
1002 self.error(obj, value)
1001
1003
1002
1004
1003 class ObjectName(TraitType):
1005 class ObjectName(TraitType):
1004 """A string holding a valid object name in this version of Python.
1006 """A string holding a valid object name in this version of Python.
1005
1007
1006 This does not check that the name exists in any scope."""
1008 This does not check that the name exists in any scope."""
1007 info_text = "a valid object identifier in Python"
1009 info_text = "a valid object identifier in Python"
1008
1010
1009 if sys.version_info[0] < 3:
1011 if sys.version_info[0] < 3:
1010 # Python 2:
1012 # Python 2:
1011 _name_re = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*$")
1013 _name_re = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*$")
1012 def isidentifier(self, s):
1014 def isidentifier(self, s):
1013 return bool(self._name_re.match(s))
1015 return bool(self._name_re.match(s))
1014
1016
1015 def coerce_str(self, obj, value):
1017 def coerce_str(self, obj, value):
1016 "In Python 2, coerce ascii-only unicode to str"
1018 "In Python 2, coerce ascii-only unicode to str"
1017 if isinstance(value, unicode):
1019 if isinstance(value, unicode):
1018 try:
1020 try:
1019 return str(value)
1021 return str(value)
1020 except UnicodeEncodeError:
1022 except UnicodeEncodeError:
1021 self.error(obj, value)
1023 self.error(obj, value)
1022 return value
1024 return value
1023
1025
1024 else:
1026 else:
1025 # Python 3:
1027 # Python 3:
1026 isidentifier = staticmethod(lambda s: s.isidentifier())
1028 isidentifier = staticmethod(lambda s: s.isidentifier())
1027 coerce_str = staticmethod(lambda _,s: s)
1029 coerce_str = staticmethod(lambda _,s: s)
1028
1030
1029 def validate(self, obj, value):
1031 def validate(self, obj, value):
1030 value = self.coerce_str(obj, value)
1032 value = self.coerce_str(obj, value)
1031
1033
1032 if isinstance(value, str) and self.isidentifier(value):
1034 if isinstance(value, str) and self.isidentifier(value):
1033 return value
1035 return value
1034 self.error(obj, value)
1036 self.error(obj, value)
1035
1037
1036 class DottedObjectName(ObjectName):
1038 class DottedObjectName(ObjectName):
1037 """A string holding a valid dotted object name in Python, such as A.b3._c"""
1039 """A string holding a valid dotted object name in Python, such as A.b3._c"""
1038 def validate(self, obj, value):
1040 def validate(self, obj, value):
1039 value = self.coerce_str(obj, value)
1041 value = self.coerce_str(obj, value)
1040
1042
1041 if isinstance(value, str) and all(self.isidentifier(x) \
1043 if isinstance(value, str) and all(self.isidentifier(x) \
1042 for x in value.split('.')):
1044 for x in value.split('.')):
1043 return value
1045 return value
1044 self.error(obj, value)
1046 self.error(obj, value)
1045
1047
1046
1048
1047 class Bool(TraitType):
1049 class Bool(TraitType):
1048 """A boolean (True, False) trait."""
1050 """A boolean (True, False) trait."""
1049
1051
1050 default_value = False
1052 default_value = False
1051 info_text = 'a boolean'
1053 info_text = 'a boolean'
1052
1054
1053 def validate(self, obj, value):
1055 def validate(self, obj, value):
1054 if isinstance(value, bool):
1056 if isinstance(value, bool):
1055 return value
1057 return value
1056 self.error(obj, value)
1058 self.error(obj, value)
1057
1059
1058
1060
1059 class CBool(Bool):
1061 class CBool(Bool):
1060 """A casting version of the boolean trait."""
1062 """A casting version of the boolean trait."""
1061
1063
1062 def validate(self, obj, value):
1064 def validate(self, obj, value):
1063 try:
1065 try:
1064 return bool(value)
1066 return bool(value)
1065 except:
1067 except:
1066 self.error(obj, value)
1068 self.error(obj, value)
1067
1069
1068
1070
1069 class Enum(TraitType):
1071 class Enum(TraitType):
1070 """An enum that whose value must be in a given sequence."""
1072 """An enum that whose value must be in a given sequence."""
1071
1073
1072 def __init__(self, values, default_value=None, allow_none=True, **metadata):
1074 def __init__(self, values, default_value=None, allow_none=True, **metadata):
1073 self.values = values
1075 self.values = values
1074 self._allow_none = allow_none
1076 self._allow_none = allow_none
1075 super(Enum, self).__init__(default_value, **metadata)
1077 super(Enum, self).__init__(default_value, **metadata)
1076
1078
1077 def validate(self, obj, value):
1079 def validate(self, obj, value):
1078 if value is None:
1080 if value is None:
1079 if self._allow_none:
1081 if self._allow_none:
1080 return value
1082 return value
1081
1083
1082 if value in self.values:
1084 if value in self.values:
1083 return value
1085 return value
1084 self.error(obj, value)
1086 self.error(obj, value)
1085
1087
1086 def info(self):
1088 def info(self):
1087 """ Returns a description of the trait."""
1089 """ Returns a description of the trait."""
1088 result = 'any of ' + repr(self.values)
1090 result = 'any of ' + repr(self.values)
1089 if self._allow_none:
1091 if self._allow_none:
1090 return result + ' or None'
1092 return result + ' or None'
1091 return result
1093 return result
1092
1094
1093 class CaselessStrEnum(Enum):
1095 class CaselessStrEnum(Enum):
1094 """An enum of strings that are caseless in validate."""
1096 """An enum of strings that are caseless in validate."""
1095
1097
1096 def validate(self, obj, value):
1098 def validate(self, obj, value):
1097 if value is None:
1099 if value is None:
1098 if self._allow_none:
1100 if self._allow_none:
1099 return value
1101 return value
1100
1102
1101 if not isinstance(value, basestring):
1103 if not isinstance(value, basestring):
1102 self.error(obj, value)
1104 self.error(obj, value)
1103
1105
1104 for v in self.values:
1106 for v in self.values:
1105 if v.lower() == value.lower():
1107 if v.lower() == value.lower():
1106 return v
1108 return v
1107 self.error(obj, value)
1109 self.error(obj, value)
1108
1110
1109 class Container(Instance):
1111 class Container(Instance):
1110 """An instance of a container (list, set, etc.)
1112 """An instance of a container (list, set, etc.)
1111
1113
1112 To be subclassed by overriding klass.
1114 To be subclassed by overriding klass.
1113 """
1115 """
1114 klass = None
1116 klass = None
1115 _valid_defaults = SequenceTypes
1117 _valid_defaults = SequenceTypes
1116 _trait = None
1118 _trait = None
1117
1119
1118 def __init__(self, trait=None, default_value=None, allow_none=True,
1120 def __init__(self, trait=None, default_value=None, allow_none=True,
1119 **metadata):
1121 **metadata):
1120 """Create a container trait type from a list, set, or tuple.
1122 """Create a container trait type from a list, set, or tuple.
1121
1123
1122 The default value is created by doing ``List(default_value)``,
1124 The default value is created by doing ``List(default_value)``,
1123 which creates a copy of the ``default_value``.
1125 which creates a copy of the ``default_value``.
1124
1126
1125 ``trait`` can be specified, which restricts the type of elements
1127 ``trait`` can be specified, which restricts the type of elements
1126 in the container to that TraitType.
1128 in the container to that TraitType.
1127
1129
1128 If only one arg is given and it is not a Trait, it is taken as
1130 If only one arg is given and it is not a Trait, it is taken as
1129 ``default_value``:
1131 ``default_value``:
1130
1132
1131 ``c = List([1,2,3])``
1133 ``c = List([1,2,3])``
1132
1134
1133 Parameters
1135 Parameters
1134 ----------
1136 ----------
1135
1137
1136 trait : TraitType [ optional ]
1138 trait : TraitType [ optional ]
1137 the type for restricting the contents of the Container. If unspecified,
1139 the type for restricting the contents of the Container. If unspecified,
1138 types are not checked.
1140 types are not checked.
1139
1141
1140 default_value : SequenceType [ optional ]
1142 default_value : SequenceType [ optional ]
1141 The default value for the Trait. Must be list/tuple/set, and
1143 The default value for the Trait. Must be list/tuple/set, and
1142 will be cast to the container type.
1144 will be cast to the container type.
1143
1145
1144 allow_none : Bool [ default True ]
1146 allow_none : Bool [ default True ]
1145 Whether to allow the value to be None
1147 Whether to allow the value to be None
1146
1148
1147 **metadata : any
1149 **metadata : any
1148 further keys for extensions to the Trait (e.g. config)
1150 further keys for extensions to the Trait (e.g. config)
1149
1151
1150 """
1152 """
1151 istrait = lambda t: isinstance(t, type) and issubclass(t, TraitType)
1153 istrait = lambda t: isinstance(t, type) and issubclass(t, TraitType)
1152
1154
1153 # allow List([values]):
1155 # allow List([values]):
1154 if default_value is None and not istrait(trait):
1156 if default_value is None and not istrait(trait):
1155 default_value = trait
1157 default_value = trait
1156 trait = None
1158 trait = None
1157
1159
1158 if default_value is None:
1160 if default_value is None:
1159 args = ()
1161 args = ()
1160 elif isinstance(default_value, self._valid_defaults):
1162 elif isinstance(default_value, self._valid_defaults):
1161 args = (default_value,)
1163 args = (default_value,)
1162 else:
1164 else:
1163 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1165 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1164
1166
1165 if istrait(trait):
1167 if istrait(trait):
1166 self._trait = trait()
1168 self._trait = trait()
1167 self._trait.name = 'element'
1169 self._trait.name = 'element'
1168 elif trait is not None:
1170 elif trait is not None:
1169 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1171 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1170
1172
1171 super(Container,self).__init__(klass=self.klass, args=args,
1173 super(Container,self).__init__(klass=self.klass, args=args,
1172 allow_none=allow_none, **metadata)
1174 allow_none=allow_none, **metadata)
1173
1175
1174 def element_error(self, obj, element, validator):
1176 def element_error(self, obj, element, validator):
1175 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1177 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1176 % (self.name, class_of(obj), validator.info(), repr_type(element))
1178 % (self.name, class_of(obj), validator.info(), repr_type(element))
1177 raise TraitError(e)
1179 raise TraitError(e)
1178
1180
1179 def validate(self, obj, value):
1181 def validate(self, obj, value):
1180 value = super(Container, self).validate(obj, value)
1182 value = super(Container, self).validate(obj, value)
1181 if value is None:
1183 if value is None:
1182 return value
1184 return value
1183
1185
1184 value = self.validate_elements(obj, value)
1186 value = self.validate_elements(obj, value)
1185
1187
1186 return value
1188 return value
1187
1189
1188 def validate_elements(self, obj, value):
1190 def validate_elements(self, obj, value):
1189 validated = []
1191 validated = []
1190 if self._trait is None or isinstance(self._trait, Any):
1192 if self._trait is None or isinstance(self._trait, Any):
1191 return value
1193 return value
1192 for v in value:
1194 for v in value:
1193 try:
1195 try:
1194 v = self._trait.validate(obj, v)
1196 v = self._trait.validate(obj, v)
1195 except TraitError:
1197 except TraitError:
1196 self.element_error(obj, v, self._trait)
1198 self.element_error(obj, v, self._trait)
1197 else:
1199 else:
1198 validated.append(v)
1200 validated.append(v)
1199 return self.klass(validated)
1201 return self.klass(validated)
1200
1202
1201
1203
1202 class List(Container):
1204 class List(Container):
1203 """An instance of a Python list."""
1205 """An instance of a Python list."""
1204 klass = list
1206 klass = list
1205
1207
1206 def __init__(self, trait=None, default_value=None, minlen=0, maxlen=sys.maxint,
1208 def __init__(self, trait=None, default_value=None, minlen=0, maxlen=sys.maxint,
1207 allow_none=True, **metadata):
1209 allow_none=True, **metadata):
1208 """Create a List trait type from a list, set, or tuple.
1210 """Create a List trait type from a list, set, or tuple.
1209
1211
1210 The default value is created by doing ``List(default_value)``,
1212 The default value is created by doing ``List(default_value)``,
1211 which creates a copy of the ``default_value``.
1213 which creates a copy of the ``default_value``.
1212
1214
1213 ``trait`` can be specified, which restricts the type of elements
1215 ``trait`` can be specified, which restricts the type of elements
1214 in the container to that TraitType.
1216 in the container to that TraitType.
1215
1217
1216 If only one arg is given and it is not a Trait, it is taken as
1218 If only one arg is given and it is not a Trait, it is taken as
1217 ``default_value``:
1219 ``default_value``:
1218
1220
1219 ``c = List([1,2,3])``
1221 ``c = List([1,2,3])``
1220
1222
1221 Parameters
1223 Parameters
1222 ----------
1224 ----------
1223
1225
1224 trait : TraitType [ optional ]
1226 trait : TraitType [ optional ]
1225 the type for restricting the contents of the Container. If unspecified,
1227 the type for restricting the contents of the Container. If unspecified,
1226 types are not checked.
1228 types are not checked.
1227
1229
1228 default_value : SequenceType [ optional ]
1230 default_value : SequenceType [ optional ]
1229 The default value for the Trait. Must be list/tuple/set, and
1231 The default value for the Trait. Must be list/tuple/set, and
1230 will be cast to the container type.
1232 will be cast to the container type.
1231
1233
1232 minlen : Int [ default 0 ]
1234 minlen : Int [ default 0 ]
1233 The minimum length of the input list
1235 The minimum length of the input list
1234
1236
1235 maxlen : Int [ default sys.maxint ]
1237 maxlen : Int [ default sys.maxint ]
1236 The maximum length of the input list
1238 The maximum length of the input list
1237
1239
1238 allow_none : Bool [ default True ]
1240 allow_none : Bool [ default True ]
1239 Whether to allow the value to be None
1241 Whether to allow the value to be None
1240
1242
1241 **metadata : any
1243 **metadata : any
1242 further keys for extensions to the Trait (e.g. config)
1244 further keys for extensions to the Trait (e.g. config)
1243
1245
1244 """
1246 """
1245 self._minlen = minlen
1247 self._minlen = minlen
1246 self._maxlen = maxlen
1248 self._maxlen = maxlen
1247 super(List, self).__init__(trait=trait, default_value=default_value,
1249 super(List, self).__init__(trait=trait, default_value=default_value,
1248 allow_none=allow_none, **metadata)
1250 allow_none=allow_none, **metadata)
1249
1251
1250 def length_error(self, obj, value):
1252 def length_error(self, obj, value):
1251 e = "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." \
1253 e = "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." \
1252 % (self.name, class_of(obj), self._minlen, self._maxlen, value)
1254 % (self.name, class_of(obj), self._minlen, self._maxlen, value)
1253 raise TraitError(e)
1255 raise TraitError(e)
1254
1256
1255 def validate_elements(self, obj, value):
1257 def validate_elements(self, obj, value):
1256 length = len(value)
1258 length = len(value)
1257 if length < self._minlen or length > self._maxlen:
1259 if length < self._minlen or length > self._maxlen:
1258 self.length_error(obj, value)
1260 self.length_error(obj, value)
1259
1261
1260 return super(List, self).validate_elements(obj, value)
1262 return super(List, self).validate_elements(obj, value)
1261
1263
1262
1264
1263 class Set(Container):
1265 class Set(Container):
1264 """An instance of a Python set."""
1266 """An instance of a Python set."""
1265 klass = set
1267 klass = set
1266
1268
1267 class Tuple(Container):
1269 class Tuple(Container):
1268 """An instance of a Python tuple."""
1270 """An instance of a Python tuple."""
1269 klass = tuple
1271 klass = tuple
1270
1272
1271 def __init__(self, *traits, **metadata):
1273 def __init__(self, *traits, **metadata):
1272 """Tuple(*traits, default_value=None, allow_none=True, **medatata)
1274 """Tuple(*traits, default_value=None, allow_none=True, **medatata)
1273
1275
1274 Create a tuple from a list, set, or tuple.
1276 Create a tuple from a list, set, or tuple.
1275
1277
1276 Create a fixed-type tuple with Traits:
1278 Create a fixed-type tuple with Traits:
1277
1279
1278 ``t = Tuple(Int, Str, CStr)``
1280 ``t = Tuple(Int, Str, CStr)``
1279
1281
1280 would be length 3, with Int,Str,CStr for each element.
1282 would be length 3, with Int,Str,CStr for each element.
1281
1283
1282 If only one arg is given and it is not a Trait, it is taken as
1284 If only one arg is given and it is not a Trait, it is taken as
1283 default_value:
1285 default_value:
1284
1286
1285 ``t = Tuple((1,2,3))``
1287 ``t = Tuple((1,2,3))``
1286
1288
1287 Otherwise, ``default_value`` *must* be specified by keyword.
1289 Otherwise, ``default_value`` *must* be specified by keyword.
1288
1290
1289 Parameters
1291 Parameters
1290 ----------
1292 ----------
1291
1293
1292 *traits : TraitTypes [ optional ]
1294 *traits : TraitTypes [ optional ]
1293 the tsype for restricting the contents of the Tuple. If unspecified,
1295 the tsype for restricting the contents of the Tuple. If unspecified,
1294 types are not checked. If specified, then each positional argument
1296 types are not checked. If specified, then each positional argument
1295 corresponds to an element of the tuple. Tuples defined with traits
1297 corresponds to an element of the tuple. Tuples defined with traits
1296 are of fixed length.
1298 are of fixed length.
1297
1299
1298 default_value : SequenceType [ optional ]
1300 default_value : SequenceType [ optional ]
1299 The default value for the Tuple. Must be list/tuple/set, and
1301 The default value for the Tuple. Must be list/tuple/set, and
1300 will be cast to a tuple. If `traits` are specified, the
1302 will be cast to a tuple. If `traits` are specified, the
1301 `default_value` must conform to the shape and type they specify.
1303 `default_value` must conform to the shape and type they specify.
1302
1304
1303 allow_none : Bool [ default True ]
1305 allow_none : Bool [ default True ]
1304 Whether to allow the value to be None
1306 Whether to allow the value to be None
1305
1307
1306 **metadata : any
1308 **metadata : any
1307 further keys for extensions to the Trait (e.g. config)
1309 further keys for extensions to the Trait (e.g. config)
1308
1310
1309 """
1311 """
1310 default_value = metadata.pop('default_value', None)
1312 default_value = metadata.pop('default_value', None)
1311 allow_none = metadata.pop('allow_none', True)
1313 allow_none = metadata.pop('allow_none', True)
1312
1314
1313 istrait = lambda t: isinstance(t, type) and issubclass(t, TraitType)
1315 istrait = lambda t: isinstance(t, type) and issubclass(t, TraitType)
1314
1316
1315 # allow Tuple((values,)):
1317 # allow Tuple((values,)):
1316 if len(traits) == 1 and default_value is None and not istrait(traits[0]):
1318 if len(traits) == 1 and default_value is None and not istrait(traits[0]):
1317 default_value = traits[0]
1319 default_value = traits[0]
1318 traits = ()
1320 traits = ()
1319
1321
1320 if default_value is None:
1322 if default_value is None:
1321 args = ()
1323 args = ()
1322 elif isinstance(default_value, self._valid_defaults):
1324 elif isinstance(default_value, self._valid_defaults):
1323 args = (default_value,)
1325 args = (default_value,)
1324 else:
1326 else:
1325 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1327 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1326
1328
1327 self._traits = []
1329 self._traits = []
1328 for trait in traits:
1330 for trait in traits:
1329 t = trait()
1331 t = trait()
1330 t.name = 'element'
1332 t.name = 'element'
1331 self._traits.append(t)
1333 self._traits.append(t)
1332
1334
1333 if self._traits and default_value is None:
1335 if self._traits and default_value is None:
1334 # don't allow default to be an empty container if length is specified
1336 # don't allow default to be an empty container if length is specified
1335 args = None
1337 args = None
1336 super(Container,self).__init__(klass=self.klass, args=args,
1338 super(Container,self).__init__(klass=self.klass, args=args,
1337 allow_none=allow_none, **metadata)
1339 allow_none=allow_none, **metadata)
1338
1340
1339 def validate_elements(self, obj, value):
1341 def validate_elements(self, obj, value):
1340 if not self._traits:
1342 if not self._traits:
1341 # nothing to validate
1343 # nothing to validate
1342 return value
1344 return value
1343 if len(value) != len(self._traits):
1345 if len(value) != len(self._traits):
1344 e = "The '%s' trait of %s instance requires %i elements, but a value of %s was specified." \
1346 e = "The '%s' trait of %s instance requires %i elements, but a value of %s was specified." \
1345 % (self.name, class_of(obj), len(self._traits), repr_type(value))
1347 % (self.name, class_of(obj), len(self._traits), repr_type(value))
1346 raise TraitError(e)
1348 raise TraitError(e)
1347
1349
1348 validated = []
1350 validated = []
1349 for t,v in zip(self._traits, value):
1351 for t,v in zip(self._traits, value):
1350 try:
1352 try:
1351 v = t.validate(obj, v)
1353 v = t.validate(obj, v)
1352 except TraitError:
1354 except TraitError:
1353 self.element_error(obj, v, t)
1355 self.element_error(obj, v, t)
1354 else:
1356 else:
1355 validated.append(v)
1357 validated.append(v)
1356 return tuple(validated)
1358 return tuple(validated)
1357
1359
1358
1360
1359 class Dict(Instance):
1361 class Dict(Instance):
1360 """An instance of a Python dict."""
1362 """An instance of a Python dict."""
1361
1363
1362 def __init__(self, default_value=None, allow_none=True, **metadata):
1364 def __init__(self, default_value=None, allow_none=True, **metadata):
1363 """Create a dict trait type from a dict.
1365 """Create a dict trait type from a dict.
1364
1366
1365 The default value is created by doing ``dict(default_value)``,
1367 The default value is created by doing ``dict(default_value)``,
1366 which creates a copy of the ``default_value``.
1368 which creates a copy of the ``default_value``.
1367 """
1369 """
1368 if default_value is None:
1370 if default_value is None:
1369 args = ((),)
1371 args = ((),)
1370 elif isinstance(default_value, dict):
1372 elif isinstance(default_value, dict):
1371 args = (default_value,)
1373 args = (default_value,)
1372 elif isinstance(default_value, SequenceTypes):
1374 elif isinstance(default_value, SequenceTypes):
1373 args = (default_value,)
1375 args = (default_value,)
1374 else:
1376 else:
1375 raise TypeError('default value of Dict was %s' % default_value)
1377 raise TypeError('default value of Dict was %s' % default_value)
1376
1378
1377 super(Dict,self).__init__(klass=dict, args=args,
1379 super(Dict,self).__init__(klass=dict, args=args,
1378 allow_none=allow_none, **metadata)
1380 allow_none=allow_none, **metadata)
1379
1381
1380 class TCPAddress(TraitType):
1382 class TCPAddress(TraitType):
1381 """A trait for an (ip, port) tuple.
1383 """A trait for an (ip, port) tuple.
1382
1384
1383 This allows for both IPv4 IP addresses as well as hostnames.
1385 This allows for both IPv4 IP addresses as well as hostnames.
1384 """
1386 """
1385
1387
1386 default_value = ('127.0.0.1', 0)
1388 default_value = ('127.0.0.1', 0)
1387 info_text = 'an (ip, port) tuple'
1389 info_text = 'an (ip, port) tuple'
1388
1390
1389 def validate(self, obj, value):
1391 def validate(self, obj, value):
1390 if isinstance(value, tuple):
1392 if isinstance(value, tuple):
1391 if len(value) == 2:
1393 if len(value) == 2:
1392 if isinstance(value[0], basestring) and isinstance(value[1], int):
1394 if isinstance(value[0], basestring) and isinstance(value[1], int):
1393 port = value[1]
1395 port = value[1]
1394 if port >= 0 and port <= 65535:
1396 if port >= 0 and port <= 65535:
1395 return value
1397 return value
1396 self.error(obj, value)
1398 self.error(obj, value)
@@ -1,703 +1,703 b''
1 """Session object for building, serializing, sending, and receiving messages in
1 """Session object for building, serializing, sending, and receiving messages in
2 IPython. The Session object supports serialization, HMAC signatures, and
2 IPython. The Session object supports serialization, HMAC signatures, and
3 metadata on messages.
3 metadata on messages.
4
4
5 Also defined here are utilities for working with Sessions:
5 Also defined here are utilities for working with Sessions:
6 * A SessionFactory to be used as a base class for configurables that work with
6 * A SessionFactory to be used as a base class for configurables that work with
7 Sessions.
7 Sessions.
8 * A Message object for convenience that allows attribute-access to the msg dict.
8 * A Message object for convenience that allows attribute-access to the msg dict.
9
9
10 Authors:
10 Authors:
11
11
12 * Min RK
12 * Min RK
13 * Brian Granger
13 * Brian Granger
14 * Fernando Perez
14 * Fernando Perez
15 """
15 """
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17 # Copyright (C) 2010-2011 The IPython Development Team
17 # Copyright (C) 2010-2011 The IPython Development Team
18 #
18 #
19 # Distributed under the terms of the BSD License. The full license is in
19 # Distributed under the terms of the BSD License. The full license is in
20 # the file COPYING, distributed as part of this software.
20 # the file COPYING, distributed as part of this software.
21 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
22
22
23 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
24 # Imports
24 # Imports
25 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
26
26
27 import hmac
27 import hmac
28 import logging
28 import logging
29 import os
29 import os
30 import pprint
30 import pprint
31 import uuid
31 import uuid
32 from datetime import datetime
32 from datetime import datetime
33
33
34 try:
34 try:
35 import cPickle
35 import cPickle
36 pickle = cPickle
36 pickle = cPickle
37 except:
37 except:
38 cPickle = None
38 cPickle = None
39 import pickle
39 import pickle
40
40
41 import zmq
41 import zmq
42 from zmq.utils import jsonapi
42 from zmq.utils import jsonapi
43 from zmq.eventloop.ioloop import IOLoop
43 from zmq.eventloop.ioloop import IOLoop
44 from zmq.eventloop.zmqstream import ZMQStream
44 from zmq.eventloop.zmqstream import ZMQStream
45
45
46 from IPython.config.configurable import Configurable, LoggingConfigurable
46 from IPython.config.configurable import Configurable, LoggingConfigurable
47 from IPython.utils.importstring import import_item
47 from IPython.utils.importstring import import_item
48 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
48 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
49 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
49 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
50 DottedObjectName)
50 DottedObjectName)
51
51
52 #-----------------------------------------------------------------------------
52 #-----------------------------------------------------------------------------
53 # utility functions
53 # utility functions
54 #-----------------------------------------------------------------------------
54 #-----------------------------------------------------------------------------
55
55
56 def squash_unicode(obj):
56 def squash_unicode(obj):
57 """coerce unicode back to bytestrings."""
57 """coerce unicode back to bytestrings."""
58 if isinstance(obj,dict):
58 if isinstance(obj,dict):
59 for key in obj.keys():
59 for key in obj.keys():
60 obj[key] = squash_unicode(obj[key])
60 obj[key] = squash_unicode(obj[key])
61 if isinstance(key, unicode):
61 if isinstance(key, unicode):
62 obj[squash_unicode(key)] = obj.pop(key)
62 obj[squash_unicode(key)] = obj.pop(key)
63 elif isinstance(obj, list):
63 elif isinstance(obj, list):
64 for i,v in enumerate(obj):
64 for i,v in enumerate(obj):
65 obj[i] = squash_unicode(v)
65 obj[i] = squash_unicode(v)
66 elif isinstance(obj, unicode):
66 elif isinstance(obj, unicode):
67 obj = obj.encode('utf8')
67 obj = obj.encode('utf8')
68 return obj
68 return obj
69
69
70 #-----------------------------------------------------------------------------
70 #-----------------------------------------------------------------------------
71 # globals and defaults
71 # globals and defaults
72 #-----------------------------------------------------------------------------
72 #-----------------------------------------------------------------------------
73 key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
73 key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
74 json_packer = lambda obj: jsonapi.dumps(obj, **{key:date_default})
74 json_packer = lambda obj: jsonapi.dumps(obj, **{key:date_default})
75 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
75 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
76
76
77 pickle_packer = lambda o: pickle.dumps(o,-1)
77 pickle_packer = lambda o: pickle.dumps(o,-1)
78 pickle_unpacker = pickle.loads
78 pickle_unpacker = pickle.loads
79
79
80 default_packer = json_packer
80 default_packer = json_packer
81 default_unpacker = json_unpacker
81 default_unpacker = json_unpacker
82
82
83
83
84 DELIM=b"<IDS|MSG>"
84 DELIM=b"<IDS|MSG>"
85
85
86 #-----------------------------------------------------------------------------
86 #-----------------------------------------------------------------------------
87 # Classes
87 # Classes
88 #-----------------------------------------------------------------------------
88 #-----------------------------------------------------------------------------
89
89
90 class SessionFactory(LoggingConfigurable):
90 class SessionFactory(LoggingConfigurable):
91 """The Base class for configurables that have a Session, Context, logger,
91 """The Base class for configurables that have a Session, Context, logger,
92 and IOLoop.
92 and IOLoop.
93 """
93 """
94
94
95 logname = Unicode('')
95 logname = Unicode('')
96 def _logname_changed(self, name, old, new):
96 def _logname_changed(self, name, old, new):
97 self.log = logging.getLogger(new)
97 self.log = logging.getLogger(new)
98
98
99 # not configurable:
99 # not configurable:
100 context = Instance('zmq.Context')
100 context = Instance('zmq.Context')
101 def _context_default(self):
101 def _context_default(self):
102 return zmq.Context.instance()
102 return zmq.Context.instance()
103
103
104 session = Instance('IPython.zmq.session.Session')
104 session = Instance('IPython.zmq.session.Session')
105
105
106 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
106 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
107 def _loop_default(self):
107 def _loop_default(self):
108 return IOLoop.instance()
108 return IOLoop.instance()
109
109
110 def __init__(self, **kwargs):
110 def __init__(self, **kwargs):
111 super(SessionFactory, self).__init__(**kwargs)
111 super(SessionFactory, self).__init__(**kwargs)
112
112
113 if self.session is None:
113 if self.session is None:
114 # construct the session
114 # construct the session
115 self.session = Session(**kwargs)
115 self.session = Session(**kwargs)
116
116
117
117
118 class Message(object):
118 class Message(object):
119 """A simple message object that maps dict keys to attributes.
119 """A simple message object that maps dict keys to attributes.
120
120
121 A Message can be created from a dict and a dict from a Message instance
121 A Message can be created from a dict and a dict from a Message instance
122 simply by calling dict(msg_obj)."""
122 simply by calling dict(msg_obj)."""
123
123
124 def __init__(self, msg_dict):
124 def __init__(self, msg_dict):
125 dct = self.__dict__
125 dct = self.__dict__
126 for k, v in dict(msg_dict).iteritems():
126 for k, v in dict(msg_dict).iteritems():
127 if isinstance(v, dict):
127 if isinstance(v, dict):
128 v = Message(v)
128 v = Message(v)
129 dct[k] = v
129 dct[k] = v
130
130
131 # Having this iterator lets dict(msg_obj) work out of the box.
131 # Having this iterator lets dict(msg_obj) work out of the box.
132 def __iter__(self):
132 def __iter__(self):
133 return iter(self.__dict__.iteritems())
133 return iter(self.__dict__.iteritems())
134
134
135 def __repr__(self):
135 def __repr__(self):
136 return repr(self.__dict__)
136 return repr(self.__dict__)
137
137
138 def __str__(self):
138 def __str__(self):
139 return pprint.pformat(self.__dict__)
139 return pprint.pformat(self.__dict__)
140
140
141 def __contains__(self, k):
141 def __contains__(self, k):
142 return k in self.__dict__
142 return k in self.__dict__
143
143
144 def __getitem__(self, k):
144 def __getitem__(self, k):
145 return self.__dict__[k]
145 return self.__dict__[k]
146
146
147
147
148 def msg_header(msg_id, msg_type, username, session):
148 def msg_header(msg_id, msg_type, username, session):
149 date = datetime.now()
149 date = datetime.now()
150 return locals()
150 return locals()
151
151
152 def extract_header(msg_or_header):
152 def extract_header(msg_or_header):
153 """Given a message or header, return the header."""
153 """Given a message or header, return the header."""
154 if not msg_or_header:
154 if not msg_or_header:
155 return {}
155 return {}
156 try:
156 try:
157 # See if msg_or_header is the entire message.
157 # See if msg_or_header is the entire message.
158 h = msg_or_header['header']
158 h = msg_or_header['header']
159 except KeyError:
159 except KeyError:
160 try:
160 try:
161 # See if msg_or_header is just the header
161 # See if msg_or_header is just the header
162 h = msg_or_header['msg_id']
162 h = msg_or_header['msg_id']
163 except KeyError:
163 except KeyError:
164 raise
164 raise
165 else:
165 else:
166 h = msg_or_header
166 h = msg_or_header
167 if not isinstance(h, dict):
167 if not isinstance(h, dict):
168 h = dict(h)
168 h = dict(h)
169 return h
169 return h
170
170
171 class Session(Configurable):
171 class Session(Configurable):
172 """Object for handling serialization and sending of messages.
172 """Object for handling serialization and sending of messages.
173
173
174 The Session object handles building messages and sending them
174 The Session object handles building messages and sending them
175 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
175 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
176 other over the network via Session objects, and only need to work with the
176 other over the network via Session objects, and only need to work with the
177 dict-based IPython message spec. The Session will handle
177 dict-based IPython message spec. The Session will handle
178 serialization/deserialization, security, and metadata.
178 serialization/deserialization, security, and metadata.
179
179
180 Sessions support configurable serialiization via packer/unpacker traits,
180 Sessions support configurable serialiization via packer/unpacker traits,
181 and signing with HMAC digests via the key/keyfile traits.
181 and signing with HMAC digests via the key/keyfile traits.
182
182
183 Parameters
183 Parameters
184 ----------
184 ----------
185
185
186 debug : bool
186 debug : bool
187 whether to trigger extra debugging statements
187 whether to trigger extra debugging statements
188 packer/unpacker : str : 'json', 'pickle' or import_string
188 packer/unpacker : str : 'json', 'pickle' or import_string
189 importstrings for methods to serialize message parts. If just
189 importstrings for methods to serialize message parts. If just
190 'json' or 'pickle', predefined JSON and pickle packers will be used.
190 'json' or 'pickle', predefined JSON and pickle packers will be used.
191 Otherwise, the entire importstring must be used.
191 Otherwise, the entire importstring must be used.
192
192
193 The functions must accept at least valid JSON input, and output *bytes*.
193 The functions must accept at least valid JSON input, and output *bytes*.
194
194
195 For example, to use msgpack:
195 For example, to use msgpack:
196 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
196 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
197 pack/unpack : callables
197 pack/unpack : callables
198 You can also set the pack/unpack callables for serialization directly.
198 You can also set the pack/unpack callables for serialization directly.
199 session : bytes
199 session : bytes
200 the ID of this Session object. The default is to generate a new UUID.
200 the ID of this Session object. The default is to generate a new UUID.
201 username : unicode
201 username : unicode
202 username added to message headers. The default is to ask the OS.
202 username added to message headers. The default is to ask the OS.
203 key : bytes
203 key : bytes
204 The key used to initialize an HMAC signature. If unset, messages
204 The key used to initialize an HMAC signature. If unset, messages
205 will not be signed or checked.
205 will not be signed or checked.
206 keyfile : filepath
206 keyfile : filepath
207 The file containing a key. If this is set, `key` will be initialized
207 The file containing a key. If this is set, `key` will be initialized
208 to the contents of the file.
208 to the contents of the file.
209
209
210 """
210 """
211
211
212 debug=Bool(False, config=True, help="""Debug output in the Session""")
212 debug=Bool(False, config=True, help="""Debug output in the Session""")
213
213
214 packer = DottedObjectName('json',config=True,
214 packer = DottedObjectName('json',config=True,
215 help="""The name of the packer for serializing messages.
215 help="""The name of the packer for serializing messages.
216 Should be one of 'json', 'pickle', or an import name
216 Should be one of 'json', 'pickle', or an import name
217 for a custom callable serializer.""")
217 for a custom callable serializer.""")
218 def _packer_changed(self, name, old, new):
218 def _packer_changed(self, name, old, new):
219 if new.lower() == 'json':
219 if new.lower() == 'json':
220 self.pack = json_packer
220 self.pack = json_packer
221 self.unpack = json_unpacker
221 self.unpack = json_unpacker
222 elif new.lower() == 'pickle':
222 elif new.lower() == 'pickle':
223 self.pack = pickle_packer
223 self.pack = pickle_packer
224 self.unpack = pickle_unpacker
224 self.unpack = pickle_unpacker
225 else:
225 else:
226 self.pack = import_item(str(new))
226 self.pack = import_item(str(new))
227
227
228 unpacker = DottedObjectName('json', config=True,
228 unpacker = DottedObjectName('json', config=True,
229 help="""The name of the unpacker for unserializing messages.
229 help="""The name of the unpacker for unserializing messages.
230 Only used with custom functions for `packer`.""")
230 Only used with custom functions for `packer`.""")
231 def _unpacker_changed(self, name, old, new):
231 def _unpacker_changed(self, name, old, new):
232 if new.lower() == 'json':
232 if new.lower() == 'json':
233 self.pack = json_packer
233 self.pack = json_packer
234 self.unpack = json_unpacker
234 self.unpack = json_unpacker
235 elif new.lower() == 'pickle':
235 elif new.lower() == 'pickle':
236 self.pack = pickle_packer
236 self.pack = pickle_packer
237 self.unpack = pickle_unpacker
237 self.unpack = pickle_unpacker
238 else:
238 else:
239 self.unpack = import_item(str(new))
239 self.unpack = import_item(str(new))
240
240
241 session = CBytes(b'', config=True,
241 session = CBytes(b'', config=True,
242 help="""The UUID identifying this session.""")
242 help="""The UUID identifying this session.""")
243 def _session_default(self):
243 def _session_default(self):
244 return bytes(uuid.uuid4())
244 return bytes(uuid.uuid4())
245
245
246 username = Unicode(os.environ.get('USER',u'username'), config=True,
246 username = Unicode(os.environ.get('USER',u'username'), config=True,
247 help="""Username for the Session. Default is your system username.""")
247 help="""Username for the Session. Default is your system username.""")
248
248
249 # message signature related traits:
249 # message signature related traits:
250 key = CBytes(b'', config=True,
250 key = CBytes(b'', config=True,
251 help="""execution key, for extra authentication.""")
251 help="""execution key, for extra authentication.""")
252 def _key_changed(self, name, old, new):
252 def _key_changed(self, name, old, new):
253 if new:
253 if new:
254 self.auth = hmac.HMAC(new)
254 self.auth = hmac.HMAC(new)
255 else:
255 else:
256 self.auth = None
256 self.auth = None
257 auth = Instance(hmac.HMAC)
257 auth = Instance(hmac.HMAC)
258 digest_history = Set()
258 digest_history = Set()
259
259
260 keyfile = Unicode('', config=True,
260 keyfile = Unicode('', config=True,
261 help="""path to file containing execution key.""")
261 help="""path to file containing execution key.""")
262 def _keyfile_changed(self, name, old, new):
262 def _keyfile_changed(self, name, old, new):
263 with open(new, 'rb') as f:
263 with open(new, 'rb') as f:
264 self.key = f.read().strip()
264 self.key = f.read().strip()
265
265
266 pack = Any(default_packer) # the actual packer function
266 pack = Any(default_packer) # the actual packer function
267 def _pack_changed(self, name, old, new):
267 def _pack_changed(self, name, old, new):
268 if not callable(new):
268 if not callable(new):
269 raise TypeError("packer must be callable, not %s"%type(new))
269 raise TypeError("packer must be callable, not %s"%type(new))
270
270
271 unpack = Any(default_unpacker) # the actual packer function
271 unpack = Any(default_unpacker) # the actual packer function
272 def _unpack_changed(self, name, old, new):
272 def _unpack_changed(self, name, old, new):
273 # unpacker is not checked - it is assumed to be
273 # unpacker is not checked - it is assumed to be
274 if not callable(new):
274 if not callable(new):
275 raise TypeError("unpacker must be callable, not %s"%type(new))
275 raise TypeError("unpacker must be callable, not %s"%type(new))
276
276
277 def __init__(self, **kwargs):
277 def __init__(self, **kwargs):
278 """create a Session object
278 """create a Session object
279
279
280 Parameters
280 Parameters
281 ----------
281 ----------
282
282
283 debug : bool
283 debug : bool
284 whether to trigger extra debugging statements
284 whether to trigger extra debugging statements
285 packer/unpacker : str : 'json', 'pickle' or import_string
285 packer/unpacker : str : 'json', 'pickle' or import_string
286 importstrings for methods to serialize message parts. If just
286 importstrings for methods to serialize message parts. If just
287 'json' or 'pickle', predefined JSON and pickle packers will be used.
287 'json' or 'pickle', predefined JSON and pickle packers will be used.
288 Otherwise, the entire importstring must be used.
288 Otherwise, the entire importstring must be used.
289
289
290 The functions must accept at least valid JSON input, and output
290 The functions must accept at least valid JSON input, and output
291 *bytes*.
291 *bytes*.
292
292
293 For example, to use msgpack:
293 For example, to use msgpack:
294 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
294 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
295 pack/unpack : callables
295 pack/unpack : callables
296 You can also set the pack/unpack callables for serialization
296 You can also set the pack/unpack callables for serialization
297 directly.
297 directly.
298 session : bytes
298 session : bytes
299 the ID of this Session object. The default is to generate a new
299 the ID of this Session object. The default is to generate a new
300 UUID.
300 UUID.
301 username : unicode
301 username : unicode
302 username added to message headers. The default is to ask the OS.
302 username added to message headers. The default is to ask the OS.
303 key : bytes
303 key : bytes
304 The key used to initialize an HMAC signature. If unset, messages
304 The key used to initialize an HMAC signature. If unset, messages
305 will not be signed or checked.
305 will not be signed or checked.
306 keyfile : filepath
306 keyfile : filepath
307 The file containing a key. If this is set, `key` will be
307 The file containing a key. If this is set, `key` will be
308 initialized to the contents of the file.
308 initialized to the contents of the file.
309 """
309 """
310 super(Session, self).__init__(**kwargs)
310 super(Session, self).__init__(**kwargs)
311 self._check_packers()
311 self._check_packers()
312 self.none = self.pack({})
312 self.none = self.pack({})
313
313
314 @property
314 @property
315 def msg_id(self):
315 def msg_id(self):
316 """always return new uuid"""
316 """always return new uuid"""
317 return str(uuid.uuid4())
317 return str(uuid.uuid4())
318
318
319 def _check_packers(self):
319 def _check_packers(self):
320 """check packers for binary data and datetime support."""
320 """check packers for binary data and datetime support."""
321 pack = self.pack
321 pack = self.pack
322 unpack = self.unpack
322 unpack = self.unpack
323
323
324 # check simple serialization
324 # check simple serialization
325 msg = dict(a=[1,'hi'])
325 msg = dict(a=[1,'hi'])
326 try:
326 try:
327 packed = pack(msg)
327 packed = pack(msg)
328 except Exception:
328 except Exception:
329 raise ValueError("packer could not serialize a simple message")
329 raise ValueError("packer could not serialize a simple message")
330
330
331 # ensure packed message is bytes
331 # ensure packed message is bytes
332 if not isinstance(packed, bytes):
332 if not isinstance(packed, bytes):
333 raise ValueError("message packed to %r, but bytes are required"%type(packed))
333 raise ValueError("message packed to %r, but bytes are required"%type(packed))
334
334
335 # check that unpack is pack's inverse
335 # check that unpack is pack's inverse
336 try:
336 try:
337 unpacked = unpack(packed)
337 unpacked = unpack(packed)
338 except Exception:
338 except Exception:
339 raise ValueError("unpacker could not handle the packer's output")
339 raise ValueError("unpacker could not handle the packer's output")
340
340
341 # check datetime support
341 # check datetime support
342 msg = dict(t=datetime.now())
342 msg = dict(t=datetime.now())
343 try:
343 try:
344 unpacked = unpack(pack(msg))
344 unpacked = unpack(pack(msg))
345 except Exception:
345 except Exception:
346 self.pack = lambda o: pack(squash_dates(o))
346 self.pack = lambda o: pack(squash_dates(o))
347 self.unpack = lambda s: extract_dates(unpack(s))
347 self.unpack = lambda s: extract_dates(unpack(s))
348
348
349 def msg_header(self, msg_type):
349 def msg_header(self, msg_type):
350 return msg_header(self.msg_id, msg_type, self.username, self.session)
350 return msg_header(self.msg_id, msg_type, self.username, self.session)
351
351
352 def msg(self, msg_type, content=None, parent=None, subheader=None, header=None):
352 def msg(self, msg_type, content=None, parent=None, subheader=None, header=None):
353 """Return the nested message dict.
353 """Return the nested message dict.
354
354
355 This format is different from what is sent over the wire. The
355 This format is different from what is sent over the wire. The
356 serialize/unserialize methods converts this nested message dict to the wire
356 serialize/unserialize methods converts this nested message dict to the wire
357 format, which is a list of message parts.
357 format, which is a list of message parts.
358 """
358 """
359 msg = {}
359 msg = {}
360 header = self.msg_header(msg_type) if header is None else header
360 header = self.msg_header(msg_type) if header is None else header
361 msg['header'] = header
361 msg['header'] = header
362 msg['msg_id'] = header['msg_id']
362 msg['msg_id'] = header['msg_id']
363 msg['msg_type'] = header['msg_type']
363 msg['msg_type'] = header['msg_type']
364 msg['parent_header'] = {} if parent is None else extract_header(parent)
364 msg['parent_header'] = {} if parent is None else extract_header(parent)
365 msg['content'] = {} if content is None else content
365 msg['content'] = {} if content is None else content
366 sub = {} if subheader is None else subheader
366 sub = {} if subheader is None else subheader
367 msg['header'].update(sub)
367 msg['header'].update(sub)
368 return msg
368 return msg
369
369
370 def sign(self, msg_list):
370 def sign(self, msg_list):
371 """Sign a message with HMAC digest. If no auth, return b''.
371 """Sign a message with HMAC digest. If no auth, return b''.
372
372
373 Parameters
373 Parameters
374 ----------
374 ----------
375 msg_list : list
375 msg_list : list
376 The [p_header,p_parent,p_content] part of the message list.
376 The [p_header,p_parent,p_content] part of the message list.
377 """
377 """
378 if self.auth is None:
378 if self.auth is None:
379 return b''
379 return b''
380 h = self.auth.copy()
380 h = self.auth.copy()
381 for m in msg_list:
381 for m in msg_list:
382 h.update(m)
382 h.update(m)
383 return h.hexdigest()
383 return str_to_bytes(h.hexdigest())
384
384
385 def serialize(self, msg, ident=None):
385 def serialize(self, msg, ident=None):
386 """Serialize the message components to bytes.
386 """Serialize the message components to bytes.
387
387
388 This is roughly the inverse of unserialize. The serialize/unserialize
388 This is roughly the inverse of unserialize. The serialize/unserialize
389 methods work with full message lists, whereas pack/unpack work with
389 methods work with full message lists, whereas pack/unpack work with
390 the individual message parts in the message list.
390 the individual message parts in the message list.
391
391
392 Parameters
392 Parameters
393 ----------
393 ----------
394 msg : dict or Message
394 msg : dict or Message
395 The nexted message dict as returned by the self.msg method.
395 The nexted message dict as returned by the self.msg method.
396
396
397 Returns
397 Returns
398 -------
398 -------
399 msg_list : list
399 msg_list : list
400 The list of bytes objects to be sent with the format:
400 The list of bytes objects to be sent with the format:
401 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
401 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
402 buffer1,buffer2,...]. In this list, the p_* entities are
402 buffer1,buffer2,...]. In this list, the p_* entities are
403 the packed or serialized versions, so if JSON is used, these
403 the packed or serialized versions, so if JSON is used, these
404 are uft8 encoded JSON strings.
404 are uft8 encoded JSON strings.
405 """
405 """
406 content = msg.get('content', {})
406 content = msg.get('content', {})
407 if content is None:
407 if content is None:
408 content = self.none
408 content = self.none
409 elif isinstance(content, dict):
409 elif isinstance(content, dict):
410 content = self.pack(content)
410 content = self.pack(content)
411 elif isinstance(content, bytes):
411 elif isinstance(content, bytes):
412 # content is already packed, as in a relayed message
412 # content is already packed, as in a relayed message
413 pass
413 pass
414 elif isinstance(content, unicode):
414 elif isinstance(content, unicode):
415 # should be bytes, but JSON often spits out unicode
415 # should be bytes, but JSON often spits out unicode
416 content = content.encode('utf8')
416 content = content.encode('utf8')
417 else:
417 else:
418 raise TypeError("Content incorrect type: %s"%type(content))
418 raise TypeError("Content incorrect type: %s"%type(content))
419
419
420 real_message = [self.pack(msg['header']),
420 real_message = [self.pack(msg['header']),
421 self.pack(msg['parent_header']),
421 self.pack(msg['parent_header']),
422 content
422 content
423 ]
423 ]
424
424
425 to_send = []
425 to_send = []
426
426
427 if isinstance(ident, list):
427 if isinstance(ident, list):
428 # accept list of idents
428 # accept list of idents
429 to_send.extend(ident)
429 to_send.extend(ident)
430 elif ident is not None:
430 elif ident is not None:
431 to_send.append(ident)
431 to_send.append(ident)
432 to_send.append(DELIM)
432 to_send.append(DELIM)
433
433
434 signature = self.sign(real_message)
434 signature = self.sign(real_message)
435 to_send.append(signature)
435 to_send.append(signature)
436
436
437 to_send.extend(real_message)
437 to_send.extend(real_message)
438
438
439 return to_send
439 return to_send
440
440
441 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
441 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
442 buffers=None, subheader=None, track=False, header=None):
442 buffers=None, subheader=None, track=False, header=None):
443 """Build and send a message via stream or socket.
443 """Build and send a message via stream or socket.
444
444
445 The message format used by this function internally is as follows:
445 The message format used by this function internally is as follows:
446
446
447 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
447 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
448 buffer1,buffer2,...]
448 buffer1,buffer2,...]
449
449
450 The serialize/unserialize methods convert the nested message dict into this
450 The serialize/unserialize methods convert the nested message dict into this
451 format.
451 format.
452
452
453 Parameters
453 Parameters
454 ----------
454 ----------
455
455
456 stream : zmq.Socket or ZMQStream
456 stream : zmq.Socket or ZMQStream
457 The socket-like object used to send the data.
457 The socket-like object used to send the data.
458 msg_or_type : str or Message/dict
458 msg_or_type : str or Message/dict
459 Normally, msg_or_type will be a msg_type unless a message is being
459 Normally, msg_or_type will be a msg_type unless a message is being
460 sent more than once. If a header is supplied, this can be set to
460 sent more than once. If a header is supplied, this can be set to
461 None and the msg_type will be pulled from the header.
461 None and the msg_type will be pulled from the header.
462
462
463 content : dict or None
463 content : dict or None
464 The content of the message (ignored if msg_or_type is a message).
464 The content of the message (ignored if msg_or_type is a message).
465 header : dict or None
465 header : dict or None
466 The header dict for the message (ignores if msg_to_type is a message).
466 The header dict for the message (ignores if msg_to_type is a message).
467 parent : Message or dict or None
467 parent : Message or dict or None
468 The parent or parent header describing the parent of this message
468 The parent or parent header describing the parent of this message
469 (ignored if msg_or_type is a message).
469 (ignored if msg_or_type is a message).
470 ident : bytes or list of bytes
470 ident : bytes or list of bytes
471 The zmq.IDENTITY routing path.
471 The zmq.IDENTITY routing path.
472 subheader : dict or None
472 subheader : dict or None
473 Extra header keys for this message's header (ignored if msg_or_type
473 Extra header keys for this message's header (ignored if msg_or_type
474 is a message).
474 is a message).
475 buffers : list or None
475 buffers : list or None
476 The already-serialized buffers to be appended to the message.
476 The already-serialized buffers to be appended to the message.
477 track : bool
477 track : bool
478 Whether to track. Only for use with Sockets, because ZMQStream
478 Whether to track. Only for use with Sockets, because ZMQStream
479 objects cannot track messages.
479 objects cannot track messages.
480
480
481 Returns
481 Returns
482 -------
482 -------
483 msg : dict
483 msg : dict
484 The constructed message.
484 The constructed message.
485 (msg,tracker) : (dict, MessageTracker)
485 (msg,tracker) : (dict, MessageTracker)
486 if track=True, then a 2-tuple will be returned,
486 if track=True, then a 2-tuple will be returned,
487 the first element being the constructed
487 the first element being the constructed
488 message, and the second being the MessageTracker
488 message, and the second being the MessageTracker
489
489
490 """
490 """
491
491
492 if not isinstance(stream, (zmq.Socket, ZMQStream)):
492 if not isinstance(stream, (zmq.Socket, ZMQStream)):
493 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
493 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
494 elif track and isinstance(stream, ZMQStream):
494 elif track and isinstance(stream, ZMQStream):
495 raise TypeError("ZMQStream cannot track messages")
495 raise TypeError("ZMQStream cannot track messages")
496
496
497 if isinstance(msg_or_type, (Message, dict)):
497 if isinstance(msg_or_type, (Message, dict)):
498 # We got a Message or message dict, not a msg_type so don't
498 # We got a Message or message dict, not a msg_type so don't
499 # build a new Message.
499 # build a new Message.
500 msg = msg_or_type
500 msg = msg_or_type
501 else:
501 else:
502 msg = self.msg(msg_or_type, content=content, parent=parent,
502 msg = self.msg(msg_or_type, content=content, parent=parent,
503 subheader=subheader, header=header)
503 subheader=subheader, header=header)
504
504
505 buffers = [] if buffers is None else buffers
505 buffers = [] if buffers is None else buffers
506 to_send = self.serialize(msg, ident)
506 to_send = self.serialize(msg, ident)
507 flag = 0
507 flag = 0
508 if buffers:
508 if buffers:
509 flag = zmq.SNDMORE
509 flag = zmq.SNDMORE
510 _track = False
510 _track = False
511 else:
511 else:
512 _track=track
512 _track=track
513 if track:
513 if track:
514 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
514 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
515 else:
515 else:
516 tracker = stream.send_multipart(to_send, flag, copy=False)
516 tracker = stream.send_multipart(to_send, flag, copy=False)
517 for b in buffers[:-1]:
517 for b in buffers[:-1]:
518 stream.send(b, flag, copy=False)
518 stream.send(b, flag, copy=False)
519 if buffers:
519 if buffers:
520 if track:
520 if track:
521 tracker = stream.send(buffers[-1], copy=False, track=track)
521 tracker = stream.send(buffers[-1], copy=False, track=track)
522 else:
522 else:
523 tracker = stream.send(buffers[-1], copy=False)
523 tracker = stream.send(buffers[-1], copy=False)
524
524
525 # omsg = Message(msg)
525 # omsg = Message(msg)
526 if self.debug:
526 if self.debug:
527 pprint.pprint(msg)
527 pprint.pprint(msg)
528 pprint.pprint(to_send)
528 pprint.pprint(to_send)
529 pprint.pprint(buffers)
529 pprint.pprint(buffers)
530
530
531 msg['tracker'] = tracker
531 msg['tracker'] = tracker
532
532
533 return msg
533 return msg
534
534
535 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
535 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
536 """Send a raw message via ident path.
536 """Send a raw message via ident path.
537
537
538 This method is used to send a already serialized message.
538 This method is used to send a already serialized message.
539
539
540 Parameters
540 Parameters
541 ----------
541 ----------
542 stream : ZMQStream or Socket
542 stream : ZMQStream or Socket
543 The ZMQ stream or socket to use for sending the message.
543 The ZMQ stream or socket to use for sending the message.
544 msg_list : list
544 msg_list : list
545 The serialized list of messages to send. This only includes the
545 The serialized list of messages to send. This only includes the
546 [p_header,p_parent,p_content,buffer1,buffer2,...] portion of
546 [p_header,p_parent,p_content,buffer1,buffer2,...] portion of
547 the message.
547 the message.
548 ident : ident or list
548 ident : ident or list
549 A single ident or a list of idents to use in sending.
549 A single ident or a list of idents to use in sending.
550 """
550 """
551 to_send = []
551 to_send = []
552 if isinstance(ident, bytes):
552 if isinstance(ident, bytes):
553 ident = [ident]
553 ident = [ident]
554 if ident is not None:
554 if ident is not None:
555 to_send.extend(ident)
555 to_send.extend(ident)
556
556
557 to_send.append(DELIM)
557 to_send.append(DELIM)
558 to_send.append(self.sign(msg_list))
558 to_send.append(self.sign(msg_list))
559 to_send.extend(msg_list)
559 to_send.extend(msg_list)
560 stream.send_multipart(msg_list, flags, copy=copy)
560 stream.send_multipart(msg_list, flags, copy=copy)
561
561
562 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
562 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
563 """Receive and unpack a message.
563 """Receive and unpack a message.
564
564
565 Parameters
565 Parameters
566 ----------
566 ----------
567 socket : ZMQStream or Socket
567 socket : ZMQStream or Socket
568 The socket or stream to use in receiving.
568 The socket or stream to use in receiving.
569
569
570 Returns
570 Returns
571 -------
571 -------
572 [idents], msg
572 [idents], msg
573 [idents] is a list of idents and msg is a nested message dict of
573 [idents] is a list of idents and msg is a nested message dict of
574 same format as self.msg returns.
574 same format as self.msg returns.
575 """
575 """
576 if isinstance(socket, ZMQStream):
576 if isinstance(socket, ZMQStream):
577 socket = socket.socket
577 socket = socket.socket
578 try:
578 try:
579 msg_list = socket.recv_multipart(mode)
579 msg_list = socket.recv_multipart(mode)
580 except zmq.ZMQError as e:
580 except zmq.ZMQError as e:
581 if e.errno == zmq.EAGAIN:
581 if e.errno == zmq.EAGAIN:
582 # We can convert EAGAIN to None as we know in this case
582 # We can convert EAGAIN to None as we know in this case
583 # recv_multipart won't return None.
583 # recv_multipart won't return None.
584 return None,None
584 return None,None
585 else:
585 else:
586 raise
586 raise
587 # split multipart message into identity list and message dict
587 # split multipart message into identity list and message dict
588 # invalid large messages can cause very expensive string comparisons
588 # invalid large messages can cause very expensive string comparisons
589 idents, msg_list = self.feed_identities(msg_list, copy)
589 idents, msg_list = self.feed_identities(msg_list, copy)
590 try:
590 try:
591 return idents, self.unserialize(msg_list, content=content, copy=copy)
591 return idents, self.unserialize(msg_list, content=content, copy=copy)
592 except Exception as e:
592 except Exception as e:
593 # TODO: handle it
593 # TODO: handle it
594 raise e
594 raise e
595
595
596 def feed_identities(self, msg_list, copy=True):
596 def feed_identities(self, msg_list, copy=True):
597 """Split the identities from the rest of the message.
597 """Split the identities from the rest of the message.
598
598
599 Feed until DELIM is reached, then return the prefix as idents and
599 Feed until DELIM is reached, then return the prefix as idents and
600 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
600 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
601 but that would be silly.
601 but that would be silly.
602
602
603 Parameters
603 Parameters
604 ----------
604 ----------
605 msg_list : a list of Message or bytes objects
605 msg_list : a list of Message or bytes objects
606 The message to be split.
606 The message to be split.
607 copy : bool
607 copy : bool
608 flag determining whether the arguments are bytes or Messages
608 flag determining whether the arguments are bytes or Messages
609
609
610 Returns
610 Returns
611 -------
611 -------
612 (idents, msg_list) : two lists
612 (idents, msg_list) : two lists
613 idents will always be a list of bytes, each of which is a ZMQ
613 idents will always be a list of bytes, each of which is a ZMQ
614 identity. msg_list will be a list of bytes or zmq.Messages of the
614 identity. msg_list will be a list of bytes or zmq.Messages of the
615 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
615 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
616 should be unpackable/unserializable via self.unserialize at this
616 should be unpackable/unserializable via self.unserialize at this
617 point.
617 point.
618 """
618 """
619 if copy:
619 if copy:
620 idx = msg_list.index(DELIM)
620 idx = msg_list.index(DELIM)
621 return msg_list[:idx], msg_list[idx+1:]
621 return msg_list[:idx], msg_list[idx+1:]
622 else:
622 else:
623 failed = True
623 failed = True
624 for idx,m in enumerate(msg_list):
624 for idx,m in enumerate(msg_list):
625 if m.bytes == DELIM:
625 if m.bytes == DELIM:
626 failed = False
626 failed = False
627 break
627 break
628 if failed:
628 if failed:
629 raise ValueError("DELIM not in msg_list")
629 raise ValueError("DELIM not in msg_list")
630 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
630 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
631 return [m.bytes for m in idents], msg_list
631 return [m.bytes for m in idents], msg_list
632
632
633 def unserialize(self, msg_list, content=True, copy=True):
633 def unserialize(self, msg_list, content=True, copy=True):
634 """Unserialize a msg_list to a nested message dict.
634 """Unserialize a msg_list to a nested message dict.
635
635
636 This is roughly the inverse of serialize. The serialize/unserialize
636 This is roughly the inverse of serialize. The serialize/unserialize
637 methods work with full message lists, whereas pack/unpack work with
637 methods work with full message lists, whereas pack/unpack work with
638 the individual message parts in the message list.
638 the individual message parts in the message list.
639
639
640 Parameters:
640 Parameters:
641 -----------
641 -----------
642 msg_list : list of bytes or Message objects
642 msg_list : list of bytes or Message objects
643 The list of message parts of the form [HMAC,p_header,p_parent,
643 The list of message parts of the form [HMAC,p_header,p_parent,
644 p_content,buffer1,buffer2,...].
644 p_content,buffer1,buffer2,...].
645 content : bool (True)
645 content : bool (True)
646 Whether to unpack the content dict (True), or leave it packed
646 Whether to unpack the content dict (True), or leave it packed
647 (False).
647 (False).
648 copy : bool (True)
648 copy : bool (True)
649 Whether to return the bytes (True), or the non-copying Message
649 Whether to return the bytes (True), or the non-copying Message
650 object in each place (False).
650 object in each place (False).
651
651
652 Returns
652 Returns
653 -------
653 -------
654 msg : dict
654 msg : dict
655 The nested message dict with top-level keys [header, parent_header,
655 The nested message dict with top-level keys [header, parent_header,
656 content, buffers].
656 content, buffers].
657 """
657 """
658 minlen = 4
658 minlen = 4
659 message = {}
659 message = {}
660 if not copy:
660 if not copy:
661 for i in range(minlen):
661 for i in range(minlen):
662 msg_list[i] = msg_list[i].bytes
662 msg_list[i] = msg_list[i].bytes
663 if self.auth is not None:
663 if self.auth is not None:
664 signature = msg_list[0]
664 signature = msg_list[0]
665 if not signature:
665 if not signature:
666 raise ValueError("Unsigned Message")
666 raise ValueError("Unsigned Message")
667 if signature in self.digest_history:
667 if signature in self.digest_history:
668 raise ValueError("Duplicate Signature: %r"%signature)
668 raise ValueError("Duplicate Signature: %r"%signature)
669 self.digest_history.add(signature)
669 self.digest_history.add(signature)
670 check = self.sign(msg_list[1:4])
670 check = self.sign(msg_list[1:4])
671 if not signature == check:
671 if not signature == check:
672 raise ValueError("Invalid Signature: %r"%signature)
672 raise ValueError("Invalid Signature: %r"%signature)
673 if not len(msg_list) >= minlen:
673 if not len(msg_list) >= minlen:
674 raise TypeError("malformed message, must have at least %i elements"%minlen)
674 raise TypeError("malformed message, must have at least %i elements"%minlen)
675 header = self.unpack(msg_list[1])
675 header = self.unpack(msg_list[1])
676 message['header'] = header
676 message['header'] = header
677 message['msg_id'] = header['msg_id']
677 message['msg_id'] = header['msg_id']
678 message['msg_type'] = header['msg_type']
678 message['msg_type'] = header['msg_type']
679 message['parent_header'] = self.unpack(msg_list[2])
679 message['parent_header'] = self.unpack(msg_list[2])
680 if content:
680 if content:
681 message['content'] = self.unpack(msg_list[3])
681 message['content'] = self.unpack(msg_list[3])
682 else:
682 else:
683 message['content'] = msg_list[3]
683 message['content'] = msg_list[3]
684
684
685 message['buffers'] = msg_list[4:]
685 message['buffers'] = msg_list[4:]
686 return message
686 return message
687
687
688 def test_msg2obj():
688 def test_msg2obj():
689 am = dict(x=1)
689 am = dict(x=1)
690 ao = Message(am)
690 ao = Message(am)
691 assert ao.x == am['x']
691 assert ao.x == am['x']
692
692
693 am['y'] = dict(z=1)
693 am['y'] = dict(z=1)
694 ao = Message(am)
694 ao = Message(am)
695 assert ao.y.z == am['y']['z']
695 assert ao.y.z == am['y']['z']
696
696
697 k1, k2 = 'y', 'z'
697 k1, k2 = 'y', 'z'
698 assert ao[k1][k2] == am[k1][k2]
698 assert ao[k1][k2] == am[k1][k2]
699
699
700 am2 = dict(ao)
700 am2 = dict(ao)
701 assert am['x'] == am2['x']
701 assert am['x'] == am2['x']
702 assert am['y']['z'] == am2['y']['z']
702 assert am['y']['z'] == am2['y']['z']
703
703
General Comments 0
You need to be logged in to leave comments. Login now