##// END OF EJS Templates
More Python 3 compatibility fixes.
Thomas Kluyver -
Show More
@@ -1,715 +1,715 b''
1 1 # encoding: utf-8
2 2 """
3 3 Utilities for working with strings and text.
4 4 """
5 5
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2008-2009 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16
17 17 import __main__
18 18
19 19 import os
20 20 import re
21 21 import shutil
22 22 import textwrap
23 23 from string import Formatter
24 24
25 25 from IPython.external.path import path
26
26 from IPython.utils import py3compat
27 27 from IPython.utils.io import nlprint
28 28 from IPython.utils.data import flatten
29 29
30 30 #-----------------------------------------------------------------------------
31 31 # Code
32 32 #-----------------------------------------------------------------------------
33 33
34 34
35 35 def unquote_ends(istr):
36 36 """Remove a single pair of quotes from the endpoints of a string."""
37 37
38 38 if not istr:
39 39 return istr
40 40 if (istr[0]=="'" and istr[-1]=="'") or \
41 41 (istr[0]=='"' and istr[-1]=='"'):
42 42 return istr[1:-1]
43 43 else:
44 44 return istr
45 45
46 46
47 47 class LSString(str):
48 48 """String derivative with a special access attributes.
49 49
50 50 These are normal strings, but with the special attributes:
51 51
52 52 .l (or .list) : value as list (split on newlines).
53 53 .n (or .nlstr): original value (the string itself).
54 54 .s (or .spstr): value as whitespace-separated string.
55 55 .p (or .paths): list of path objects
56 56
57 57 Any values which require transformations are computed only once and
58 58 cached.
59 59
60 60 Such strings are very useful to efficiently interact with the shell, which
61 61 typically only understands whitespace-separated options for commands."""
62 62
63 63 def get_list(self):
64 64 try:
65 65 return self.__list
66 66 except AttributeError:
67 67 self.__list = self.split('\n')
68 68 return self.__list
69 69
70 70 l = list = property(get_list)
71 71
72 72 def get_spstr(self):
73 73 try:
74 74 return self.__spstr
75 75 except AttributeError:
76 76 self.__spstr = self.replace('\n',' ')
77 77 return self.__spstr
78 78
79 79 s = spstr = property(get_spstr)
80 80
81 81 def get_nlstr(self):
82 82 return self
83 83
84 84 n = nlstr = property(get_nlstr)
85 85
86 86 def get_paths(self):
87 87 try:
88 88 return self.__paths
89 89 except AttributeError:
90 90 self.__paths = [path(p) for p in self.split('\n') if os.path.exists(p)]
91 91 return self.__paths
92 92
93 93 p = paths = property(get_paths)
94 94
95 95 # FIXME: We need to reimplement type specific displayhook and then add this
96 96 # back as a custom printer. This should also be moved outside utils into the
97 97 # core.
98 98
99 99 # def print_lsstring(arg):
100 100 # """ Prettier (non-repr-like) and more informative printer for LSString """
101 101 # print "LSString (.p, .n, .l, .s available). Value:"
102 102 # print arg
103 103 #
104 104 #
105 105 # print_lsstring = result_display.when_type(LSString)(print_lsstring)
106 106
107 107
108 108 class SList(list):
109 109 """List derivative with a special access attributes.
110 110
111 111 These are normal lists, but with the special attributes:
112 112
113 113 .l (or .list) : value as list (the list itself).
114 114 .n (or .nlstr): value as a string, joined on newlines.
115 115 .s (or .spstr): value as a string, joined on spaces.
116 116 .p (or .paths): list of path objects
117 117
118 118 Any values which require transformations are computed only once and
119 119 cached."""
120 120
121 121 def get_list(self):
122 122 return self
123 123
124 124 l = list = property(get_list)
125 125
126 126 def get_spstr(self):
127 127 try:
128 128 return self.__spstr
129 129 except AttributeError:
130 130 self.__spstr = ' '.join(self)
131 131 return self.__spstr
132 132
133 133 s = spstr = property(get_spstr)
134 134
135 135 def get_nlstr(self):
136 136 try:
137 137 return self.__nlstr
138 138 except AttributeError:
139 139 self.__nlstr = '\n'.join(self)
140 140 return self.__nlstr
141 141
142 142 n = nlstr = property(get_nlstr)
143 143
144 144 def get_paths(self):
145 145 try:
146 146 return self.__paths
147 147 except AttributeError:
148 148 self.__paths = [path(p) for p in self if os.path.exists(p)]
149 149 return self.__paths
150 150
151 151 p = paths = property(get_paths)
152 152
153 153 def grep(self, pattern, prune = False, field = None):
154 154 """ Return all strings matching 'pattern' (a regex or callable)
155 155
156 156 This is case-insensitive. If prune is true, return all items
157 157 NOT matching the pattern.
158 158
159 159 If field is specified, the match must occur in the specified
160 160 whitespace-separated field.
161 161
162 162 Examples::
163 163
164 164 a.grep( lambda x: x.startswith('C') )
165 165 a.grep('Cha.*log', prune=1)
166 166 a.grep('chm', field=-1)
167 167 """
168 168
169 169 def match_target(s):
170 170 if field is None:
171 171 return s
172 172 parts = s.split()
173 173 try:
174 174 tgt = parts[field]
175 175 return tgt
176 176 except IndexError:
177 177 return ""
178 178
179 179 if isinstance(pattern, basestring):
180 180 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
181 181 else:
182 182 pred = pattern
183 183 if not prune:
184 184 return SList([el for el in self if pred(match_target(el))])
185 185 else:
186 186 return SList([el for el in self if not pred(match_target(el))])
187 187
188 188 def fields(self, *fields):
189 189 """ Collect whitespace-separated fields from string list
190 190
191 191 Allows quick awk-like usage of string lists.
192 192
193 193 Example data (in var a, created by 'a = !ls -l')::
194 194 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
195 195 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
196 196
197 197 a.fields(0) is ['-rwxrwxrwx', 'drwxrwxrwx+']
198 198 a.fields(1,0) is ['1 -rwxrwxrwx', '6 drwxrwxrwx+']
199 199 (note the joining by space).
200 200 a.fields(-1) is ['ChangeLog', 'IPython']
201 201
202 202 IndexErrors are ignored.
203 203
204 204 Without args, fields() just split()'s the strings.
205 205 """
206 206 if len(fields) == 0:
207 207 return [el.split() for el in self]
208 208
209 209 res = SList()
210 210 for el in [f.split() for f in self]:
211 211 lineparts = []
212 212
213 213 for fd in fields:
214 214 try:
215 215 lineparts.append(el[fd])
216 216 except IndexError:
217 217 pass
218 218 if lineparts:
219 219 res.append(" ".join(lineparts))
220 220
221 221 return res
222 222
223 223 def sort(self,field= None, nums = False):
224 224 """ sort by specified fields (see fields())
225 225
226 226 Example::
227 227 a.sort(1, nums = True)
228 228
229 229 Sorts a by second field, in numerical order (so that 21 > 3)
230 230
231 231 """
232 232
233 233 #decorate, sort, undecorate
234 234 if field is not None:
235 235 dsu = [[SList([line]).fields(field), line] for line in self]
236 236 else:
237 237 dsu = [[line, line] for line in self]
238 238 if nums:
239 239 for i in range(len(dsu)):
240 240 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
241 241 try:
242 242 n = int(numstr)
243 243 except ValueError:
244 244 n = 0;
245 245 dsu[i][0] = n
246 246
247 247
248 248 dsu.sort()
249 249 return SList([t[1] for t in dsu])
250 250
251 251
252 252 # FIXME: We need to reimplement type specific displayhook and then add this
253 253 # back as a custom printer. This should also be moved outside utils into the
254 254 # core.
255 255
256 256 # def print_slist(arg):
257 257 # """ Prettier (non-repr-like) and more informative printer for SList """
258 258 # print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
259 259 # if hasattr(arg, 'hideonce') and arg.hideonce:
260 260 # arg.hideonce = False
261 261 # return
262 262 #
263 263 # nlprint(arg)
264 264 #
265 265 # print_slist = result_display.when_type(SList)(print_slist)
266 266
267 267
268 268 def esc_quotes(strng):
269 269 """Return the input string with single and double quotes escaped out"""
270 270
271 271 return strng.replace('"','\\"').replace("'","\\'")
272 272
273 273
274 274 def make_quoted_expr(s):
275 275 """Return string s in appropriate quotes, using raw string if possible.
276 276
277 277 XXX - example removed because it caused encoding errors in documentation
278 278 generation. We need a new example that doesn't contain invalid chars.
279 279
280 280 Note the use of raw string and padding at the end to allow trailing
281 281 backslash.
282 282 """
283 283
284 284 tail = ''
285 285 tailpadding = ''
286 286 raw = ''
287 ucode = 'u'
287 ucode = '' if py3compat.PY3 else 'u'
288 288 if "\\" in s:
289 289 raw = 'r'
290 290 if s.endswith('\\'):
291 291 tail = '[:-1]'
292 292 tailpadding = '_'
293 293 if '"' not in s:
294 294 quote = '"'
295 295 elif "'" not in s:
296 296 quote = "'"
297 297 elif '"""' not in s and not s.endswith('"'):
298 298 quote = '"""'
299 299 elif "'''" not in s and not s.endswith("'"):
300 300 quote = "'''"
301 301 else:
302 302 # give up, backslash-escaped string will do
303 303 return '"%s"' % esc_quotes(s)
304 304 res = ucode + raw + quote + s + tailpadding + quote + tail
305 305 return res
306 306
307 307
308 308 def qw(words,flat=0,sep=None,maxsplit=-1):
309 309 """Similar to Perl's qw() operator, but with some more options.
310 310
311 311 qw(words,flat=0,sep=' ',maxsplit=-1) -> words.split(sep,maxsplit)
312 312
313 313 words can also be a list itself, and with flat=1, the output will be
314 314 recursively flattened.
315 315
316 316 Examples:
317 317
318 318 >>> qw('1 2')
319 319 ['1', '2']
320 320
321 321 >>> qw(['a b','1 2',['m n','p q']])
322 322 [['a', 'b'], ['1', '2'], [['m', 'n'], ['p', 'q']]]
323 323
324 324 >>> qw(['a b','1 2',['m n','p q']],flat=1)
325 325 ['a', 'b', '1', '2', 'm', 'n', 'p', 'q']
326 326 """
327 327
328 328 if isinstance(words, basestring):
329 329 return [word.strip() for word in words.split(sep,maxsplit)
330 330 if word and not word.isspace() ]
331 331 if flat:
332 332 return flatten(map(qw,words,[1]*len(words)))
333 333 return map(qw,words)
334 334
335 335
336 336 def qwflat(words,sep=None,maxsplit=-1):
337 337 """Calls qw(words) in flat mode. It's just a convenient shorthand."""
338 338 return qw(words,1,sep,maxsplit)
339 339
340 340
341 341 def qw_lol(indata):
342 342 """qw_lol('a b') -> [['a','b']],
343 343 otherwise it's just a call to qw().
344 344
345 345 We need this to make sure the modules_some keys *always* end up as a
346 346 list of lists."""
347 347
348 348 if isinstance(indata, basestring):
349 349 return [qw(indata)]
350 350 else:
351 351 return qw(indata)
352 352
353 353
354 354 def grep(pat,list,case=1):
355 355 """Simple minded grep-like function.
356 356 grep(pat,list) returns occurrences of pat in list, None on failure.
357 357
358 358 It only does simple string matching, with no support for regexps. Use the
359 359 option case=0 for case-insensitive matching."""
360 360
361 361 # This is pretty crude. At least it should implement copying only references
362 362 # to the original data in case it's big. Now it copies the data for output.
363 363 out=[]
364 364 if case:
365 365 for term in list:
366 366 if term.find(pat)>-1: out.append(term)
367 367 else:
368 368 lpat=pat.lower()
369 369 for term in list:
370 370 if term.lower().find(lpat)>-1: out.append(term)
371 371
372 372 if len(out): return out
373 373 else: return None
374 374
375 375
376 376 def dgrep(pat,*opts):
377 377 """Return grep() on dir()+dir(__builtins__).
378 378
379 379 A very common use of grep() when working interactively."""
380 380
381 381 return grep(pat,dir(__main__)+dir(__main__.__builtins__),*opts)
382 382
383 383
384 384 def idgrep(pat):
385 385 """Case-insensitive dgrep()"""
386 386
387 387 return dgrep(pat,0)
388 388
389 389
390 390 def igrep(pat,list):
391 391 """Synonym for case-insensitive grep."""
392 392
393 393 return grep(pat,list,case=0)
394 394
395 395
396 396 def indent(instr,nspaces=4, ntabs=0, flatten=False):
397 397 """Indent a string a given number of spaces or tabstops.
398 398
399 399 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
400 400
401 401 Parameters
402 402 ----------
403 403
404 404 instr : basestring
405 405 The string to be indented.
406 406 nspaces : int (default: 4)
407 407 The number of spaces to be indented.
408 408 ntabs : int (default: 0)
409 409 The number of tabs to be indented.
410 410 flatten : bool (default: False)
411 411 Whether to scrub existing indentation. If True, all lines will be
412 412 aligned to the same indentation. If False, existing indentation will
413 413 be strictly increased.
414 414
415 415 Returns
416 416 -------
417 417
418 418 str|unicode : string indented by ntabs and nspaces.
419 419
420 420 """
421 421 if instr is None:
422 422 return
423 423 ind = '\t'*ntabs+' '*nspaces
424 424 if flatten:
425 425 pat = re.compile(r'^\s*', re.MULTILINE)
426 426 else:
427 427 pat = re.compile(r'^', re.MULTILINE)
428 428 outstr = re.sub(pat, ind, instr)
429 429 if outstr.endswith(os.linesep+ind):
430 430 return outstr[:-len(ind)]
431 431 else:
432 432 return outstr
433 433
434 434 def native_line_ends(filename,backup=1):
435 435 """Convert (in-place) a file to line-ends native to the current OS.
436 436
437 437 If the optional backup argument is given as false, no backup of the
438 438 original file is left. """
439 439
440 440 backup_suffixes = {'posix':'~','dos':'.bak','nt':'.bak','mac':'.bak'}
441 441
442 442 bak_filename = filename + backup_suffixes[os.name]
443 443
444 444 original = open(filename).read()
445 445 shutil.copy2(filename,bak_filename)
446 446 try:
447 447 new = open(filename,'wb')
448 448 new.write(os.linesep.join(original.splitlines()))
449 449 new.write(os.linesep) # ALWAYS put an eol at the end of the file
450 450 new.close()
451 451 except:
452 452 os.rename(bak_filename,filename)
453 453 if not backup:
454 454 try:
455 455 os.remove(bak_filename)
456 456 except:
457 457 pass
458 458
459 459
460 460 def list_strings(arg):
461 461 """Always return a list of strings, given a string or list of strings
462 462 as input.
463 463
464 464 :Examples:
465 465
466 466 In [7]: list_strings('A single string')
467 467 Out[7]: ['A single string']
468 468
469 469 In [8]: list_strings(['A single string in a list'])
470 470 Out[8]: ['A single string in a list']
471 471
472 472 In [9]: list_strings(['A','list','of','strings'])
473 473 Out[9]: ['A', 'list', 'of', 'strings']
474 474 """
475 475
476 476 if isinstance(arg,basestring): return [arg]
477 477 else: return arg
478 478
479 479
480 480 def marquee(txt='',width=78,mark='*'):
481 481 """Return the input string centered in a 'marquee'.
482 482
483 483 :Examples:
484 484
485 485 In [16]: marquee('A test',40)
486 486 Out[16]: '**************** A test ****************'
487 487
488 488 In [17]: marquee('A test',40,'-')
489 489 Out[17]: '---------------- A test ----------------'
490 490
491 491 In [18]: marquee('A test',40,' ')
492 492 Out[18]: ' A test '
493 493
494 494 """
495 495 if not txt:
496 496 return (mark*width)[:width]
497 nmark = (width-len(txt)-2)/len(mark)/2
497 nmark = (width-len(txt)-2)//len(mark)//2
498 498 if nmark < 0: nmark =0
499 499 marks = mark*nmark
500 500 return '%s %s %s' % (marks,txt,marks)
501 501
502 502
503 503 ini_spaces_re = re.compile(r'^(\s+)')
504 504
505 505 def num_ini_spaces(strng):
506 506 """Return the number of initial spaces in a string"""
507 507
508 508 ini_spaces = ini_spaces_re.match(strng)
509 509 if ini_spaces:
510 510 return ini_spaces.end()
511 511 else:
512 512 return 0
513 513
514 514
515 515 def format_screen(strng):
516 516 """Format a string for screen printing.
517 517
518 518 This removes some latex-type format codes."""
519 519 # Paragraph continue
520 520 par_re = re.compile(r'\\$',re.MULTILINE)
521 521 strng = par_re.sub('',strng)
522 522 return strng
523 523
524 524 def dedent(text):
525 525 """Equivalent of textwrap.dedent that ignores unindented first line.
526 526
527 527 This means it will still dedent strings like:
528 528 '''foo
529 529 is a bar
530 530 '''
531 531
532 532 For use in wrap_paragraphs.
533 533 """
534 534
535 535 if text.startswith('\n'):
536 536 # text starts with blank line, don't ignore the first line
537 537 return textwrap.dedent(text)
538 538
539 539 # split first line
540 540 splits = text.split('\n',1)
541 541 if len(splits) == 1:
542 542 # only one line
543 543 return textwrap.dedent(text)
544 544
545 545 first, rest = splits
546 546 # dedent everything but the first line
547 547 rest = textwrap.dedent(rest)
548 548 return '\n'.join([first, rest])
549 549
550 550 def wrap_paragraphs(text, ncols=80):
551 551 """Wrap multiple paragraphs to fit a specified width.
552 552
553 553 This is equivalent to textwrap.wrap, but with support for multiple
554 554 paragraphs, as separated by empty lines.
555 555
556 556 Returns
557 557 -------
558 558
559 559 list of complete paragraphs, wrapped to fill `ncols` columns.
560 560 """
561 561 paragraph_re = re.compile(r'\n(\s*\n)+', re.MULTILINE)
562 562 text = dedent(text).strip()
563 563 paragraphs = paragraph_re.split(text)[::2] # every other entry is space
564 564 out_ps = []
565 565 indent_re = re.compile(r'\n\s+', re.MULTILINE)
566 566 for p in paragraphs:
567 567 # presume indentation that survives dedent is meaningful formatting,
568 568 # so don't fill unless text is flush.
569 569 if indent_re.search(p) is None:
570 570 # wrap paragraph
571 571 p = textwrap.fill(p, ncols)
572 572 out_ps.append(p)
573 573 return out_ps
574 574
575 575
576 576
577 577 class EvalFormatter(Formatter):
578 578 """A String Formatter that allows evaluation of simple expressions.
579 579
580 580 Any time a format key is not found in the kwargs,
581 581 it will be tried as an expression in the kwargs namespace.
582 582
583 583 This is to be used in templating cases, such as the parallel batch
584 584 script templates, where simple arithmetic on arguments is useful.
585 585
586 586 Examples
587 587 --------
588 588
589 589 In [1]: f = EvalFormatter()
590 590 In [2]: f.format('{n/4}', n=8)
591 591 Out[2]: '2'
592 592
593 593 In [3]: f.format('{range(3)}')
594 594 Out[3]: '[0, 1, 2]'
595 595
596 596 In [4]: f.format('{3*2}')
597 597 Out[4]: '6'
598 598 """
599 599
600 600 # should we allow slicing by disabling the format_spec feature?
601 601 allow_slicing = True
602 602
603 603 # copied from Formatter._vformat with minor changes to allow eval
604 604 # and replace the format_spec code with slicing
605 605 def _vformat(self, format_string, args, kwargs, used_args, recursion_depth):
606 606 if recursion_depth < 0:
607 607 raise ValueError('Max string recursion exceeded')
608 608 result = []
609 609 for literal_text, field_name, format_spec, conversion in \
610 610 self.parse(format_string):
611 611
612 612 # output the literal text
613 613 if literal_text:
614 614 result.append(literal_text)
615 615
616 616 # if there's a field, output it
617 617 if field_name is not None:
618 618 # this is some markup, find the object and do
619 619 # the formatting
620 620
621 621 if self.allow_slicing and format_spec:
622 622 # override format spec, to allow slicing:
623 623 field_name = ':'.join([field_name, format_spec])
624 624 format_spec = ''
625 625
626 626 # eval the contents of the field for the object
627 627 # to be formatted
628 628 obj = eval(field_name, kwargs)
629 629
630 630 # do any conversion on the resulting object
631 631 obj = self.convert_field(obj, conversion)
632 632
633 633 # expand the format spec, if needed
634 634 format_spec = self._vformat(format_spec, args, kwargs,
635 635 used_args, recursion_depth-1)
636 636
637 637 # format the object and append to the result
638 638 result.append(self.format_field(obj, format_spec))
639 639
640 640 return ''.join(result)
641 641
642 642
643 643 def columnize(items, separator=' ', displaywidth=80):
644 644 """ Transform a list of strings into a single string with columns.
645 645
646 646 Parameters
647 647 ----------
648 648 items : sequence of strings
649 649 The strings to process.
650 650
651 651 separator : str, optional [default is two spaces]
652 652 The string that separates columns.
653 653
654 654 displaywidth : int, optional [default is 80]
655 655 Width of the display in number of characters.
656 656
657 657 Returns
658 658 -------
659 659 The formatted string.
660 660 """
661 661 # Note: this code is adapted from columnize 0.3.2.
662 662 # See http://code.google.com/p/pycolumnize/
663 663
664 664 # Some degenerate cases.
665 665 size = len(items)
666 666 if size == 0:
667 667 return '\n'
668 668 elif size == 1:
669 669 return '%s\n' % items[0]
670 670
671 671 # Special case: if any item is longer than the maximum width, there's no
672 672 # point in triggering the logic below...
673 673 item_len = map(len, items) # save these, we can reuse them below
674 674 longest = max(item_len)
675 675 if longest >= displaywidth:
676 676 return '\n'.join(items+[''])
677 677
678 678 # Try every row count from 1 upwards
679 679 array_index = lambda nrows, row, col: nrows*col + row
680 680 for nrows in range(1, size):
681 681 ncols = (size + nrows - 1) // nrows
682 682 colwidths = []
683 683 totwidth = -len(separator)
684 684 for col in range(ncols):
685 685 # Get max column width for this column
686 686 colwidth = 0
687 687 for row in range(nrows):
688 688 i = array_index(nrows, row, col)
689 689 if i >= size: break
690 690 x, len_x = items[i], item_len[i]
691 691 colwidth = max(colwidth, len_x)
692 692 colwidths.append(colwidth)
693 693 totwidth += colwidth + len(separator)
694 694 if totwidth > displaywidth:
695 695 break
696 696 if totwidth <= displaywidth:
697 697 break
698 698
699 699 # The smallest number of rows computed and the max widths for each
700 700 # column has been obtained. Now we just have to format each of the rows.
701 701 string = ''
702 702 for row in range(nrows):
703 703 texts = []
704 704 for col in range(ncols):
705 705 i = row + nrows*col
706 706 if i >= size:
707 707 texts.append('')
708 708 else:
709 709 texts.append(items[i])
710 710 while texts and not texts[-1]:
711 711 del texts[-1]
712 712 for col in range(len(texts)):
713 713 texts[col] = texts[col].ljust(colwidths[col])
714 714 string += '%s\n' % separator.join(texts)
715 715 return string
@@ -1,1396 +1,1398 b''
1 1 # encoding: utf-8
2 2 """
3 3 A lightweight Traits like module.
4 4
5 5 This is designed to provide a lightweight, simple, pure Python version of
6 6 many of the capabilities of enthought.traits. This includes:
7 7
8 8 * Validation
9 9 * Type specification with defaults
10 10 * Static and dynamic notification
11 11 * Basic predefined types
12 12 * An API that is similar to enthought.traits
13 13
14 14 We don't support:
15 15
16 16 * Delegation
17 17 * Automatic GUI generation
18 18 * A full set of trait types. Most importantly, we don't provide container
19 19 traits (list, dict, tuple) that can trigger notifications if their
20 20 contents change.
21 21 * API compatibility with enthought.traits
22 22
23 23 There are also some important difference in our design:
24 24
25 25 * enthought.traits does not validate default values. We do.
26 26
27 27 We choose to create this module because we need these capabilities, but
28 28 we need them to be pure Python so they work in all Python implementations,
29 29 including Jython and IronPython.
30 30
31 31 Authors:
32 32
33 33 * Brian Granger
34 34 * Enthought, Inc. Some of the code in this file comes from enthought.traits
35 35 and is licensed under the BSD license. Also, many of the ideas also come
36 36 from enthought.traits even though our implementation is very different.
37 37 """
38 38
39 39 #-----------------------------------------------------------------------------
40 40 # Copyright (C) 2008-2009 The IPython Development Team
41 41 #
42 42 # Distributed under the terms of the BSD License. The full license is in
43 43 # the file COPYING, distributed as part of this software.
44 44 #-----------------------------------------------------------------------------
45 45
46 46 #-----------------------------------------------------------------------------
47 47 # Imports
48 48 #-----------------------------------------------------------------------------
49 49
50 50
51 51 import inspect
52 52 import re
53 53 import sys
54 54 import types
55 from types import (
56 InstanceType, ClassType, FunctionType,
57 ListType, TupleType
58 )
55 from types import FunctionType
56 try:
57 from types import ClassType, InstanceType
58 ClassTypes = (ClassType, type)
59 except:
60 ClassTypes = (type,)
61
59 62 from .importstring import import_item
63 from IPython.utils import py3compat
60 64
61 ClassTypes = (ClassType, type)
62
63 SequenceTypes = (ListType, TupleType, set, frozenset)
65 SequenceTypes = (list, tuple, set, frozenset)
64 66
65 67 #-----------------------------------------------------------------------------
66 68 # Basic classes
67 69 #-----------------------------------------------------------------------------
68 70
69 71
70 72 class NoDefaultSpecified ( object ): pass
71 73 NoDefaultSpecified = NoDefaultSpecified()
72 74
73 75
74 76 class Undefined ( object ): pass
75 77 Undefined = Undefined()
76 78
77 79 class TraitError(Exception):
78 80 pass
79 81
80 82 #-----------------------------------------------------------------------------
81 83 # Utilities
82 84 #-----------------------------------------------------------------------------
83 85
84 86
85 87 def class_of ( object ):
86 88 """ Returns a string containing the class name of an object with the
87 89 correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image',
88 90 'a PlotValue').
89 91 """
90 92 if isinstance( object, basestring ):
91 93 return add_article( object )
92 94
93 95 return add_article( object.__class__.__name__ )
94 96
95 97
96 98 def add_article ( name ):
97 99 """ Returns a string containing the correct indefinite article ('a' or 'an')
98 100 prefixed to the specified string.
99 101 """
100 102 if name[:1].lower() in 'aeiou':
101 103 return 'an ' + name
102 104
103 105 return 'a ' + name
104 106
105 107
106 108 def repr_type(obj):
107 109 """ Return a string representation of a value and its type for readable
108 110 error messages.
109 111 """
110 112 the_type = type(obj)
111 if the_type is InstanceType:
113 if (not py3compat.PY3) and the_type is InstanceType:
112 114 # Old-style class.
113 115 the_type = obj.__class__
114 116 msg = '%r %r' % (obj, the_type)
115 117 return msg
116 118
117 119
118 120 def parse_notifier_name(name):
119 121 """Convert the name argument to a list of names.
120 122
121 123 Examples
122 124 --------
123 125
124 126 >>> parse_notifier_name('a')
125 127 ['a']
126 128 >>> parse_notifier_name(['a','b'])
127 129 ['a', 'b']
128 130 >>> parse_notifier_name(None)
129 131 ['anytrait']
130 132 """
131 133 if isinstance(name, str):
132 134 return [name]
133 135 elif name is None:
134 136 return ['anytrait']
135 137 elif isinstance(name, (list, tuple)):
136 138 for n in name:
137 139 assert isinstance(n, str), "names must be strings"
138 140 return name
139 141
140 142
141 143 class _SimpleTest:
142 144 def __init__ ( self, value ): self.value = value
143 145 def __call__ ( self, test ):
144 146 return test == self.value
145 147 def __repr__(self):
146 148 return "<SimpleTest(%r)" % self.value
147 149 def __str__(self):
148 150 return self.__repr__()
149 151
150 152
151 153 def getmembers(object, predicate=None):
152 154 """A safe version of inspect.getmembers that handles missing attributes.
153 155
154 156 This is useful when there are descriptor based attributes that for
155 157 some reason raise AttributeError even though they exist. This happens
156 158 in zope.inteface with the __provides__ attribute.
157 159 """
158 160 results = []
159 161 for key in dir(object):
160 162 try:
161 163 value = getattr(object, key)
162 164 except AttributeError:
163 165 pass
164 166 else:
165 167 if not predicate or predicate(value):
166 168 results.append((key, value))
167 169 results.sort()
168 170 return results
169 171
170 172
171 173 #-----------------------------------------------------------------------------
172 174 # Base TraitType for all traits
173 175 #-----------------------------------------------------------------------------
174 176
175 177
176 178 class TraitType(object):
177 179 """A base class for all trait descriptors.
178 180
179 181 Notes
180 182 -----
181 183 Our implementation of traits is based on Python's descriptor
182 184 prototol. This class is the base class for all such descriptors. The
183 185 only magic we use is a custom metaclass for the main :class:`HasTraits`
184 186 class that does the following:
185 187
186 188 1. Sets the :attr:`name` attribute of every :class:`TraitType`
187 189 instance in the class dict to the name of the attribute.
188 190 2. Sets the :attr:`this_class` attribute of every :class:`TraitType`
189 191 instance in the class dict to the *class* that declared the trait.
190 192 This is used by the :class:`This` trait to allow subclasses to
191 193 accept superclasses for :class:`This` values.
192 194 """
193 195
194 196
195 197 metadata = {}
196 198 default_value = Undefined
197 199 info_text = 'any value'
198 200
199 201 def __init__(self, default_value=NoDefaultSpecified, **metadata):
200 202 """Create a TraitType.
201 203 """
202 204 if default_value is not NoDefaultSpecified:
203 205 self.default_value = default_value
204 206
205 207 if len(metadata) > 0:
206 208 if len(self.metadata) > 0:
207 209 self._metadata = self.metadata.copy()
208 210 self._metadata.update(metadata)
209 211 else:
210 212 self._metadata = metadata
211 213 else:
212 214 self._metadata = self.metadata
213 215
214 216 self.init()
215 217
216 218 def init(self):
217 219 pass
218 220
219 221 def get_default_value(self):
220 222 """Create a new instance of the default value."""
221 223 return self.default_value
222 224
223 225 def instance_init(self, obj):
224 226 """This is called by :meth:`HasTraits.__new__` to finish init'ing.
225 227
226 228 Some stages of initialization must be delayed until the parent
227 229 :class:`HasTraits` instance has been created. This method is
228 230 called in :meth:`HasTraits.__new__` after the instance has been
229 231 created.
230 232
231 233 This method trigger the creation and validation of default values
232 234 and also things like the resolution of str given class names in
233 235 :class:`Type` and :class`Instance`.
234 236
235 237 Parameters
236 238 ----------
237 239 obj : :class:`HasTraits` instance
238 240 The parent :class:`HasTraits` instance that has just been
239 241 created.
240 242 """
241 243 self.set_default_value(obj)
242 244
243 245 def set_default_value(self, obj):
244 246 """Set the default value on a per instance basis.
245 247
246 248 This method is called by :meth:`instance_init` to create and
247 249 validate the default value. The creation and validation of
248 250 default values must be delayed until the parent :class:`HasTraits`
249 251 class has been instantiated.
250 252 """
251 253 # Check for a deferred initializer defined in the same class as the
252 254 # trait declaration or above.
253 255 mro = type(obj).mro()
254 256 meth_name = '_%s_default' % self.name
255 257 for cls in mro[:mro.index(self.this_class)+1]:
256 258 if meth_name in cls.__dict__:
257 259 break
258 260 else:
259 261 # We didn't find one. Do static initialization.
260 262 dv = self.get_default_value()
261 263 newdv = self._validate(obj, dv)
262 264 obj._trait_values[self.name] = newdv
263 265 return
264 266 # Complete the dynamic initialization.
265 267 obj._trait_dyn_inits[self.name] = cls.__dict__[meth_name]
266 268
267 269 def __get__(self, obj, cls=None):
268 270 """Get the value of the trait by self.name for the instance.
269 271
270 272 Default values are instantiated when :meth:`HasTraits.__new__`
271 273 is called. Thus by the time this method gets called either the
272 274 default value or a user defined value (they called :meth:`__set__`)
273 275 is in the :class:`HasTraits` instance.
274 276 """
275 277 if obj is None:
276 278 return self
277 279 else:
278 280 try:
279 281 value = obj._trait_values[self.name]
280 282 except KeyError:
281 283 # Check for a dynamic initializer.
282 284 if self.name in obj._trait_dyn_inits:
283 285 value = obj._trait_dyn_inits[self.name](obj)
284 286 # FIXME: Do we really validate here?
285 287 value = self._validate(obj, value)
286 288 obj._trait_values[self.name] = value
287 289 return value
288 290 else:
289 291 raise TraitError('Unexpected error in TraitType: '
290 292 'both default value and dynamic initializer are '
291 293 'absent.')
292 294 except Exception:
293 295 # HasTraits should call set_default_value to populate
294 296 # this. So this should never be reached.
295 297 raise TraitError('Unexpected error in TraitType: '
296 298 'default value not set properly')
297 299 else:
298 300 return value
299 301
300 302 def __set__(self, obj, value):
301 303 new_value = self._validate(obj, value)
302 304 old_value = self.__get__(obj)
303 305 if old_value != new_value:
304 306 obj._trait_values[self.name] = new_value
305 307 obj._notify_trait(self.name, old_value, new_value)
306 308
307 309 def _validate(self, obj, value):
308 310 if hasattr(self, 'validate'):
309 311 return self.validate(obj, value)
310 312 elif hasattr(self, 'is_valid_for'):
311 313 valid = self.is_valid_for(value)
312 314 if valid:
313 315 return value
314 316 else:
315 317 raise TraitError('invalid value for type: %r' % value)
316 318 elif hasattr(self, 'value_for'):
317 319 return self.value_for(value)
318 320 else:
319 321 return value
320 322
321 323 def info(self):
322 324 return self.info_text
323 325
324 326 def error(self, obj, value):
325 327 if obj is not None:
326 328 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
327 329 % (self.name, class_of(obj),
328 330 self.info(), repr_type(value))
329 331 else:
330 332 e = "The '%s' trait must be %s, but a value of %r was specified." \
331 333 % (self.name, self.info(), repr_type(value))
332 334 raise TraitError(e)
333 335
334 336 def get_metadata(self, key):
335 337 return getattr(self, '_metadata', {}).get(key, None)
336 338
337 339 def set_metadata(self, key, value):
338 340 getattr(self, '_metadata', {})[key] = value
339 341
340 342
341 343 #-----------------------------------------------------------------------------
342 344 # The HasTraits implementation
343 345 #-----------------------------------------------------------------------------
344 346
345 347
346 348 class MetaHasTraits(type):
347 349 """A metaclass for HasTraits.
348 350
349 351 This metaclass makes sure that any TraitType class attributes are
350 352 instantiated and sets their name attribute.
351 353 """
352 354
353 355 def __new__(mcls, name, bases, classdict):
354 356 """Create the HasTraits class.
355 357
356 358 This instantiates all TraitTypes in the class dict and sets their
357 359 :attr:`name` attribute.
358 360 """
359 361 # print "MetaHasTraitlets (mcls, name): ", mcls, name
360 362 # print "MetaHasTraitlets (bases): ", bases
361 363 # print "MetaHasTraitlets (classdict): ", classdict
362 364 for k,v in classdict.iteritems():
363 365 if isinstance(v, TraitType):
364 366 v.name = k
365 367 elif inspect.isclass(v):
366 368 if issubclass(v, TraitType):
367 369 vinst = v()
368 370 vinst.name = k
369 371 classdict[k] = vinst
370 372 return super(MetaHasTraits, mcls).__new__(mcls, name, bases, classdict)
371 373
372 374 def __init__(cls, name, bases, classdict):
373 375 """Finish initializing the HasTraits class.
374 376
375 377 This sets the :attr:`this_class` attribute of each TraitType in the
376 378 class dict to the newly created class ``cls``.
377 379 """
378 380 for k, v in classdict.iteritems():
379 381 if isinstance(v, TraitType):
380 382 v.this_class = cls
381 383 super(MetaHasTraits, cls).__init__(name, bases, classdict)
382 384
383 385 class HasTraits(object):
384 386
385 387 __metaclass__ = MetaHasTraits
386 388
387 389 def __new__(cls, **kw):
388 390 # This is needed because in Python 2.6 object.__new__ only accepts
389 391 # the cls argument.
390 392 new_meth = super(HasTraits, cls).__new__
391 393 if new_meth is object.__new__:
392 394 inst = new_meth(cls)
393 395 else:
394 396 inst = new_meth(cls, **kw)
395 397 inst._trait_values = {}
396 398 inst._trait_notifiers = {}
397 399 inst._trait_dyn_inits = {}
398 400 # Here we tell all the TraitType instances to set their default
399 401 # values on the instance.
400 402 for key in dir(cls):
401 403 # Some descriptors raise AttributeError like zope.interface's
402 404 # __provides__ attributes even though they exist. This causes
403 405 # AttributeErrors even though they are listed in dir(cls).
404 406 try:
405 407 value = getattr(cls, key)
406 408 except AttributeError:
407 409 pass
408 410 else:
409 411 if isinstance(value, TraitType):
410 412 value.instance_init(inst)
411 413
412 414 return inst
413 415
414 416 def __init__(self, **kw):
415 417 # Allow trait values to be set using keyword arguments.
416 418 # We need to use setattr for this to trigger validation and
417 419 # notifications.
418 420 for key, value in kw.iteritems():
419 421 setattr(self, key, value)
420 422
421 423 def _notify_trait(self, name, old_value, new_value):
422 424
423 425 # First dynamic ones
424 426 callables = self._trait_notifiers.get(name,[])
425 427 more_callables = self._trait_notifiers.get('anytrait',[])
426 428 callables.extend(more_callables)
427 429
428 430 # Now static ones
429 431 try:
430 432 cb = getattr(self, '_%s_changed' % name)
431 433 except:
432 434 pass
433 435 else:
434 436 callables.append(cb)
435 437
436 438 # Call them all now
437 439 for c in callables:
438 440 # Traits catches and logs errors here. I allow them to raise
439 441 if callable(c):
440 442 argspec = inspect.getargspec(c)
441 443 nargs = len(argspec[0])
442 444 # Bound methods have an additional 'self' argument
443 445 # I don't know how to treat unbound methods, but they
444 446 # can't really be used for callbacks.
445 447 if isinstance(c, types.MethodType):
446 448 offset = -1
447 449 else:
448 450 offset = 0
449 451 if nargs + offset == 0:
450 452 c()
451 453 elif nargs + offset == 1:
452 454 c(name)
453 455 elif nargs + offset == 2:
454 456 c(name, new_value)
455 457 elif nargs + offset == 3:
456 458 c(name, old_value, new_value)
457 459 else:
458 460 raise TraitError('a trait changed callback '
459 461 'must have 0-3 arguments.')
460 462 else:
461 463 raise TraitError('a trait changed callback '
462 464 'must be callable.')
463 465
464 466
465 467 def _add_notifiers(self, handler, name):
466 468 if not self._trait_notifiers.has_key(name):
467 469 nlist = []
468 470 self._trait_notifiers[name] = nlist
469 471 else:
470 472 nlist = self._trait_notifiers[name]
471 473 if handler not in nlist:
472 474 nlist.append(handler)
473 475
474 476 def _remove_notifiers(self, handler, name):
475 477 if self._trait_notifiers.has_key(name):
476 478 nlist = self._trait_notifiers[name]
477 479 try:
478 480 index = nlist.index(handler)
479 481 except ValueError:
480 482 pass
481 483 else:
482 484 del nlist[index]
483 485
484 486 def on_trait_change(self, handler, name=None, remove=False):
485 487 """Setup a handler to be called when a trait changes.
486 488
487 489 This is used to setup dynamic notifications of trait changes.
488 490
489 491 Static handlers can be created by creating methods on a HasTraits
490 492 subclass with the naming convention '_[traitname]_changed'. Thus,
491 493 to create static handler for the trait 'a', create the method
492 494 _a_changed(self, name, old, new) (fewer arguments can be used, see
493 495 below).
494 496
495 497 Parameters
496 498 ----------
497 499 handler : callable
498 500 A callable that is called when a trait changes. Its
499 501 signature can be handler(), handler(name), handler(name, new)
500 502 or handler(name, old, new).
501 503 name : list, str, None
502 504 If None, the handler will apply to all traits. If a list
503 505 of str, handler will apply to all names in the list. If a
504 506 str, the handler will apply just to that name.
505 507 remove : bool
506 508 If False (the default), then install the handler. If True
507 509 then unintall it.
508 510 """
509 511 if remove:
510 512 names = parse_notifier_name(name)
511 513 for n in names:
512 514 self._remove_notifiers(handler, n)
513 515 else:
514 516 names = parse_notifier_name(name)
515 517 for n in names:
516 518 self._add_notifiers(handler, n)
517 519
518 520 @classmethod
519 521 def class_trait_names(cls, **metadata):
520 522 """Get a list of all the names of this classes traits.
521 523
522 524 This method is just like the :meth:`trait_names` method, but is unbound.
523 525 """
524 526 return cls.class_traits(**metadata).keys()
525 527
526 528 @classmethod
527 529 def class_traits(cls, **metadata):
528 530 """Get a list of all the traits of this class.
529 531
530 532 This method is just like the :meth:`traits` method, but is unbound.
531 533
532 534 The TraitTypes returned don't know anything about the values
533 535 that the various HasTrait's instances are holding.
534 536
535 537 This follows the same algorithm as traits does and does not allow
536 538 for any simple way of specifying merely that a metadata name
537 539 exists, but has any value. This is because get_metadata returns
538 540 None if a metadata key doesn't exist.
539 541 """
540 542 traits = dict([memb for memb in getmembers(cls) if \
541 543 isinstance(memb[1], TraitType)])
542 544
543 545 if len(metadata) == 0:
544 546 return traits
545 547
546 548 for meta_name, meta_eval in metadata.items():
547 549 if type(meta_eval) is not FunctionType:
548 550 metadata[meta_name] = _SimpleTest(meta_eval)
549 551
550 552 result = {}
551 553 for name, trait in traits.items():
552 554 for meta_name, meta_eval in metadata.items():
553 555 if not meta_eval(trait.get_metadata(meta_name)):
554 556 break
555 557 else:
556 558 result[name] = trait
557 559
558 560 return result
559 561
560 562 def trait_names(self, **metadata):
561 563 """Get a list of all the names of this classes traits."""
562 564 return self.traits(**metadata).keys()
563 565
564 566 def traits(self, **metadata):
565 567 """Get a list of all the traits of this class.
566 568
567 569 The TraitTypes returned don't know anything about the values
568 570 that the various HasTrait's instances are holding.
569 571
570 572 This follows the same algorithm as traits does and does not allow
571 573 for any simple way of specifying merely that a metadata name
572 574 exists, but has any value. This is because get_metadata returns
573 575 None if a metadata key doesn't exist.
574 576 """
575 577 traits = dict([memb for memb in getmembers(self.__class__) if \
576 578 isinstance(memb[1], TraitType)])
577 579
578 580 if len(metadata) == 0:
579 581 return traits
580 582
581 583 for meta_name, meta_eval in metadata.items():
582 584 if type(meta_eval) is not FunctionType:
583 585 metadata[meta_name] = _SimpleTest(meta_eval)
584 586
585 587 result = {}
586 588 for name, trait in traits.items():
587 589 for meta_name, meta_eval in metadata.items():
588 590 if not meta_eval(trait.get_metadata(meta_name)):
589 591 break
590 592 else:
591 593 result[name] = trait
592 594
593 595 return result
594 596
595 597 def trait_metadata(self, traitname, key):
596 598 """Get metadata values for trait by key."""
597 599 try:
598 600 trait = getattr(self.__class__, traitname)
599 601 except AttributeError:
600 602 raise TraitError("Class %s does not have a trait named %s" %
601 603 (self.__class__.__name__, traitname))
602 604 else:
603 605 return trait.get_metadata(key)
604 606
605 607 #-----------------------------------------------------------------------------
606 608 # Actual TraitTypes implementations/subclasses
607 609 #-----------------------------------------------------------------------------
608 610
609 611 #-----------------------------------------------------------------------------
610 612 # TraitTypes subclasses for handling classes and instances of classes
611 613 #-----------------------------------------------------------------------------
612 614
613 615
614 616 class ClassBasedTraitType(TraitType):
615 617 """A trait with error reporting for Type, Instance and This."""
616 618
617 619 def error(self, obj, value):
618 620 kind = type(value)
619 if kind is InstanceType:
621 if (not py3compat.PY3) and kind is InstanceType:
620 622 msg = 'class %s' % value.__class__.__name__
621 623 else:
622 624 msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
623 625
624 626 if obj is not None:
625 627 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
626 628 % (self.name, class_of(obj),
627 629 self.info(), msg)
628 630 else:
629 631 e = "The '%s' trait must be %s, but a value of %r was specified." \
630 632 % (self.name, self.info(), msg)
631 633
632 634 raise TraitError(e)
633 635
634 636
635 637 class Type(ClassBasedTraitType):
636 638 """A trait whose value must be a subclass of a specified class."""
637 639
638 640 def __init__ (self, default_value=None, klass=None, allow_none=True, **metadata ):
639 641 """Construct a Type trait
640 642
641 643 A Type trait specifies that its values must be subclasses of
642 644 a particular class.
643 645
644 646 If only ``default_value`` is given, it is used for the ``klass`` as
645 647 well.
646 648
647 649 Parameters
648 650 ----------
649 651 default_value : class, str or None
650 652 The default value must be a subclass of klass. If an str,
651 653 the str must be a fully specified class name, like 'foo.bar.Bah'.
652 654 The string is resolved into real class, when the parent
653 655 :class:`HasTraits` class is instantiated.
654 656 klass : class, str, None
655 657 Values of this trait must be a subclass of klass. The klass
656 658 may be specified in a string like: 'foo.bar.MyClass'.
657 659 The string is resolved into real class, when the parent
658 660 :class:`HasTraits` class is instantiated.
659 661 allow_none : boolean
660 662 Indicates whether None is allowed as an assignable value. Even if
661 663 ``False``, the default value may be ``None``.
662 664 """
663 665 if default_value is None:
664 666 if klass is None:
665 667 klass = object
666 668 elif klass is None:
667 669 klass = default_value
668 670
669 671 if not (inspect.isclass(klass) or isinstance(klass, basestring)):
670 672 raise TraitError("A Type trait must specify a class.")
671 673
672 674 self.klass = klass
673 675 self._allow_none = allow_none
674 676
675 677 super(Type, self).__init__(default_value, **metadata)
676 678
677 679 def validate(self, obj, value):
678 680 """Validates that the value is a valid object instance."""
679 681 try:
680 682 if issubclass(value, self.klass):
681 683 return value
682 684 except:
683 685 if (value is None) and (self._allow_none):
684 686 return value
685 687
686 688 self.error(obj, value)
687 689
688 690 def info(self):
689 691 """ Returns a description of the trait."""
690 692 if isinstance(self.klass, basestring):
691 693 klass = self.klass
692 694 else:
693 695 klass = self.klass.__name__
694 696 result = 'a subclass of ' + klass
695 697 if self._allow_none:
696 698 return result + ' or None'
697 699 return result
698 700
699 701 def instance_init(self, obj):
700 702 self._resolve_classes()
701 703 super(Type, self).instance_init(obj)
702 704
703 705 def _resolve_classes(self):
704 706 if isinstance(self.klass, basestring):
705 707 self.klass = import_item(self.klass)
706 708 if isinstance(self.default_value, basestring):
707 709 self.default_value = import_item(self.default_value)
708 710
709 711 def get_default_value(self):
710 712 return self.default_value
711 713
712 714
713 715 class DefaultValueGenerator(object):
714 716 """A class for generating new default value instances."""
715 717
716 718 def __init__(self, *args, **kw):
717 719 self.args = args
718 720 self.kw = kw
719 721
720 722 def generate(self, klass):
721 723 return klass(*self.args, **self.kw)
722 724
723 725
724 726 class Instance(ClassBasedTraitType):
725 727 """A trait whose value must be an instance of a specified class.
726 728
727 729 The value can also be an instance of a subclass of the specified class.
728 730 """
729 731
730 732 def __init__(self, klass=None, args=None, kw=None,
731 733 allow_none=True, **metadata ):
732 734 """Construct an Instance trait.
733 735
734 736 This trait allows values that are instances of a particular
735 737 class or its sublclasses. Our implementation is quite different
736 738 from that of enthough.traits as we don't allow instances to be used
737 739 for klass and we handle the ``args`` and ``kw`` arguments differently.
738 740
739 741 Parameters
740 742 ----------
741 743 klass : class, str
742 744 The class that forms the basis for the trait. Class names
743 745 can also be specified as strings, like 'foo.bar.Bar'.
744 746 args : tuple
745 747 Positional arguments for generating the default value.
746 748 kw : dict
747 749 Keyword arguments for generating the default value.
748 750 allow_none : bool
749 751 Indicates whether None is allowed as a value.
750 752
751 753 Default Value
752 754 -------------
753 755 If both ``args`` and ``kw`` are None, then the default value is None.
754 756 If ``args`` is a tuple and ``kw`` is a dict, then the default is
755 757 created as ``klass(*args, **kw)``. If either ``args`` or ``kw`` is
756 758 not (but not both), None is replace by ``()`` or ``{}``.
757 759 """
758 760
759 761 self._allow_none = allow_none
760 762
761 763 if (klass is None) or (not (inspect.isclass(klass) or isinstance(klass, basestring))):
762 764 raise TraitError('The klass argument must be a class'
763 765 ' you gave: %r' % klass)
764 766 self.klass = klass
765 767
766 768 # self.klass is a class, so handle default_value
767 769 if args is None and kw is None:
768 770 default_value = None
769 771 else:
770 772 if args is None:
771 773 # kw is not None
772 774 args = ()
773 775 elif kw is None:
774 776 # args is not None
775 777 kw = {}
776 778
777 779 if not isinstance(kw, dict):
778 780 raise TraitError("The 'kw' argument must be a dict or None.")
779 781 if not isinstance(args, tuple):
780 782 raise TraitError("The 'args' argument must be a tuple or None.")
781 783
782 784 default_value = DefaultValueGenerator(*args, **kw)
783 785
784 786 super(Instance, self).__init__(default_value, **metadata)
785 787
786 788 def validate(self, obj, value):
787 789 if value is None:
788 790 if self._allow_none:
789 791 return value
790 792 self.error(obj, value)
791 793
792 794 if isinstance(value, self.klass):
793 795 return value
794 796 else:
795 797 self.error(obj, value)
796 798
797 799 def info(self):
798 800 if isinstance(self.klass, basestring):
799 801 klass = self.klass
800 802 else:
801 803 klass = self.klass.__name__
802 804 result = class_of(klass)
803 805 if self._allow_none:
804 806 return result + ' or None'
805 807
806 808 return result
807 809
808 810 def instance_init(self, obj):
809 811 self._resolve_classes()
810 812 super(Instance, self).instance_init(obj)
811 813
812 814 def _resolve_classes(self):
813 815 if isinstance(self.klass, basestring):
814 816 self.klass = import_item(self.klass)
815 817
816 818 def get_default_value(self):
817 819 """Instantiate a default value instance.
818 820
819 821 This is called when the containing HasTraits classes'
820 822 :meth:`__new__` method is called to ensure that a unique instance
821 823 is created for each HasTraits instance.
822 824 """
823 825 dv = self.default_value
824 826 if isinstance(dv, DefaultValueGenerator):
825 827 return dv.generate(self.klass)
826 828 else:
827 829 return dv
828 830
829 831
830 832 class This(ClassBasedTraitType):
831 833 """A trait for instances of the class containing this trait.
832 834
833 835 Because how how and when class bodies are executed, the ``This``
834 836 trait can only have a default value of None. This, and because we
835 837 always validate default values, ``allow_none`` is *always* true.
836 838 """
837 839
838 840 info_text = 'an instance of the same type as the receiver or None'
839 841
840 842 def __init__(self, **metadata):
841 843 super(This, self).__init__(None, **metadata)
842 844
843 845 def validate(self, obj, value):
844 846 # What if value is a superclass of obj.__class__? This is
845 847 # complicated if it was the superclass that defined the This
846 848 # trait.
847 849 if isinstance(value, self.this_class) or (value is None):
848 850 return value
849 851 else:
850 852 self.error(obj, value)
851 853
852 854
853 855 #-----------------------------------------------------------------------------
854 856 # Basic TraitTypes implementations/subclasses
855 857 #-----------------------------------------------------------------------------
856 858
857 859
858 860 class Any(TraitType):
859 861 default_value = None
860 862 info_text = 'any value'
861 863
862 864
863 865 class Int(TraitType):
864 866 """A integer trait."""
865 867
866 868 default_value = 0
867 869 info_text = 'an integer'
868 870
869 871 def validate(self, obj, value):
870 872 if isinstance(value, int):
871 873 return value
872 874 self.error(obj, value)
873 875
874 876 class CInt(Int):
875 877 """A casting version of the int trait."""
876 878
877 879 def validate(self, obj, value):
878 880 try:
879 881 return int(value)
880 882 except:
881 883 self.error(obj, value)
882 884
885 if not py3compat.PY3:
886 class Long(TraitType):
887 """A long integer trait."""
883 888
884 class Long(TraitType):
885 """A long integer trait."""
889 default_value = 0L
890 info_text = 'a long'
886 891
887 default_value = 0L
888 info_text = 'a long'
889
890 def validate(self, obj, value):
891 if isinstance(value, long):
892 return value
893 if isinstance(value, int):
894 return long(value)
895 self.error(obj, value)
892 def validate(self, obj, value):
893 if isinstance(value, long):
894 return value
895 if isinstance(value, int):
896 return long(value)
897 self.error(obj, value)
896 898
897 899
898 class CLong(Long):
899 """A casting version of the long integer trait."""
900 class CLong(Long):
901 """A casting version of the long integer trait."""
900 902
901 def validate(self, obj, value):
902 try:
903 return long(value)
904 except:
905 self.error(obj, value)
903 def validate(self, obj, value):
904 try:
905 return long(value)
906 except:
907 self.error(obj, value)
906 908
907 909
908 910 class Float(TraitType):
909 911 """A float trait."""
910 912
911 913 default_value = 0.0
912 914 info_text = 'a float'
913 915
914 916 def validate(self, obj, value):
915 917 if isinstance(value, float):
916 918 return value
917 919 if isinstance(value, int):
918 920 return float(value)
919 921 self.error(obj, value)
920 922
921 923
922 924 class CFloat(Float):
923 925 """A casting version of the float trait."""
924 926
925 927 def validate(self, obj, value):
926 928 try:
927 929 return float(value)
928 930 except:
929 931 self.error(obj, value)
930 932
931 933 class Complex(TraitType):
932 934 """A trait for complex numbers."""
933 935
934 936 default_value = 0.0 + 0.0j
935 937 info_text = 'a complex number'
936 938
937 939 def validate(self, obj, value):
938 940 if isinstance(value, complex):
939 941 return value
940 942 if isinstance(value, (float, int)):
941 943 return complex(value)
942 944 self.error(obj, value)
943 945
944 946
945 947 class CComplex(Complex):
946 948 """A casting version of the complex number trait."""
947 949
948 950 def validate (self, obj, value):
949 951 try:
950 952 return complex(value)
951 953 except:
952 954 self.error(obj, value)
953 955
954 956 # We should always be explicit about whether we're using bytes or unicode, both
955 957 # for Python 3 conversion and for reliable unicode behaviour on Python 2. So
956 958 # we don't have a Str type.
957 959 class Bytes(TraitType):
958 """A trait for strings."""
960 """A trait for byte strings."""
959 961
960 962 default_value = ''
961 963 info_text = 'a string'
962 964
963 965 def validate(self, obj, value):
964 966 if isinstance(value, bytes):
965 967 return value
966 968 self.error(obj, value)
967 969
968 970
969 971 class CBytes(Bytes):
970 """A casting version of the string trait."""
972 """A casting version of the byte string trait."""
971 973
972 974 def validate(self, obj, value):
973 975 try:
974 976 return bytes(value)
975 977 except:
976 978 self.error(obj, value)
977 979
978 980
979 981 class Unicode(TraitType):
980 982 """A trait for unicode strings."""
981 983
982 984 default_value = u''
983 985 info_text = 'a unicode string'
984 986
985 987 def validate(self, obj, value):
986 988 if isinstance(value, unicode):
987 989 return value
988 990 if isinstance(value, bytes):
989 991 return unicode(value)
990 992 self.error(obj, value)
991 993
992 994
993 995 class CUnicode(Unicode):
994 996 """A casting version of the unicode trait."""
995 997
996 998 def validate(self, obj, value):
997 999 try:
998 1000 return unicode(value)
999 1001 except:
1000 1002 self.error(obj, value)
1001 1003
1002 1004
1003 1005 class ObjectName(TraitType):
1004 1006 """A string holding a valid object name in this version of Python.
1005 1007
1006 1008 This does not check that the name exists in any scope."""
1007 1009 info_text = "a valid object identifier in Python"
1008 1010
1009 1011 if sys.version_info[0] < 3:
1010 1012 # Python 2:
1011 1013 _name_re = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*$")
1012 1014 def isidentifier(self, s):
1013 1015 return bool(self._name_re.match(s))
1014 1016
1015 1017 def coerce_str(self, obj, value):
1016 1018 "In Python 2, coerce ascii-only unicode to str"
1017 1019 if isinstance(value, unicode):
1018 1020 try:
1019 1021 return str(value)
1020 1022 except UnicodeEncodeError:
1021 1023 self.error(obj, value)
1022 1024 return value
1023 1025
1024 1026 else:
1025 1027 # Python 3:
1026 1028 isidentifier = staticmethod(lambda s: s.isidentifier())
1027 1029 coerce_str = staticmethod(lambda _,s: s)
1028 1030
1029 1031 def validate(self, obj, value):
1030 1032 value = self.coerce_str(obj, value)
1031 1033
1032 1034 if isinstance(value, str) and self.isidentifier(value):
1033 1035 return value
1034 1036 self.error(obj, value)
1035 1037
1036 1038 class DottedObjectName(ObjectName):
1037 1039 """A string holding a valid dotted object name in Python, such as A.b3._c"""
1038 1040 def validate(self, obj, value):
1039 1041 value = self.coerce_str(obj, value)
1040 1042
1041 1043 if isinstance(value, str) and all(self.isidentifier(x) \
1042 1044 for x in value.split('.')):
1043 1045 return value
1044 1046 self.error(obj, value)
1045 1047
1046 1048
1047 1049 class Bool(TraitType):
1048 1050 """A boolean (True, False) trait."""
1049 1051
1050 1052 default_value = False
1051 1053 info_text = 'a boolean'
1052 1054
1053 1055 def validate(self, obj, value):
1054 1056 if isinstance(value, bool):
1055 1057 return value
1056 1058 self.error(obj, value)
1057 1059
1058 1060
1059 1061 class CBool(Bool):
1060 1062 """A casting version of the boolean trait."""
1061 1063
1062 1064 def validate(self, obj, value):
1063 1065 try:
1064 1066 return bool(value)
1065 1067 except:
1066 1068 self.error(obj, value)
1067 1069
1068 1070
1069 1071 class Enum(TraitType):
1070 1072 """An enum that whose value must be in a given sequence."""
1071 1073
1072 1074 def __init__(self, values, default_value=None, allow_none=True, **metadata):
1073 1075 self.values = values
1074 1076 self._allow_none = allow_none
1075 1077 super(Enum, self).__init__(default_value, **metadata)
1076 1078
1077 1079 def validate(self, obj, value):
1078 1080 if value is None:
1079 1081 if self._allow_none:
1080 1082 return value
1081 1083
1082 1084 if value in self.values:
1083 1085 return value
1084 1086 self.error(obj, value)
1085 1087
1086 1088 def info(self):
1087 1089 """ Returns a description of the trait."""
1088 1090 result = 'any of ' + repr(self.values)
1089 1091 if self._allow_none:
1090 1092 return result + ' or None'
1091 1093 return result
1092 1094
1093 1095 class CaselessStrEnum(Enum):
1094 1096 """An enum of strings that are caseless in validate."""
1095 1097
1096 1098 def validate(self, obj, value):
1097 1099 if value is None:
1098 1100 if self._allow_none:
1099 1101 return value
1100 1102
1101 1103 if not isinstance(value, basestring):
1102 1104 self.error(obj, value)
1103 1105
1104 1106 for v in self.values:
1105 1107 if v.lower() == value.lower():
1106 1108 return v
1107 1109 self.error(obj, value)
1108 1110
1109 1111 class Container(Instance):
1110 1112 """An instance of a container (list, set, etc.)
1111 1113
1112 1114 To be subclassed by overriding klass.
1113 1115 """
1114 1116 klass = None
1115 1117 _valid_defaults = SequenceTypes
1116 1118 _trait = None
1117 1119
1118 1120 def __init__(self, trait=None, default_value=None, allow_none=True,
1119 1121 **metadata):
1120 1122 """Create a container trait type from a list, set, or tuple.
1121 1123
1122 1124 The default value is created by doing ``List(default_value)``,
1123 1125 which creates a copy of the ``default_value``.
1124 1126
1125 1127 ``trait`` can be specified, which restricts the type of elements
1126 1128 in the container to that TraitType.
1127 1129
1128 1130 If only one arg is given and it is not a Trait, it is taken as
1129 1131 ``default_value``:
1130 1132
1131 1133 ``c = List([1,2,3])``
1132 1134
1133 1135 Parameters
1134 1136 ----------
1135 1137
1136 1138 trait : TraitType [ optional ]
1137 1139 the type for restricting the contents of the Container. If unspecified,
1138 1140 types are not checked.
1139 1141
1140 1142 default_value : SequenceType [ optional ]
1141 1143 The default value for the Trait. Must be list/tuple/set, and
1142 1144 will be cast to the container type.
1143 1145
1144 1146 allow_none : Bool [ default True ]
1145 1147 Whether to allow the value to be None
1146 1148
1147 1149 **metadata : any
1148 1150 further keys for extensions to the Trait (e.g. config)
1149 1151
1150 1152 """
1151 1153 istrait = lambda t: isinstance(t, type) and issubclass(t, TraitType)
1152 1154
1153 1155 # allow List([values]):
1154 1156 if default_value is None and not istrait(trait):
1155 1157 default_value = trait
1156 1158 trait = None
1157 1159
1158 1160 if default_value is None:
1159 1161 args = ()
1160 1162 elif isinstance(default_value, self._valid_defaults):
1161 1163 args = (default_value,)
1162 1164 else:
1163 1165 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1164 1166
1165 1167 if istrait(trait):
1166 1168 self._trait = trait()
1167 1169 self._trait.name = 'element'
1168 1170 elif trait is not None:
1169 1171 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1170 1172
1171 1173 super(Container,self).__init__(klass=self.klass, args=args,
1172 1174 allow_none=allow_none, **metadata)
1173 1175
1174 1176 def element_error(self, obj, element, validator):
1175 1177 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1176 1178 % (self.name, class_of(obj), validator.info(), repr_type(element))
1177 1179 raise TraitError(e)
1178 1180
1179 1181 def validate(self, obj, value):
1180 1182 value = super(Container, self).validate(obj, value)
1181 1183 if value is None:
1182 1184 return value
1183 1185
1184 1186 value = self.validate_elements(obj, value)
1185 1187
1186 1188 return value
1187 1189
1188 1190 def validate_elements(self, obj, value):
1189 1191 validated = []
1190 1192 if self._trait is None or isinstance(self._trait, Any):
1191 1193 return value
1192 1194 for v in value:
1193 1195 try:
1194 1196 v = self._trait.validate(obj, v)
1195 1197 except TraitError:
1196 1198 self.element_error(obj, v, self._trait)
1197 1199 else:
1198 1200 validated.append(v)
1199 1201 return self.klass(validated)
1200 1202
1201 1203
1202 1204 class List(Container):
1203 1205 """An instance of a Python list."""
1204 1206 klass = list
1205 1207
1206 1208 def __init__(self, trait=None, default_value=None, minlen=0, maxlen=sys.maxint,
1207 1209 allow_none=True, **metadata):
1208 1210 """Create a List trait type from a list, set, or tuple.
1209 1211
1210 1212 The default value is created by doing ``List(default_value)``,
1211 1213 which creates a copy of the ``default_value``.
1212 1214
1213 1215 ``trait`` can be specified, which restricts the type of elements
1214 1216 in the container to that TraitType.
1215 1217
1216 1218 If only one arg is given and it is not a Trait, it is taken as
1217 1219 ``default_value``:
1218 1220
1219 1221 ``c = List([1,2,3])``
1220 1222
1221 1223 Parameters
1222 1224 ----------
1223 1225
1224 1226 trait : TraitType [ optional ]
1225 1227 the type for restricting the contents of the Container. If unspecified,
1226 1228 types are not checked.
1227 1229
1228 1230 default_value : SequenceType [ optional ]
1229 1231 The default value for the Trait. Must be list/tuple/set, and
1230 1232 will be cast to the container type.
1231 1233
1232 1234 minlen : Int [ default 0 ]
1233 1235 The minimum length of the input list
1234 1236
1235 1237 maxlen : Int [ default sys.maxint ]
1236 1238 The maximum length of the input list
1237 1239
1238 1240 allow_none : Bool [ default True ]
1239 1241 Whether to allow the value to be None
1240 1242
1241 1243 **metadata : any
1242 1244 further keys for extensions to the Trait (e.g. config)
1243 1245
1244 1246 """
1245 1247 self._minlen = minlen
1246 1248 self._maxlen = maxlen
1247 1249 super(List, self).__init__(trait=trait, default_value=default_value,
1248 1250 allow_none=allow_none, **metadata)
1249 1251
1250 1252 def length_error(self, obj, value):
1251 1253 e = "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." \
1252 1254 % (self.name, class_of(obj), self._minlen, self._maxlen, value)
1253 1255 raise TraitError(e)
1254 1256
1255 1257 def validate_elements(self, obj, value):
1256 1258 length = len(value)
1257 1259 if length < self._minlen or length > self._maxlen:
1258 1260 self.length_error(obj, value)
1259 1261
1260 1262 return super(List, self).validate_elements(obj, value)
1261 1263
1262 1264
1263 1265 class Set(Container):
1264 1266 """An instance of a Python set."""
1265 1267 klass = set
1266 1268
1267 1269 class Tuple(Container):
1268 1270 """An instance of a Python tuple."""
1269 1271 klass = tuple
1270 1272
1271 1273 def __init__(self, *traits, **metadata):
1272 1274 """Tuple(*traits, default_value=None, allow_none=True, **medatata)
1273 1275
1274 1276 Create a tuple from a list, set, or tuple.
1275 1277
1276 1278 Create a fixed-type tuple with Traits:
1277 1279
1278 1280 ``t = Tuple(Int, Str, CStr)``
1279 1281
1280 1282 would be length 3, with Int,Str,CStr for each element.
1281 1283
1282 1284 If only one arg is given and it is not a Trait, it is taken as
1283 1285 default_value:
1284 1286
1285 1287 ``t = Tuple((1,2,3))``
1286 1288
1287 1289 Otherwise, ``default_value`` *must* be specified by keyword.
1288 1290
1289 1291 Parameters
1290 1292 ----------
1291 1293
1292 1294 *traits : TraitTypes [ optional ]
1293 1295 the tsype for restricting the contents of the Tuple. If unspecified,
1294 1296 types are not checked. If specified, then each positional argument
1295 1297 corresponds to an element of the tuple. Tuples defined with traits
1296 1298 are of fixed length.
1297 1299
1298 1300 default_value : SequenceType [ optional ]
1299 1301 The default value for the Tuple. Must be list/tuple/set, and
1300 1302 will be cast to a tuple. If `traits` are specified, the
1301 1303 `default_value` must conform to the shape and type they specify.
1302 1304
1303 1305 allow_none : Bool [ default True ]
1304 1306 Whether to allow the value to be None
1305 1307
1306 1308 **metadata : any
1307 1309 further keys for extensions to the Trait (e.g. config)
1308 1310
1309 1311 """
1310 1312 default_value = metadata.pop('default_value', None)
1311 1313 allow_none = metadata.pop('allow_none', True)
1312 1314
1313 1315 istrait = lambda t: isinstance(t, type) and issubclass(t, TraitType)
1314 1316
1315 1317 # allow Tuple((values,)):
1316 1318 if len(traits) == 1 and default_value is None and not istrait(traits[0]):
1317 1319 default_value = traits[0]
1318 1320 traits = ()
1319 1321
1320 1322 if default_value is None:
1321 1323 args = ()
1322 1324 elif isinstance(default_value, self._valid_defaults):
1323 1325 args = (default_value,)
1324 1326 else:
1325 1327 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1326 1328
1327 1329 self._traits = []
1328 1330 for trait in traits:
1329 1331 t = trait()
1330 1332 t.name = 'element'
1331 1333 self._traits.append(t)
1332 1334
1333 1335 if self._traits and default_value is None:
1334 1336 # don't allow default to be an empty container if length is specified
1335 1337 args = None
1336 1338 super(Container,self).__init__(klass=self.klass, args=args,
1337 1339 allow_none=allow_none, **metadata)
1338 1340
1339 1341 def validate_elements(self, obj, value):
1340 1342 if not self._traits:
1341 1343 # nothing to validate
1342 1344 return value
1343 1345 if len(value) != len(self._traits):
1344 1346 e = "The '%s' trait of %s instance requires %i elements, but a value of %s was specified." \
1345 1347 % (self.name, class_of(obj), len(self._traits), repr_type(value))
1346 1348 raise TraitError(e)
1347 1349
1348 1350 validated = []
1349 1351 for t,v in zip(self._traits, value):
1350 1352 try:
1351 1353 v = t.validate(obj, v)
1352 1354 except TraitError:
1353 1355 self.element_error(obj, v, t)
1354 1356 else:
1355 1357 validated.append(v)
1356 1358 return tuple(validated)
1357 1359
1358 1360
1359 1361 class Dict(Instance):
1360 1362 """An instance of a Python dict."""
1361 1363
1362 1364 def __init__(self, default_value=None, allow_none=True, **metadata):
1363 1365 """Create a dict trait type from a dict.
1364 1366
1365 1367 The default value is created by doing ``dict(default_value)``,
1366 1368 which creates a copy of the ``default_value``.
1367 1369 """
1368 1370 if default_value is None:
1369 1371 args = ((),)
1370 1372 elif isinstance(default_value, dict):
1371 1373 args = (default_value,)
1372 1374 elif isinstance(default_value, SequenceTypes):
1373 1375 args = (default_value,)
1374 1376 else:
1375 1377 raise TypeError('default value of Dict was %s' % default_value)
1376 1378
1377 1379 super(Dict,self).__init__(klass=dict, args=args,
1378 1380 allow_none=allow_none, **metadata)
1379 1381
1380 1382 class TCPAddress(TraitType):
1381 1383 """A trait for an (ip, port) tuple.
1382 1384
1383 1385 This allows for both IPv4 IP addresses as well as hostnames.
1384 1386 """
1385 1387
1386 1388 default_value = ('127.0.0.1', 0)
1387 1389 info_text = 'an (ip, port) tuple'
1388 1390
1389 1391 def validate(self, obj, value):
1390 1392 if isinstance(value, tuple):
1391 1393 if len(value) == 2:
1392 1394 if isinstance(value[0], basestring) and isinstance(value[1], int):
1393 1395 port = value[1]
1394 1396 if port >= 0 and port <= 65535:
1395 1397 return value
1396 1398 self.error(obj, value)
@@ -1,703 +1,703 b''
1 1 """Session object for building, serializing, sending, and receiving messages in
2 2 IPython. The Session object supports serialization, HMAC signatures, and
3 3 metadata on messages.
4 4
5 5 Also defined here are utilities for working with Sessions:
6 6 * A SessionFactory to be used as a base class for configurables that work with
7 7 Sessions.
8 8 * A Message object for convenience that allows attribute-access to the msg dict.
9 9
10 10 Authors:
11 11
12 12 * Min RK
13 13 * Brian Granger
14 14 * Fernando Perez
15 15 """
16 16 #-----------------------------------------------------------------------------
17 17 # Copyright (C) 2010-2011 The IPython Development Team
18 18 #
19 19 # Distributed under the terms of the BSD License. The full license is in
20 20 # the file COPYING, distributed as part of this software.
21 21 #-----------------------------------------------------------------------------
22 22
23 23 #-----------------------------------------------------------------------------
24 24 # Imports
25 25 #-----------------------------------------------------------------------------
26 26
27 27 import hmac
28 28 import logging
29 29 import os
30 30 import pprint
31 31 import uuid
32 32 from datetime import datetime
33 33
34 34 try:
35 35 import cPickle
36 36 pickle = cPickle
37 37 except:
38 38 cPickle = None
39 39 import pickle
40 40
41 41 import zmq
42 42 from zmq.utils import jsonapi
43 43 from zmq.eventloop.ioloop import IOLoop
44 44 from zmq.eventloop.zmqstream import ZMQStream
45 45
46 46 from IPython.config.configurable import Configurable, LoggingConfigurable
47 47 from IPython.utils.importstring import import_item
48 48 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
49 49 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
50 50 DottedObjectName)
51 51
52 52 #-----------------------------------------------------------------------------
53 53 # utility functions
54 54 #-----------------------------------------------------------------------------
55 55
56 56 def squash_unicode(obj):
57 57 """coerce unicode back to bytestrings."""
58 58 if isinstance(obj,dict):
59 59 for key in obj.keys():
60 60 obj[key] = squash_unicode(obj[key])
61 61 if isinstance(key, unicode):
62 62 obj[squash_unicode(key)] = obj.pop(key)
63 63 elif isinstance(obj, list):
64 64 for i,v in enumerate(obj):
65 65 obj[i] = squash_unicode(v)
66 66 elif isinstance(obj, unicode):
67 67 obj = obj.encode('utf8')
68 68 return obj
69 69
70 70 #-----------------------------------------------------------------------------
71 71 # globals and defaults
72 72 #-----------------------------------------------------------------------------
73 73 key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
74 74 json_packer = lambda obj: jsonapi.dumps(obj, **{key:date_default})
75 75 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
76 76
77 77 pickle_packer = lambda o: pickle.dumps(o,-1)
78 78 pickle_unpacker = pickle.loads
79 79
80 80 default_packer = json_packer
81 81 default_unpacker = json_unpacker
82 82
83 83
84 84 DELIM=b"<IDS|MSG>"
85 85
86 86 #-----------------------------------------------------------------------------
87 87 # Classes
88 88 #-----------------------------------------------------------------------------
89 89
90 90 class SessionFactory(LoggingConfigurable):
91 91 """The Base class for configurables that have a Session, Context, logger,
92 92 and IOLoop.
93 93 """
94 94
95 95 logname = Unicode('')
96 96 def _logname_changed(self, name, old, new):
97 97 self.log = logging.getLogger(new)
98 98
99 99 # not configurable:
100 100 context = Instance('zmq.Context')
101 101 def _context_default(self):
102 102 return zmq.Context.instance()
103 103
104 104 session = Instance('IPython.zmq.session.Session')
105 105
106 106 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
107 107 def _loop_default(self):
108 108 return IOLoop.instance()
109 109
110 110 def __init__(self, **kwargs):
111 111 super(SessionFactory, self).__init__(**kwargs)
112 112
113 113 if self.session is None:
114 114 # construct the session
115 115 self.session = Session(**kwargs)
116 116
117 117
118 118 class Message(object):
119 119 """A simple message object that maps dict keys to attributes.
120 120
121 121 A Message can be created from a dict and a dict from a Message instance
122 122 simply by calling dict(msg_obj)."""
123 123
124 124 def __init__(self, msg_dict):
125 125 dct = self.__dict__
126 126 for k, v in dict(msg_dict).iteritems():
127 127 if isinstance(v, dict):
128 128 v = Message(v)
129 129 dct[k] = v
130 130
131 131 # Having this iterator lets dict(msg_obj) work out of the box.
132 132 def __iter__(self):
133 133 return iter(self.__dict__.iteritems())
134 134
135 135 def __repr__(self):
136 136 return repr(self.__dict__)
137 137
138 138 def __str__(self):
139 139 return pprint.pformat(self.__dict__)
140 140
141 141 def __contains__(self, k):
142 142 return k in self.__dict__
143 143
144 144 def __getitem__(self, k):
145 145 return self.__dict__[k]
146 146
147 147
148 148 def msg_header(msg_id, msg_type, username, session):
149 149 date = datetime.now()
150 150 return locals()
151 151
152 152 def extract_header(msg_or_header):
153 153 """Given a message or header, return the header."""
154 154 if not msg_or_header:
155 155 return {}
156 156 try:
157 157 # See if msg_or_header is the entire message.
158 158 h = msg_or_header['header']
159 159 except KeyError:
160 160 try:
161 161 # See if msg_or_header is just the header
162 162 h = msg_or_header['msg_id']
163 163 except KeyError:
164 164 raise
165 165 else:
166 166 h = msg_or_header
167 167 if not isinstance(h, dict):
168 168 h = dict(h)
169 169 return h
170 170
171 171 class Session(Configurable):
172 172 """Object for handling serialization and sending of messages.
173 173
174 174 The Session object handles building messages and sending them
175 175 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
176 176 other over the network via Session objects, and only need to work with the
177 177 dict-based IPython message spec. The Session will handle
178 178 serialization/deserialization, security, and metadata.
179 179
180 180 Sessions support configurable serialiization via packer/unpacker traits,
181 181 and signing with HMAC digests via the key/keyfile traits.
182 182
183 183 Parameters
184 184 ----------
185 185
186 186 debug : bool
187 187 whether to trigger extra debugging statements
188 188 packer/unpacker : str : 'json', 'pickle' or import_string
189 189 importstrings for methods to serialize message parts. If just
190 190 'json' or 'pickle', predefined JSON and pickle packers will be used.
191 191 Otherwise, the entire importstring must be used.
192 192
193 193 The functions must accept at least valid JSON input, and output *bytes*.
194 194
195 195 For example, to use msgpack:
196 196 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
197 197 pack/unpack : callables
198 198 You can also set the pack/unpack callables for serialization directly.
199 199 session : bytes
200 200 the ID of this Session object. The default is to generate a new UUID.
201 201 username : unicode
202 202 username added to message headers. The default is to ask the OS.
203 203 key : bytes
204 204 The key used to initialize an HMAC signature. If unset, messages
205 205 will not be signed or checked.
206 206 keyfile : filepath
207 207 The file containing a key. If this is set, `key` will be initialized
208 208 to the contents of the file.
209 209
210 210 """
211 211
212 212 debug=Bool(False, config=True, help="""Debug output in the Session""")
213 213
214 214 packer = DottedObjectName('json',config=True,
215 215 help="""The name of the packer for serializing messages.
216 216 Should be one of 'json', 'pickle', or an import name
217 217 for a custom callable serializer.""")
218 218 def _packer_changed(self, name, old, new):
219 219 if new.lower() == 'json':
220 220 self.pack = json_packer
221 221 self.unpack = json_unpacker
222 222 elif new.lower() == 'pickle':
223 223 self.pack = pickle_packer
224 224 self.unpack = pickle_unpacker
225 225 else:
226 226 self.pack = import_item(str(new))
227 227
228 228 unpacker = DottedObjectName('json', config=True,
229 229 help="""The name of the unpacker for unserializing messages.
230 230 Only used with custom functions for `packer`.""")
231 231 def _unpacker_changed(self, name, old, new):
232 232 if new.lower() == 'json':
233 233 self.pack = json_packer
234 234 self.unpack = json_unpacker
235 235 elif new.lower() == 'pickle':
236 236 self.pack = pickle_packer
237 237 self.unpack = pickle_unpacker
238 238 else:
239 239 self.unpack = import_item(str(new))
240 240
241 241 session = CBytes(b'', config=True,
242 242 help="""The UUID identifying this session.""")
243 243 def _session_default(self):
244 244 return bytes(uuid.uuid4())
245 245
246 246 username = Unicode(os.environ.get('USER',u'username'), config=True,
247 247 help="""Username for the Session. Default is your system username.""")
248 248
249 249 # message signature related traits:
250 250 key = CBytes(b'', config=True,
251 251 help="""execution key, for extra authentication.""")
252 252 def _key_changed(self, name, old, new):
253 253 if new:
254 254 self.auth = hmac.HMAC(new)
255 255 else:
256 256 self.auth = None
257 257 auth = Instance(hmac.HMAC)
258 258 digest_history = Set()
259 259
260 260 keyfile = Unicode('', config=True,
261 261 help="""path to file containing execution key.""")
262 262 def _keyfile_changed(self, name, old, new):
263 263 with open(new, 'rb') as f:
264 264 self.key = f.read().strip()
265 265
266 266 pack = Any(default_packer) # the actual packer function
267 267 def _pack_changed(self, name, old, new):
268 268 if not callable(new):
269 269 raise TypeError("packer must be callable, not %s"%type(new))
270 270
271 271 unpack = Any(default_unpacker) # the actual packer function
272 272 def _unpack_changed(self, name, old, new):
273 273 # unpacker is not checked - it is assumed to be
274 274 if not callable(new):
275 275 raise TypeError("unpacker must be callable, not %s"%type(new))
276 276
277 277 def __init__(self, **kwargs):
278 278 """create a Session object
279 279
280 280 Parameters
281 281 ----------
282 282
283 283 debug : bool
284 284 whether to trigger extra debugging statements
285 285 packer/unpacker : str : 'json', 'pickle' or import_string
286 286 importstrings for methods to serialize message parts. If just
287 287 'json' or 'pickle', predefined JSON and pickle packers will be used.
288 288 Otherwise, the entire importstring must be used.
289 289
290 290 The functions must accept at least valid JSON input, and output
291 291 *bytes*.
292 292
293 293 For example, to use msgpack:
294 294 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
295 295 pack/unpack : callables
296 296 You can also set the pack/unpack callables for serialization
297 297 directly.
298 298 session : bytes
299 299 the ID of this Session object. The default is to generate a new
300 300 UUID.
301 301 username : unicode
302 302 username added to message headers. The default is to ask the OS.
303 303 key : bytes
304 304 The key used to initialize an HMAC signature. If unset, messages
305 305 will not be signed or checked.
306 306 keyfile : filepath
307 307 The file containing a key. If this is set, `key` will be
308 308 initialized to the contents of the file.
309 309 """
310 310 super(Session, self).__init__(**kwargs)
311 311 self._check_packers()
312 312 self.none = self.pack({})
313 313
314 314 @property
315 315 def msg_id(self):
316 316 """always return new uuid"""
317 317 return str(uuid.uuid4())
318 318
319 319 def _check_packers(self):
320 320 """check packers for binary data and datetime support."""
321 321 pack = self.pack
322 322 unpack = self.unpack
323 323
324 324 # check simple serialization
325 325 msg = dict(a=[1,'hi'])
326 326 try:
327 327 packed = pack(msg)
328 328 except Exception:
329 329 raise ValueError("packer could not serialize a simple message")
330 330
331 331 # ensure packed message is bytes
332 332 if not isinstance(packed, bytes):
333 333 raise ValueError("message packed to %r, but bytes are required"%type(packed))
334 334
335 335 # check that unpack is pack's inverse
336 336 try:
337 337 unpacked = unpack(packed)
338 338 except Exception:
339 339 raise ValueError("unpacker could not handle the packer's output")
340 340
341 341 # check datetime support
342 342 msg = dict(t=datetime.now())
343 343 try:
344 344 unpacked = unpack(pack(msg))
345 345 except Exception:
346 346 self.pack = lambda o: pack(squash_dates(o))
347 347 self.unpack = lambda s: extract_dates(unpack(s))
348 348
349 349 def msg_header(self, msg_type):
350 350 return msg_header(self.msg_id, msg_type, self.username, self.session)
351 351
352 352 def msg(self, msg_type, content=None, parent=None, subheader=None, header=None):
353 353 """Return the nested message dict.
354 354
355 355 This format is different from what is sent over the wire. The
356 356 serialize/unserialize methods converts this nested message dict to the wire
357 357 format, which is a list of message parts.
358 358 """
359 359 msg = {}
360 360 header = self.msg_header(msg_type) if header is None else header
361 361 msg['header'] = header
362 362 msg['msg_id'] = header['msg_id']
363 363 msg['msg_type'] = header['msg_type']
364 364 msg['parent_header'] = {} if parent is None else extract_header(parent)
365 365 msg['content'] = {} if content is None else content
366 366 sub = {} if subheader is None else subheader
367 367 msg['header'].update(sub)
368 368 return msg
369 369
370 370 def sign(self, msg_list):
371 371 """Sign a message with HMAC digest. If no auth, return b''.
372 372
373 373 Parameters
374 374 ----------
375 375 msg_list : list
376 376 The [p_header,p_parent,p_content] part of the message list.
377 377 """
378 378 if self.auth is None:
379 379 return b''
380 380 h = self.auth.copy()
381 381 for m in msg_list:
382 382 h.update(m)
383 return h.hexdigest()
383 return str_to_bytes(h.hexdigest())
384 384
385 385 def serialize(self, msg, ident=None):
386 386 """Serialize the message components to bytes.
387 387
388 388 This is roughly the inverse of unserialize. The serialize/unserialize
389 389 methods work with full message lists, whereas pack/unpack work with
390 390 the individual message parts in the message list.
391 391
392 392 Parameters
393 393 ----------
394 394 msg : dict or Message
395 395 The nexted message dict as returned by the self.msg method.
396 396
397 397 Returns
398 398 -------
399 399 msg_list : list
400 400 The list of bytes objects to be sent with the format:
401 401 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
402 402 buffer1,buffer2,...]. In this list, the p_* entities are
403 403 the packed or serialized versions, so if JSON is used, these
404 404 are uft8 encoded JSON strings.
405 405 """
406 406 content = msg.get('content', {})
407 407 if content is None:
408 408 content = self.none
409 409 elif isinstance(content, dict):
410 410 content = self.pack(content)
411 411 elif isinstance(content, bytes):
412 412 # content is already packed, as in a relayed message
413 413 pass
414 414 elif isinstance(content, unicode):
415 415 # should be bytes, but JSON often spits out unicode
416 416 content = content.encode('utf8')
417 417 else:
418 418 raise TypeError("Content incorrect type: %s"%type(content))
419 419
420 420 real_message = [self.pack(msg['header']),
421 421 self.pack(msg['parent_header']),
422 422 content
423 423 ]
424 424
425 425 to_send = []
426 426
427 427 if isinstance(ident, list):
428 428 # accept list of idents
429 429 to_send.extend(ident)
430 430 elif ident is not None:
431 431 to_send.append(ident)
432 432 to_send.append(DELIM)
433 433
434 434 signature = self.sign(real_message)
435 435 to_send.append(signature)
436 436
437 437 to_send.extend(real_message)
438 438
439 439 return to_send
440 440
441 441 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
442 442 buffers=None, subheader=None, track=False, header=None):
443 443 """Build and send a message via stream or socket.
444 444
445 445 The message format used by this function internally is as follows:
446 446
447 447 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
448 448 buffer1,buffer2,...]
449 449
450 450 The serialize/unserialize methods convert the nested message dict into this
451 451 format.
452 452
453 453 Parameters
454 454 ----------
455 455
456 456 stream : zmq.Socket or ZMQStream
457 457 The socket-like object used to send the data.
458 458 msg_or_type : str or Message/dict
459 459 Normally, msg_or_type will be a msg_type unless a message is being
460 460 sent more than once. If a header is supplied, this can be set to
461 461 None and the msg_type will be pulled from the header.
462 462
463 463 content : dict or None
464 464 The content of the message (ignored if msg_or_type is a message).
465 465 header : dict or None
466 466 The header dict for the message (ignores if msg_to_type is a message).
467 467 parent : Message or dict or None
468 468 The parent or parent header describing the parent of this message
469 469 (ignored if msg_or_type is a message).
470 470 ident : bytes or list of bytes
471 471 The zmq.IDENTITY routing path.
472 472 subheader : dict or None
473 473 Extra header keys for this message's header (ignored if msg_or_type
474 474 is a message).
475 475 buffers : list or None
476 476 The already-serialized buffers to be appended to the message.
477 477 track : bool
478 478 Whether to track. Only for use with Sockets, because ZMQStream
479 479 objects cannot track messages.
480 480
481 481 Returns
482 482 -------
483 483 msg : dict
484 484 The constructed message.
485 485 (msg,tracker) : (dict, MessageTracker)
486 486 if track=True, then a 2-tuple will be returned,
487 487 the first element being the constructed
488 488 message, and the second being the MessageTracker
489 489
490 490 """
491 491
492 492 if not isinstance(stream, (zmq.Socket, ZMQStream)):
493 493 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
494 494 elif track and isinstance(stream, ZMQStream):
495 495 raise TypeError("ZMQStream cannot track messages")
496 496
497 497 if isinstance(msg_or_type, (Message, dict)):
498 498 # We got a Message or message dict, not a msg_type so don't
499 499 # build a new Message.
500 500 msg = msg_or_type
501 501 else:
502 502 msg = self.msg(msg_or_type, content=content, parent=parent,
503 503 subheader=subheader, header=header)
504 504
505 505 buffers = [] if buffers is None else buffers
506 506 to_send = self.serialize(msg, ident)
507 507 flag = 0
508 508 if buffers:
509 509 flag = zmq.SNDMORE
510 510 _track = False
511 511 else:
512 512 _track=track
513 513 if track:
514 514 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
515 515 else:
516 516 tracker = stream.send_multipart(to_send, flag, copy=False)
517 517 for b in buffers[:-1]:
518 518 stream.send(b, flag, copy=False)
519 519 if buffers:
520 520 if track:
521 521 tracker = stream.send(buffers[-1], copy=False, track=track)
522 522 else:
523 523 tracker = stream.send(buffers[-1], copy=False)
524 524
525 525 # omsg = Message(msg)
526 526 if self.debug:
527 527 pprint.pprint(msg)
528 528 pprint.pprint(to_send)
529 529 pprint.pprint(buffers)
530 530
531 531 msg['tracker'] = tracker
532 532
533 533 return msg
534 534
535 535 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
536 536 """Send a raw message via ident path.
537 537
538 538 This method is used to send a already serialized message.
539 539
540 540 Parameters
541 541 ----------
542 542 stream : ZMQStream or Socket
543 543 The ZMQ stream or socket to use for sending the message.
544 544 msg_list : list
545 545 The serialized list of messages to send. This only includes the
546 546 [p_header,p_parent,p_content,buffer1,buffer2,...] portion of
547 547 the message.
548 548 ident : ident or list
549 549 A single ident or a list of idents to use in sending.
550 550 """
551 551 to_send = []
552 552 if isinstance(ident, bytes):
553 553 ident = [ident]
554 554 if ident is not None:
555 555 to_send.extend(ident)
556 556
557 557 to_send.append(DELIM)
558 558 to_send.append(self.sign(msg_list))
559 559 to_send.extend(msg_list)
560 560 stream.send_multipart(msg_list, flags, copy=copy)
561 561
562 562 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
563 563 """Receive and unpack a message.
564 564
565 565 Parameters
566 566 ----------
567 567 socket : ZMQStream or Socket
568 568 The socket or stream to use in receiving.
569 569
570 570 Returns
571 571 -------
572 572 [idents], msg
573 573 [idents] is a list of idents and msg is a nested message dict of
574 574 same format as self.msg returns.
575 575 """
576 576 if isinstance(socket, ZMQStream):
577 577 socket = socket.socket
578 578 try:
579 579 msg_list = socket.recv_multipart(mode)
580 580 except zmq.ZMQError as e:
581 581 if e.errno == zmq.EAGAIN:
582 582 # We can convert EAGAIN to None as we know in this case
583 583 # recv_multipart won't return None.
584 584 return None,None
585 585 else:
586 586 raise
587 587 # split multipart message into identity list and message dict
588 588 # invalid large messages can cause very expensive string comparisons
589 589 idents, msg_list = self.feed_identities(msg_list, copy)
590 590 try:
591 591 return idents, self.unserialize(msg_list, content=content, copy=copy)
592 592 except Exception as e:
593 593 # TODO: handle it
594 594 raise e
595 595
596 596 def feed_identities(self, msg_list, copy=True):
597 597 """Split the identities from the rest of the message.
598 598
599 599 Feed until DELIM is reached, then return the prefix as idents and
600 600 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
601 601 but that would be silly.
602 602
603 603 Parameters
604 604 ----------
605 605 msg_list : a list of Message or bytes objects
606 606 The message to be split.
607 607 copy : bool
608 608 flag determining whether the arguments are bytes or Messages
609 609
610 610 Returns
611 611 -------
612 612 (idents, msg_list) : two lists
613 613 idents will always be a list of bytes, each of which is a ZMQ
614 614 identity. msg_list will be a list of bytes or zmq.Messages of the
615 615 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
616 616 should be unpackable/unserializable via self.unserialize at this
617 617 point.
618 618 """
619 619 if copy:
620 620 idx = msg_list.index(DELIM)
621 621 return msg_list[:idx], msg_list[idx+1:]
622 622 else:
623 623 failed = True
624 624 for idx,m in enumerate(msg_list):
625 625 if m.bytes == DELIM:
626 626 failed = False
627 627 break
628 628 if failed:
629 629 raise ValueError("DELIM not in msg_list")
630 630 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
631 631 return [m.bytes for m in idents], msg_list
632 632
633 633 def unserialize(self, msg_list, content=True, copy=True):
634 634 """Unserialize a msg_list to a nested message dict.
635 635
636 636 This is roughly the inverse of serialize. The serialize/unserialize
637 637 methods work with full message lists, whereas pack/unpack work with
638 638 the individual message parts in the message list.
639 639
640 640 Parameters:
641 641 -----------
642 642 msg_list : list of bytes or Message objects
643 643 The list of message parts of the form [HMAC,p_header,p_parent,
644 644 p_content,buffer1,buffer2,...].
645 645 content : bool (True)
646 646 Whether to unpack the content dict (True), or leave it packed
647 647 (False).
648 648 copy : bool (True)
649 649 Whether to return the bytes (True), or the non-copying Message
650 650 object in each place (False).
651 651
652 652 Returns
653 653 -------
654 654 msg : dict
655 655 The nested message dict with top-level keys [header, parent_header,
656 656 content, buffers].
657 657 """
658 658 minlen = 4
659 659 message = {}
660 660 if not copy:
661 661 for i in range(minlen):
662 662 msg_list[i] = msg_list[i].bytes
663 663 if self.auth is not None:
664 664 signature = msg_list[0]
665 665 if not signature:
666 666 raise ValueError("Unsigned Message")
667 667 if signature in self.digest_history:
668 668 raise ValueError("Duplicate Signature: %r"%signature)
669 669 self.digest_history.add(signature)
670 670 check = self.sign(msg_list[1:4])
671 671 if not signature == check:
672 672 raise ValueError("Invalid Signature: %r"%signature)
673 673 if not len(msg_list) >= minlen:
674 674 raise TypeError("malformed message, must have at least %i elements"%minlen)
675 675 header = self.unpack(msg_list[1])
676 676 message['header'] = header
677 677 message['msg_id'] = header['msg_id']
678 678 message['msg_type'] = header['msg_type']
679 679 message['parent_header'] = self.unpack(msg_list[2])
680 680 if content:
681 681 message['content'] = self.unpack(msg_list[3])
682 682 else:
683 683 message['content'] = msg_list[3]
684 684
685 685 message['buffers'] = msg_list[4:]
686 686 return message
687 687
688 688 def test_msg2obj():
689 689 am = dict(x=1)
690 690 ao = Message(am)
691 691 assert ao.x == am['x']
692 692
693 693 am['y'] = dict(z=1)
694 694 ao = Message(am)
695 695 assert ao.y.z == am['y']['z']
696 696
697 697 k1, k2 = 'y', 'z'
698 698 assert ao[k1][k2] == am[k1][k2]
699 699
700 700 am2 = dict(ao)
701 701 assert am['x'] == am2['x']
702 702 assert am['y']['z'] == am2['y']['z']
703 703
General Comments 0
You need to be logged in to leave comments. Login now