##// END OF EJS Templates
remove now-obsolete use of skip_doctest outside core
Min RK -
Show More
@@ -1,707 +1,702 b''
1 1 # -*- coding: utf-8 -*-
2 2 """
3 3 ======
4 4 Rmagic
5 5 ======
6 6
7 7 Magic command interface for interactive work with R via rpy2
8 8
9 9 .. note::
10 10
11 11 The ``rpy2`` package needs to be installed separately. It
12 12 can be obtained using ``easy_install`` or ``pip``.
13 13
14 14 You will also need a working copy of R.
15 15
16 16 Usage
17 17 =====
18 18
19 19 To enable the magics below, execute ``%load_ext rmagic``.
20 20
21 21 ``%R``
22 22
23 23 {R_DOC}
24 24
25 25 ``%Rpush``
26 26
27 27 {RPUSH_DOC}
28 28
29 29 ``%Rpull``
30 30
31 31 {RPULL_DOC}
32 32
33 33 ``%Rget``
34 34
35 35 {RGET_DOC}
36 36
37 37 """
38 38 from __future__ import print_function
39 39
40 40 #-----------------------------------------------------------------------------
41 41 # Copyright (C) 2012 The IPython Development Team
42 42 #
43 43 # Distributed under the terms of the BSD License. The full license is in
44 44 # the file COPYING, distributed as part of this software.
45 45 #-----------------------------------------------------------------------------
46 46
47 47 import sys
48 48 import tempfile
49 49 from glob import glob
50 50 from shutil import rmtree
51 51 import warnings
52 52
53 53 # numpy and rpy2 imports
54 54
55 55 import numpy as np
56 56
57 57 import rpy2.rinterface as ri
58 58 import rpy2.robjects as ro
59 59 try:
60 60 from rpy2.robjects import pandas2ri
61 61 pandas2ri.activate()
62 62 except ImportError:
63 63 pandas2ri = None
64 64 from rpy2.robjects import numpy2ri
65 65 numpy2ri.activate()
66 66
67 67 # IPython imports
68 68
69 69 from IPython.core.displaypub import publish_display_data
70 70 from IPython.core.magic import (Magics, magics_class, line_magic,
71 71 line_cell_magic, needs_local_scope)
72 from IPython.testing.skipdoctest import skip_doctest
73 72 from IPython.core.magic_arguments import (
74 73 argument, magic_arguments, parse_argstring
75 74 )
76 75 from simplegeneric import generic
77 76 from IPython.utils.py3compat import (str_to_unicode, unicode_to_str, PY3,
78 77 unicode_type)
79 78 from IPython.utils.text import dedent
80 79
81 80 class RInterpreterError(ri.RRuntimeError):
82 81 """An error when running R code in a %%R magic cell."""
83 82 def __init__(self, line, err, stdout):
84 83 self.line = line
85 84 self.err = err.rstrip()
86 85 self.stdout = stdout.rstrip()
87 86
88 87 def __unicode__(self):
89 88 s = 'Failed to parse and evaluate line %r.\nR error message: %r' % \
90 89 (self.line, self.err)
91 90 if self.stdout and (self.stdout != self.err):
92 91 s += '\nR stdout:\n' + self.stdout
93 92 return s
94 93
95 94 if PY3:
96 95 __str__ = __unicode__
97 96 else:
98 97 def __str__(self):
99 98 return unicode_to_str(unicode(self), 'utf-8')
100 99
101 100 def Rconverter(Robj, dataframe=False):
102 101 """
103 102 Convert an object in R's namespace to one suitable
104 103 for ipython's namespace.
105 104
106 105 For a data.frame, it tries to return a structured array.
107 106 It first checks for colnames, then names.
108 107 If all are NULL, it returns np.asarray(Robj), else
109 108 it tries to construct a recarray
110 109
111 110 Parameters
112 111 ----------
113 112
114 113 Robj: an R object returned from rpy2
115 114 """
116 115 is_data_frame = ro.r('is.data.frame')
117 116 colnames = ro.r('colnames')
118 117 rownames = ro.r('rownames') # with pandas, these could be used for the index
119 118 names = ro.r('names')
120 119
121 120 if dataframe:
122 121 as_data_frame = ro.r('as.data.frame')
123 122 cols = colnames(Robj)
124 123 _names = names(Robj)
125 124 if cols != ri.NULL:
126 125 Robj = as_data_frame(Robj)
127 126 names = tuple(np.array(cols))
128 127 elif _names != ri.NULL:
129 128 names = tuple(np.array(_names))
130 129 else: # failed to find names
131 130 return np.asarray(Robj)
132 131 Robj = np.rec.fromarrays(Robj, names = names)
133 132 return np.asarray(Robj)
134 133
135 134 @generic
136 135 def pyconverter(pyobj):
137 136 """Convert Python objects to R objects. Add types using the decorator:
138 137
139 138 @pyconverter.when_type
140 139 """
141 140 return pyobj
142 141
143 142 # The default conversion for lists seems to make them a nested list. That has
144 143 # some advantages, but is rarely convenient, so for interactive use, we convert
145 144 # lists to a numpy array, which becomes an R vector.
146 145 @pyconverter.when_type(list)
147 146 def pyconverter_list(pyobj):
148 147 return np.asarray(pyobj)
149 148
150 149 if pandas2ri is None:
151 150 # pandas2ri was new in rpy2 2.3.3, so for now we'll fallback to pandas'
152 151 # conversion function.
153 152 try:
154 153 from pandas import DataFrame
155 154 from pandas.rpy.common import convert_to_r_dataframe
156 155 @pyconverter.when_type(DataFrame)
157 156 def pyconverter_dataframe(pyobj):
158 157 return convert_to_r_dataframe(pyobj, strings_as_factors=True)
159 158 except ImportError:
160 159 pass
161 160
162 161 @magics_class
163 162 class RMagics(Magics):
164 163 """A set of magics useful for interactive work with R via rpy2.
165 164 """
166 165
167 166 def __init__(self, shell, Rconverter=Rconverter,
168 167 pyconverter=pyconverter,
169 168 cache_display_data=False):
170 169 """
171 170 Parameters
172 171 ----------
173 172
174 173 shell : IPython shell
175 174
176 175 Rconverter : callable
177 176 To be called on values taken from R before putting them in the
178 177 IPython namespace.
179 178
180 179 pyconverter : callable
181 180 To be called on values in ipython namespace before
182 181 assigning to variables in rpy2.
183 182
184 183 cache_display_data : bool
185 184 If True, the published results of the final call to R are
186 185 cached in the variable 'display_cache'.
187 186
188 187 """
189 188 super(RMagics, self).__init__(shell)
190 189 self.cache_display_data = cache_display_data
191 190
192 191 self.r = ro.R()
193 192
194 193 self.Rstdout_cache = []
195 194 self.pyconverter = pyconverter
196 195 self.Rconverter = Rconverter
197 196
198 197 def eval(self, line):
199 198 '''
200 199 Parse and evaluate a line of R code with rpy2.
201 200 Returns the output to R's stdout() connection,
202 201 the value generated by evaluating the code, and a
203 202 boolean indicating whether the return value would be
204 203 visible if the line of code were evaluated in an R REPL.
205 204
206 205 R Code evaluation and visibility determination are
207 206 done via an R call of the form withVisible({<code>})
208 207
209 208 '''
210 209 old_writeconsole = ri.get_writeconsole()
211 210 ri.set_writeconsole(self.write_console)
212 211 try:
213 212 res = ro.r("withVisible({%s\n})" % line)
214 213 value = res[0] #value (R object)
215 214 visible = ro.conversion.ri2py(res[1])[0] #visible (boolean)
216 215 except (ri.RRuntimeError, ValueError) as exception:
217 216 warning_or_other_msg = self.flush() # otherwise next return seems to have copy of error
218 217 raise RInterpreterError(line, str_to_unicode(str(exception)), warning_or_other_msg)
219 218 text_output = self.flush()
220 219 ri.set_writeconsole(old_writeconsole)
221 220 return text_output, value, visible
222 221
223 222 def write_console(self, output):
224 223 '''
225 224 A hook to capture R's stdout in a cache.
226 225 '''
227 226 self.Rstdout_cache.append(output)
228 227
229 228 def flush(self):
230 229 '''
231 230 Flush R's stdout cache to a string, returning the string.
232 231 '''
233 232 value = ''.join([str_to_unicode(s, 'utf-8') for s in self.Rstdout_cache])
234 233 self.Rstdout_cache = []
235 234 return value
236 235
237 @skip_doctest
238 236 @needs_local_scope
239 237 @line_magic
240 238 def Rpush(self, line, local_ns=None):
241 239 '''
242 240 A line-level magic for R that pushes
243 241 variables from python to rpy2. The line should be made up
244 242 of whitespace separated variable names in the IPython
245 243 namespace::
246 244
247 245 In [7]: import numpy as np
248 246
249 247 In [8]: X = np.array([4.5,6.3,7.9])
250 248
251 249 In [9]: X.mean()
252 250 Out[9]: 6.2333333333333343
253 251
254 252 In [10]: %Rpush X
255 253
256 254 In [11]: %R mean(X)
257 255 Out[11]: array([ 6.23333333])
258 256
259 257 '''
260 258 if local_ns is None:
261 259 local_ns = {}
262 260
263 261 inputs = line.split(' ')
264 262 for input in inputs:
265 263 try:
266 264 val = local_ns[input]
267 265 except KeyError:
268 266 try:
269 267 val = self.shell.user_ns[input]
270 268 except KeyError:
271 269 # reraise the KeyError as a NameError so that it looks like
272 270 # the standard python behavior when you use an unnamed
273 271 # variable
274 272 raise NameError("name '%s' is not defined" % input)
275 273
276 274 self.r.assign(input, self.pyconverter(val))
277 275
278 @skip_doctest
279 276 @magic_arguments()
280 277 @argument(
281 278 '-d', '--as_dataframe', action='store_true',
282 279 default=False,
283 280 help='Convert objects to data.frames before returning to ipython.'
284 281 )
285 282 @argument(
286 283 'outputs',
287 284 nargs='*',
288 285 )
289 286 @line_magic
290 287 def Rpull(self, line):
291 288 '''
292 289 A line-level magic for R that pulls
293 290 variables from python to rpy2::
294 291
295 292 In [18]: _ = %R x = c(3,4,6.7); y = c(4,6,7); z = c('a',3,4)
296 293
297 294 In [19]: %Rpull x y z
298 295
299 296 In [20]: x
300 297 Out[20]: array([ 3. , 4. , 6.7])
301 298
302 299 In [21]: y
303 300 Out[21]: array([ 4., 6., 7.])
304 301
305 302 In [22]: z
306 303 Out[22]:
307 304 array(['a', '3', '4'],
308 305 dtype='|S1')
309 306
310 307
311 308 If --as_dataframe, then each object is returned as a structured array
312 309 after first passed through "as.data.frame" in R before
313 310 being calling self.Rconverter.
314 311 This is useful when a structured array is desired as output, or
315 312 when the object in R has mixed data types.
316 313 See the %%R docstring for more examples.
317 314
318 315 Notes
319 316 -----
320 317
321 318 Beware that R names can have '.' so this is not fool proof.
322 319 To avoid this, don't name your R objects with '.'s...
323 320
324 321 '''
325 322 args = parse_argstring(self.Rpull, line)
326 323 outputs = args.outputs
327 324 for output in outputs:
328 325 self.shell.push({output:self.Rconverter(self.r(output),dataframe=args.as_dataframe)})
329 326
330 @skip_doctest
331 327 @magic_arguments()
332 328 @argument(
333 329 '-d', '--as_dataframe', action='store_true',
334 330 default=False,
335 331 help='Convert objects to data.frames before returning to ipython.'
336 332 )
337 333 @argument(
338 334 'output',
339 335 nargs=1,
340 336 type=str,
341 337 )
342 338 @line_magic
343 339 def Rget(self, line):
344 340 '''
345 341 Return an object from rpy2, possibly as a structured array (if possible).
346 342 Similar to Rpull except only one argument is accepted and the value is
347 343 returned rather than pushed to self.shell.user_ns::
348 344
349 345 In [3]: dtype=[('x', '<i4'), ('y', '<f8'), ('z', '|S1')]
350 346
351 347 In [4]: datapy = np.array([(1, 2.9, 'a'), (2, 3.5, 'b'), (3, 2.1, 'c'), (4, 5, 'e')], dtype=dtype)
352 348
353 349 In [5]: %R -i datapy
354 350
355 351 In [6]: %Rget datapy
356 352 Out[6]:
357 353 array([['1', '2', '3', '4'],
358 354 ['2', '3', '2', '5'],
359 355 ['a', 'b', 'c', 'e']],
360 356 dtype='|S1')
361 357
362 358 In [7]: %Rget -d datapy
363 359 Out[7]:
364 360 array([(1, 2.9, 'a'), (2, 3.5, 'b'), (3, 2.1, 'c'), (4, 5.0, 'e')],
365 361 dtype=[('x', '<i4'), ('y', '<f8'), ('z', '|S1')])
366 362
367 363 '''
368 364 args = parse_argstring(self.Rget, line)
369 365 output = args.output
370 366 return self.Rconverter(self.r(output[0]),dataframe=args.as_dataframe)
371 367
372 368
373 @skip_doctest
374 369 @magic_arguments()
375 370 @argument(
376 371 '-i', '--input', action='append',
377 372 help='Names of input variable from shell.user_ns to be assigned to R variables of the same names after calling self.pyconverter. Multiple names can be passed separated only by commas with no whitespace.'
378 373 )
379 374 @argument(
380 375 '-o', '--output', action='append',
381 376 help='Names of variables to be pushed from rpy2 to shell.user_ns after executing cell body and applying self.Rconverter. Multiple names can be passed separated only by commas with no whitespace.'
382 377 )
383 378 @argument(
384 379 '-w', '--width', type=int,
385 380 help='Width of png plotting device sent as an argument to *png* in R.'
386 381 )
387 382 @argument(
388 383 '-h', '--height', type=int,
389 384 help='Height of png plotting device sent as an argument to *png* in R.'
390 385 )
391 386
392 387 @argument(
393 388 '-d', '--dataframe', action='append',
394 389 help='Convert these objects to data.frames and return as structured arrays.'
395 390 )
396 391 @argument(
397 392 '-u', '--units', type=unicode_type, choices=["px", "in", "cm", "mm"],
398 393 help='Units of png plotting device sent as an argument to *png* in R. One of ["px", "in", "cm", "mm"].'
399 394 )
400 395 @argument(
401 396 '-r', '--res', type=int,
402 397 help='Resolution of png plotting device sent as an argument to *png* in R. Defaults to 72 if *units* is one of ["in", "cm", "mm"].'
403 398 )
404 399 @argument(
405 400 '-p', '--pointsize', type=int,
406 401 help='Pointsize of png plotting device sent as an argument to *png* in R.'
407 402 )
408 403 @argument(
409 404 '-b', '--bg',
410 405 help='Background of png plotting device sent as an argument to *png* in R.'
411 406 )
412 407 @argument(
413 408 '-n', '--noreturn',
414 409 help='Force the magic to not return anything.',
415 410 action='store_true',
416 411 default=False
417 412 )
418 413 @argument(
419 414 'code',
420 415 nargs='*',
421 416 )
422 417 @needs_local_scope
423 418 @line_cell_magic
424 419 def R(self, line, cell=None, local_ns=None):
425 420 '''
426 421 Execute code in R, and pull some of the results back into the Python namespace.
427 422
428 423 In line mode, this will evaluate an expression and convert the returned value to a Python object.
429 424 The return value is determined by rpy2's behaviour of returning the result of evaluating the
430 425 final line.
431 426
432 427 Multiple R lines can be executed by joining them with semicolons::
433 428
434 429 In [9]: %R X=c(1,4,5,7); sd(X); mean(X)
435 430 Out[9]: array([ 4.25])
436 431
437 432 In cell mode, this will run a block of R code. The resulting value
438 433 is printed if it would printed be when evaluating the same code
439 434 within a standard R REPL.
440 435
441 436 Nothing is returned to python by default in cell mode::
442 437
443 438 In [10]: %%R
444 439 ....: Y = c(2,4,3,9)
445 440 ....: summary(lm(Y~X))
446 441
447 442 Call:
448 443 lm(formula = Y ~ X)
449 444
450 445 Residuals:
451 446 1 2 3 4
452 447 0.88 -0.24 -2.28 1.64
453 448
454 449 Coefficients:
455 450 Estimate Std. Error t value Pr(>|t|)
456 451 (Intercept) 0.0800 2.3000 0.035 0.975
457 452 X 1.0400 0.4822 2.157 0.164
458 453
459 454 Residual standard error: 2.088 on 2 degrees of freedom
460 455 Multiple R-squared: 0.6993,Adjusted R-squared: 0.549
461 456 F-statistic: 4.651 on 1 and 2 DF, p-value: 0.1638
462 457
463 458 In the notebook, plots are published as the output of the cell::
464 459
465 460 %R plot(X, Y)
466 461
467 462 will create a scatter plot of X bs Y.
468 463
469 464 If cell is not None and line has some R code, it is prepended to
470 465 the R code in cell.
471 466
472 467 Objects can be passed back and forth between rpy2 and python via the -i -o flags in line::
473 468
474 469 In [14]: Z = np.array([1,4,5,10])
475 470
476 471 In [15]: %R -i Z mean(Z)
477 472 Out[15]: array([ 5.])
478 473
479 474
480 475 In [16]: %R -o W W=Z*mean(Z)
481 476 Out[16]: array([ 5., 20., 25., 50.])
482 477
483 478 In [17]: W
484 479 Out[17]: array([ 5., 20., 25., 50.])
485 480
486 481 The return value is determined by these rules:
487 482
488 483 * If the cell is not None, the magic returns None.
489 484
490 485 * If the cell evaluates as False, the resulting value is returned
491 486 unless the final line prints something to the console, in
492 487 which case None is returned.
493 488
494 489 * If the final line results in a NULL value when evaluated
495 490 by rpy2, then None is returned.
496 491
497 492 * No attempt is made to convert the final value to a structured array.
498 493 Use the --dataframe flag or %Rget to push / return a structured array.
499 494
500 495 * If the -n flag is present, there is no return value.
501 496
502 497 * A trailing ';' will also result in no return value as the last
503 498 value in the line is an empty string.
504 499
505 500 The --dataframe argument will attempt to return structured arrays.
506 501 This is useful for dataframes with
507 502 mixed data types. Note also that for a data.frame,
508 503 if it is returned as an ndarray, it is transposed::
509 504
510 505 In [18]: dtype=[('x', '<i4'), ('y', '<f8'), ('z', '|S1')]
511 506
512 507 In [19]: datapy = np.array([(1, 2.9, 'a'), (2, 3.5, 'b'), (3, 2.1, 'c'), (4, 5, 'e')], dtype=dtype)
513 508
514 509 In [20]: %%R -o datar
515 510 datar = datapy
516 511 ....:
517 512
518 513 In [21]: datar
519 514 Out[21]:
520 515 array([['1', '2', '3', '4'],
521 516 ['2', '3', '2', '5'],
522 517 ['a', 'b', 'c', 'e']],
523 518 dtype='|S1')
524 519
525 520 In [22]: %%R -d datar
526 521 datar = datapy
527 522 ....:
528 523
529 524 In [23]: datar
530 525 Out[23]:
531 526 array([(1, 2.9, 'a'), (2, 3.5, 'b'), (3, 2.1, 'c'), (4, 5.0, 'e')],
532 527 dtype=[('x', '<i4'), ('y', '<f8'), ('z', '|S1')])
533 528
534 529 The --dataframe argument first tries colnames, then names.
535 530 If both are NULL, it returns an ndarray (i.e. unstructured)::
536 531
537 532 In [1]: %R mydata=c(4,6,8.3); NULL
538 533
539 534 In [2]: %R -d mydata
540 535
541 536 In [3]: mydata
542 537 Out[3]: array([ 4. , 6. , 8.3])
543 538
544 539 In [4]: %R names(mydata) = c('a','b','c'); NULL
545 540
546 541 In [5]: %R -d mydata
547 542
548 543 In [6]: mydata
549 544 Out[6]:
550 545 array((4.0, 6.0, 8.3),
551 546 dtype=[('a', '<f8'), ('b', '<f8'), ('c', '<f8')])
552 547
553 548 In [7]: %R -o mydata
554 549
555 550 In [8]: mydata
556 551 Out[8]: array([ 4. , 6. , 8.3])
557 552
558 553 '''
559 554
560 555 args = parse_argstring(self.R, line)
561 556
562 557 # arguments 'code' in line are prepended to
563 558 # the cell lines
564 559
565 560 if cell is None:
566 561 code = ''
567 562 return_output = True
568 563 line_mode = True
569 564 else:
570 565 code = cell
571 566 return_output = False
572 567 line_mode = False
573 568
574 569 code = ' '.join(args.code) + code
575 570
576 571 # if there is no local namespace then default to an empty dict
577 572 if local_ns is None:
578 573 local_ns = {}
579 574
580 575 if args.input:
581 576 for input in ','.join(args.input).split(','):
582 577 try:
583 578 val = local_ns[input]
584 579 except KeyError:
585 580 try:
586 581 val = self.shell.user_ns[input]
587 582 except KeyError:
588 583 raise NameError("name '%s' is not defined" % input)
589 584 self.r.assign(input, self.pyconverter(val))
590 585
591 586 if getattr(args, 'units') is not None:
592 587 if args.units != "px" and getattr(args, 'res') is None:
593 588 args.res = 72
594 589 args.units = '"%s"' % args.units
595 590
596 591 png_argdict = dict([(n, getattr(args, n)) for n in ['units', 'res', 'height', 'width', 'bg', 'pointsize']])
597 592 png_args = ','.join(['%s=%s' % (o,v) for o, v in png_argdict.items() if v is not None])
598 593 # execute the R code in a temporary directory
599 594
600 595 tmpd = tempfile.mkdtemp()
601 596 self.r('png("%s/Rplots%%03d.png",%s)' % (tmpd.replace('\\', '/'), png_args))
602 597
603 598 text_output = ''
604 599 try:
605 600 if line_mode:
606 601 for line in code.split(';'):
607 602 text_result, result, visible = self.eval(line)
608 603 text_output += text_result
609 604 if text_result:
610 605 # the last line printed something to the console so we won't return it
611 606 return_output = False
612 607 else:
613 608 text_result, result, visible = self.eval(code)
614 609 text_output += text_result
615 610 if visible:
616 611 old_writeconsole = ri.get_writeconsole()
617 612 ri.set_writeconsole(self.write_console)
618 613 ro.r.show(result)
619 614 text_output += self.flush()
620 615 ri.set_writeconsole(old_writeconsole)
621 616
622 617 except RInterpreterError as e:
623 618 print(e.stdout)
624 619 if not e.stdout.endswith(e.err):
625 620 print(e.err)
626 621 rmtree(tmpd)
627 622 return
628 623 finally:
629 624 self.r('dev.off()')
630 625
631 626 # read out all the saved .png files
632 627
633 628 images = [open(imgfile, 'rb').read() for imgfile in glob("%s/Rplots*png" % tmpd)]
634 629
635 630 # now publish the images
636 631 # mimicking IPython/zmq/pylab/backend_inline.py
637 632 fmt = 'png'
638 633 mimetypes = { 'png' : 'image/png', 'svg' : 'image/svg+xml' }
639 634 mime = mimetypes[fmt]
640 635
641 636 # publish the printed R objects, if any
642 637
643 638 display_data = []
644 639 if text_output:
645 640 display_data.append(('RMagic.R', {'text/plain':text_output}))
646 641
647 642 # flush text streams before sending figures, helps a little with output
648 643 for image in images:
649 644 # synchronization in the console (though it's a bandaid, not a real sln)
650 645 sys.stdout.flush(); sys.stderr.flush()
651 646 display_data.append(('RMagic.R', {mime: image}))
652 647
653 648 # kill the temporary directory
654 649 rmtree(tmpd)
655 650
656 651 # try to turn every output into a numpy array
657 652 # this means that output are assumed to be castable
658 653 # as numpy arrays
659 654
660 655 if args.output:
661 656 for output in ','.join(args.output).split(','):
662 657 self.shell.push({output:self.Rconverter(self.r(output), dataframe=False)})
663 658
664 659 if args.dataframe:
665 660 for output in ','.join(args.dataframe).split(','):
666 661 self.shell.push({output:self.Rconverter(self.r(output), dataframe=True)})
667 662
668 663 for tag, disp_d in display_data:
669 664 publish_display_data(data=disp_d, source=tag)
670 665
671 666 # this will keep a reference to the display_data
672 667 # which might be useful to other objects who happen to use
673 668 # this method
674 669
675 670 if self.cache_display_data:
676 671 self.display_cache = display_data
677 672
678 673 # if in line mode and return_output, return the result as an ndarray
679 674 if return_output and not args.noreturn:
680 675 if result != ri.NULL:
681 676 return self.Rconverter(result, dataframe=False)
682 677
683 678 __doc__ = __doc__.format(
684 679 R_DOC = dedent(RMagics.R.__doc__),
685 680 RPUSH_DOC = dedent(RMagics.Rpush.__doc__),
686 681 RPULL_DOC = dedent(RMagics.Rpull.__doc__),
687 682 RGET_DOC = dedent(RMagics.Rget.__doc__)
688 683 )
689 684
690 685
691 686 def load_ipython_extension(ip):
692 687 """Load the extension in IPython."""
693 688 warnings.warn("The rmagic extension in IPython is deprecated in favour of "
694 689 "rpy2.ipython. If available, that will be loaded instead.\n"
695 690 "http://rpy.sourceforge.net/")
696 691 try:
697 692 import rpy2.ipython
698 693 except ImportError:
699 694 pass # Fall back to our own implementation for now
700 695 else:
701 696 return rpy2.ipython.load_ipython_extension(ip)
702 697
703 698 ip.register_magics(RMagics)
704 699 # Initialising rpy2 interferes with readline. Since, at this point, we've
705 700 # probably just loaded rpy2, we reset the delimiters. See issue gh-2759.
706 701 if ip.has_readline:
707 702 ip.readline.set_completer_delims(ip.readline_delims)
@@ -1,243 +1,241 b''
1 1 # -*- coding: utf-8 -*-
2 2 """
3 3 %store magic for lightweight persistence.
4 4
5 5 Stores variables, aliases and macros in IPython's database.
6 6
7 7 To automatically restore stored variables at startup, add this to your
8 8 :file:`ipython_config.py` file::
9 9
10 10 c.StoreMagics.autorestore = True
11 11 """
12 12 from __future__ import print_function
13 13 #-----------------------------------------------------------------------------
14 14 # Copyright (c) 2012, The IPython Development Team.
15 15 #
16 16 # Distributed under the terms of the Modified BSD License.
17 17 #
18 18 # The full license is in the file COPYING.txt, distributed with this software.
19 19 #-----------------------------------------------------------------------------
20 20
21 21 #-----------------------------------------------------------------------------
22 22 # Imports
23 23 #-----------------------------------------------------------------------------
24 24
25 25 # Stdlib
26 26 import inspect, os, sys, textwrap
27 27
28 28 # Our own
29 29 from IPython.core.error import UsageError
30 30 from IPython.core.magic import Magics, magics_class, line_magic
31 from IPython.testing.skipdoctest import skip_doctest
32 31 from IPython.utils.traitlets import Bool
33 32 from IPython.utils.py3compat import string_types
34 33
35 34 #-----------------------------------------------------------------------------
36 35 # Functions and classes
37 36 #-----------------------------------------------------------------------------
38 37
39 38 def restore_aliases(ip):
40 39 staliases = ip.db.get('stored_aliases', {})
41 40 for k,v in staliases.items():
42 41 #print "restore alias",k,v # dbg
43 42 #self.alias_table[k] = v
44 43 ip.alias_manager.define_alias(k,v)
45 44
46 45
47 46 def refresh_variables(ip):
48 47 db = ip.db
49 48 for key in db.keys('autorestore/*'):
50 49 # strip autorestore
51 50 justkey = os.path.basename(key)
52 51 try:
53 52 obj = db[key]
54 53 except KeyError:
55 54 print("Unable to restore variable '%s', ignoring (use %%store -d to forget!)" % justkey)
56 55 print("The error was:", sys.exc_info()[0])
57 56 else:
58 57 #print "restored",justkey,"=",obj #dbg
59 58 ip.user_ns[justkey] = obj
60 59
61 60
62 61 def restore_dhist(ip):
63 62 ip.user_ns['_dh'] = ip.db.get('dhist',[])
64 63
65 64
66 65 def restore_data(ip):
67 66 refresh_variables(ip)
68 67 restore_aliases(ip)
69 68 restore_dhist(ip)
70 69
71 70
72 71 @magics_class
73 72 class StoreMagics(Magics):
74 73 """Lightweight persistence for python variables.
75 74
76 75 Provides the %store magic."""
77 76
78 77 autorestore = Bool(False, config=True, help=
79 78 """If True, any %store-d variables will be automatically restored
80 79 when IPython starts.
81 80 """
82 81 )
83 82
84 83 def __init__(self, shell):
85 84 super(StoreMagics, self).__init__(shell=shell)
86 85 self.shell.configurables.append(self)
87 86 if self.autorestore:
88 87 restore_data(self.shell)
89 88
90 @skip_doctest
91 89 @line_magic
92 90 def store(self, parameter_s=''):
93 91 """Lightweight persistence for python variables.
94 92
95 93 Example::
96 94
97 95 In [1]: l = ['hello',10,'world']
98 96 In [2]: %store l
99 97 In [3]: exit
100 98
101 99 (IPython session is closed and started again...)
102 100
103 101 ville@badger:~$ ipython
104 102 In [1]: l
105 103 NameError: name 'l' is not defined
106 104 In [2]: %store -r
107 105 In [3]: l
108 106 Out[3]: ['hello', 10, 'world']
109 107
110 108 Usage:
111 109
112 110 * ``%store`` - Show list of all variables and their current
113 111 values
114 112 * ``%store spam`` - Store the *current* value of the variable spam
115 113 to disk
116 114 * ``%store -d spam`` - Remove the variable and its value from storage
117 115 * ``%store -z`` - Remove all variables from storage
118 116 * ``%store -r`` - Refresh all variables from store (overwrite
119 117 current vals)
120 118 * ``%store -r spam bar`` - Refresh specified variables from store
121 119 (delete current val)
122 120 * ``%store foo >a.txt`` - Store value of foo to new file a.txt
123 121 * ``%store foo >>a.txt`` - Append value of foo to file a.txt
124 122
125 123 It should be noted that if you change the value of a variable, you
126 124 need to %store it again if you want to persist the new value.
127 125
128 126 Note also that the variables will need to be pickleable; most basic
129 127 python types can be safely %store'd.
130 128
131 129 Also aliases can be %store'd across sessions.
132 130 """
133 131
134 132 opts,argsl = self.parse_options(parameter_s,'drz',mode='string')
135 133 args = argsl.split(None,1)
136 134 ip = self.shell
137 135 db = ip.db
138 136 # delete
139 137 if 'd' in opts:
140 138 try:
141 139 todel = args[0]
142 140 except IndexError:
143 141 raise UsageError('You must provide the variable to forget')
144 142 else:
145 143 try:
146 144 del db['autorestore/' + todel]
147 145 except:
148 146 raise UsageError("Can't delete variable '%s'" % todel)
149 147 # reset
150 148 elif 'z' in opts:
151 149 for k in db.keys('autorestore/*'):
152 150 del db[k]
153 151
154 152 elif 'r' in opts:
155 153 if args:
156 154 for arg in args:
157 155 try:
158 156 obj = db['autorestore/' + arg]
159 157 except KeyError:
160 158 print("no stored variable %s" % arg)
161 159 else:
162 160 ip.user_ns[arg] = obj
163 161 else:
164 162 restore_data(ip)
165 163
166 164 # run without arguments -> list variables & values
167 165 elif not args:
168 166 vars = db.keys('autorestore/*')
169 167 vars.sort()
170 168 if vars:
171 169 size = max(map(len, vars))
172 170 else:
173 171 size = 0
174 172
175 173 print('Stored variables and their in-db values:')
176 174 fmt = '%-'+str(size)+'s -> %s'
177 175 get = db.get
178 176 for var in vars:
179 177 justkey = os.path.basename(var)
180 178 # print 30 first characters from every var
181 179 print(fmt % (justkey, repr(get(var, '<unavailable>'))[:50]))
182 180
183 181 # default action - store the variable
184 182 else:
185 183 # %store foo >file.txt or >>file.txt
186 184 if len(args) > 1 and args[1].startswith('>'):
187 185 fnam = os.path.expanduser(args[1].lstrip('>').lstrip())
188 186 if args[1].startswith('>>'):
189 187 fil = open(fnam, 'a')
190 188 else:
191 189 fil = open(fnam, 'w')
192 190 obj = ip.ev(args[0])
193 191 print("Writing '%s' (%s) to file '%s'." % (args[0],
194 192 obj.__class__.__name__, fnam))
195 193
196 194
197 195 if not isinstance (obj, string_types):
198 196 from pprint import pprint
199 197 pprint(obj, fil)
200 198 else:
201 199 fil.write(obj)
202 200 if not obj.endswith('\n'):
203 201 fil.write('\n')
204 202
205 203 fil.close()
206 204 return
207 205
208 206 # %store foo
209 207 try:
210 208 obj = ip.user_ns[args[0]]
211 209 except KeyError:
212 210 # it might be an alias
213 211 name = args[0]
214 212 try:
215 213 cmd = ip.alias_manager.retrieve_alias(name)
216 214 except ValueError:
217 215 raise UsageError("Unknown variable '%s'" % name)
218 216
219 217 staliases = db.get('stored_aliases',{})
220 218 staliases[name] = cmd
221 219 db['stored_aliases'] = staliases
222 220 print("Alias stored: %s (%s)" % (name, cmd))
223 221 return
224 222
225 223 else:
226 224 modname = getattr(inspect.getmodule(obj), '__name__', '')
227 225 if modname == '__main__':
228 226 print(textwrap.dedent("""\
229 227 Warning:%s is %s
230 228 Proper storage of interactively declared classes (or instances
231 229 of those classes) is not possible! Only instances
232 230 of classes in real modules on file system can be %%store'd.
233 231 """ % (args[0], obj) ))
234 232 return
235 233 #pickled = pickle.dumps(obj)
236 234 db[ 'autorestore/' + args[0] ] = obj
237 235 print("Stored '%s' (%s)" % (args[0], obj.__class__.__name__))
238 236
239 237
240 238 def load_ipython_extension(ip):
241 239 """Load the extension in IPython."""
242 240 ip.register_magics(StoreMagics)
243 241
@@ -1,111 +1,108 b''
1 1 """Link and DirectionalLink classes.
2 2
3 3 Propagate changes between widgets on the javascript side
4 4 """
5 5
6 6 # Copyright (c) IPython Development Team.
7 7 # Distributed under the terms of the Modified BSD License.
8 8
9 9 from .widget import Widget
10 from IPython.testing.skipdoctest import skip_doctest
11 10 from IPython.utils.traitlets import Unicode, Tuple, List,Instance, TraitError
12 11
13 12 class WidgetTraitTuple(Tuple):
14 13 """Traitlet for validating a single (Widget, 'trait_name') pair"""
15 14
16 15 def __init__(self, **kwargs):
17 16 super(WidgetTraitTuple, self).__init__(Instance(Widget), Unicode, **kwargs)
18 17
19 18 def validate_elements(self, obj, value):
20 19 value = super(WidgetTraitTuple, self).validate_elements(obj, value)
21 20 widget, trait_name = value
22 21 trait = widget.traits().get(trait_name)
23 22 trait_repr = "%s.%s" % (widget.__class__.__name__, trait_name)
24 23 # Can't raise TraitError because the parent will swallow the message
25 24 # and throw it away in a new, less informative TraitError
26 25 if trait is None:
27 26 raise TypeError("No such trait: %s" % trait_repr)
28 27 elif not trait.get_metadata('sync'):
29 28 raise TypeError("%s cannot be synced" % trait_repr)
30 29
31 30 return value
32 31
33 32
34 33 class Link(Widget):
35 34 """Link Widget
36 35
37 36 one trait:
38 37 widgets, a list of (widget, 'trait_name') tuples which should be linked in the frontend.
39 38 """
40 39 _model_name = Unicode('LinkModel', sync=True)
41 40 widgets = List(WidgetTraitTuple, sync=True)
42 41
43 42 def __init__(self, widgets, **kwargs):
44 43 if len(widgets) < 2:
45 44 raise TypeError("Require at least two widgets to link")
46 45 kwargs['widgets'] = widgets
47 46 super(Link, self).__init__(**kwargs)
48 47
49 48 # for compatibility with traitlet links
50 49 def unlink(self):
51 50 self.close()
52 51
53 52
54 @skip_doctest
55 53 def jslink(*args):
56 54 """Link traits from different widgets together on the frontend so they remain in sync.
57 55
58 56 Parameters
59 57 ----------
60 58 *args : two or more (Widget, 'trait_name') tuples that should be kept in sync.
61 59
62 60 Examples
63 61 --------
64 62
65 63 >>> c = link((widget1, 'value'), (widget2, 'value'), (widget3, 'value'))
66 64 """
67 65 return Link(widgets=args)
68 66
69 67
70 68 class DirectionalLink(Widget):
71 69 """A directional link
72 70
73 71 source: a (Widget, 'trait_name') tuple for the source trait
74 72 targets: one or more (Widget, 'trait_name') tuples that should be updated
75 73 when the source trait changes.
76 74 """
77 75 _model_name = Unicode('DirectionalLinkModel', sync=True)
78 76 targets = List(WidgetTraitTuple, sync=True)
79 77 source = WidgetTraitTuple(sync=True)
80 78
81 79 # Does not quite behave like other widgets but reproduces
82 80 # the behavior of IPython.utils.traitlets.directional_link
83 81 def __init__(self, source, targets, **kwargs):
84 82 if len(targets) < 1:
85 83 raise TypeError("Require at least two widgets to link")
86 84
87 85 kwargs['source'] = source
88 86 kwargs['targets'] = targets
89 87 super(DirectionalLink, self).__init__(**kwargs)
90 88
91 89 # for compatibility with traitlet links
92 90 def unlink(self):
93 91 self.close()
94 92
95 @skip_doctest
96 93 def jsdlink(source, *targets):
97 94 """Link the trait of a source widget with traits of target widgets in the frontend.
98 95
99 96 Parameters
100 97 ----------
101 98 source : a (Widget, 'trait_name') tuple for the source trait
102 99 *targets : one or more (Widget, 'trait_name') tuples that should be updated
103 100 when the source trait changes.
104 101
105 102 Examples
106 103 --------
107 104
108 105 >>> c = dlink((src_widget, 'value'), (tgt_widget1, 'value'), (tgt_widget2, 'value'))
109 106 """
110 107 return DirectionalLink(source=source, targets=targets)
111 108
@@ -1,78 +1,76 b''
1 1 """Output class.
2 2
3 3 Represents a widget that can be used to display output within the widget area.
4 4 """
5 5
6 6 # Copyright (c) IPython Development Team.
7 7 # Distributed under the terms of the Modified BSD License.
8 8
9 9 from .widget import DOMWidget
10 10 import sys
11 11 from IPython.utils.traitlets import Unicode, List
12 12 from IPython.display import clear_output
13 from IPython.testing.skipdoctest import skip_doctest
14 13 from IPython.kernel.zmq.session import Message
15 14
16 @skip_doctest
17 15 class Output(DOMWidget):
18 16 """Widget used as a context manager to display output.
19 17
20 18 This widget can capture and display stdout, stderr, and rich output. To use
21 19 it, create an instance of it and display it. Then use it as a context
22 20 manager. Any output produced while in it's context will be captured and
23 21 displayed in it instead of the standard output area.
24 22
25 23 Example
26 24 from IPython.html import widgets
27 25 from IPython.display import display
28 26 out = widgets.Output()
29 27 display(out)
30 28
31 29 print('prints to output area')
32 30
33 31 with out:
34 32 print('prints to output widget')"""
35 33 _view_name = Unicode('OutputView', sync=True)
36 34
37 35 def clear_output(self, *pargs, **kwargs):
38 36 with self:
39 37 clear_output(*pargs, **kwargs)
40 38
41 39 def __enter__(self):
42 40 """Called upon entering output widget context manager."""
43 41 self._flush()
44 42 kernel = get_ipython().kernel
45 43 session = kernel.session
46 44 send = session.send
47 45 self._original_send = send
48 46 self._session = session
49 47
50 48 def send_hook(stream, msg_or_type, content=None, parent=None, ident=None,
51 49 buffers=None, track=False, header=None, metadata=None):
52 50
53 51 # Handle both prebuild messages and unbuilt messages.
54 52 if isinstance(msg_or_type, (Message, dict)):
55 53 msg_type = msg_or_type['msg_type']
56 54 msg = dict(msg_or_type)
57 55 else:
58 56 msg_type = msg_or_type
59 57 msg = session.msg(msg_type, content=content, parent=parent,
60 58 header=header, metadata=metadata)
61 59
62 60 # If this is a message type that we want to forward, forward it.
63 61 if stream is kernel.iopub_socket and msg_type in ['clear_output', 'stream', 'display_data']:
64 62 self.send(msg)
65 63 else:
66 64 send(stream, msg, ident=ident, buffers=buffers, track=track)
67 65
68 66 session.send = send_hook
69 67
70 68 def __exit__(self, exception_type, exception_value, traceback):
71 69 """Called upon exiting output widget context manager."""
72 70 self._flush()
73 71 self._session.send = self._original_send
74 72
75 73 def _flush(self):
76 74 """Flush stdout and stderr buffers."""
77 75 sys.stdout.flush()
78 76 sys.stderr.flush()
@@ -1,512 +1,510 b''
1 1 # -*- coding: utf-8 -*-
2 2 """
3 3 Defines a variety of Pygments lexers for highlighting IPython code.
4 4
5 5 This includes:
6 6
7 7 IPythonLexer, IPython3Lexer
8 8 Lexers for pure IPython (python + magic/shell commands)
9 9
10 10 IPythonPartialTracebackLexer, IPythonTracebackLexer
11 11 Supports 2.x and 3.x via keyword `python3`. The partial traceback
12 12 lexer reads everything but the Python code appearing in a traceback.
13 13 The full lexer combines the partial lexer with an IPython lexer.
14 14
15 15 IPythonConsoleLexer
16 16 A lexer for IPython console sessions, with support for tracebacks.
17 17
18 18 IPyLexer
19 19 A friendly lexer which examines the first line of text and from it,
20 20 decides whether to use an IPython lexer or an IPython console lexer.
21 21 This is probably the only lexer that needs to be explicitly added
22 22 to Pygments.
23 23
24 24 """
25 25 #-----------------------------------------------------------------------------
26 26 # Copyright (c) 2013, the IPython Development Team.
27 27 #
28 28 # Distributed under the terms of the Modified BSD License.
29 29 #
30 30 # The full license is in the file COPYING.txt, distributed with this software.
31 31 #-----------------------------------------------------------------------------
32 32
33 33 # Standard library
34 34 import re
35 35
36 36 # Third party
37 37 from pygments.lexers import BashLexer, PythonLexer, Python3Lexer
38 38 from pygments.lexer import (
39 39 Lexer, DelegatingLexer, RegexLexer, do_insertions, bygroups, using,
40 40 )
41 41 from pygments.token import (
42 42 Comment, Generic, Keyword, Literal, Name, Operator, Other, Text, Error,
43 43 )
44 44 from pygments.util import get_bool_opt
45 45
46 46 # Local
47 from IPython.testing.skipdoctest import skip_doctest
48 47
49 48 line_re = re.compile('.*?\n')
50 49
51 50 __all__ = ['build_ipy_lexer', 'IPython3Lexer', 'IPythonLexer',
52 51 'IPythonPartialTracebackLexer', 'IPythonTracebackLexer',
53 52 'IPythonConsoleLexer', 'IPyLexer']
54 53
55 54 ipython_tokens = [
56 55 (r"(?s)(\s*)(%%)(\w+)(.*)", bygroups(Text, Operator, Keyword, Text)),
57 56 (r'(?s)(^\s*)(%%!)([^\n]*\n)(.*)', bygroups(Text, Operator, Text, using(BashLexer))),
58 57 (r"(%%?)(\w+)(\?\??)$", bygroups(Operator, Keyword, Operator)),
59 58 (r"\b(\?\??)(\s*)$", bygroups(Operator, Text)),
60 59 (r'(%)(sx|sc|system)(.*)(\n)', bygroups(Operator, Keyword,
61 60 using(BashLexer), Text)),
62 61 (r'(%)(\w+)(.*\n)', bygroups(Operator, Keyword, Text)),
63 62 (r'^(!!)(.+)(\n)', bygroups(Operator, using(BashLexer), Text)),
64 63 (r'(!)(?!=)(.+)(\n)', bygroups(Operator, using(BashLexer), Text)),
65 64 (r'^(\s*)(\?\??)(\s*%{0,2}[\w\.\*]*)', bygroups(Text, Operator, Text)),
66 65 ]
67 66
68 67 def build_ipy_lexer(python3):
69 68 """Builds IPython lexers depending on the value of `python3`.
70 69
71 70 The lexer inherits from an appropriate Python lexer and then adds
72 71 information about IPython specific keywords (i.e. magic commands,
73 72 shell commands, etc.)
74 73
75 74 Parameters
76 75 ----------
77 76 python3 : bool
78 77 If `True`, then build an IPython lexer from a Python 3 lexer.
79 78
80 79 """
81 80 # It would be nice to have a single IPython lexer class which takes
82 81 # a boolean `python3`. But since there are two Python lexer classes,
83 82 # we will also have two IPython lexer classes.
84 83 if python3:
85 84 PyLexer = Python3Lexer
86 85 clsname = 'IPython3Lexer'
87 86 name = 'IPython3'
88 87 aliases = ['ipython3']
89 88 doc = """IPython3 Lexer"""
90 89 else:
91 90 PyLexer = PythonLexer
92 91 clsname = 'IPythonLexer'
93 92 name = 'IPython'
94 93 aliases = ['ipython2', 'ipython']
95 94 doc = """IPython Lexer"""
96 95
97 96 tokens = PyLexer.tokens.copy()
98 97 tokens['root'] = ipython_tokens + tokens['root']
99 98
100 99 attrs = {'name': name, 'aliases': aliases, 'filenames': [],
101 100 '__doc__': doc, 'tokens': tokens}
102 101
103 102 return type(name, (PyLexer,), attrs)
104 103
105 104
106 105 IPython3Lexer = build_ipy_lexer(python3=True)
107 106 IPythonLexer = build_ipy_lexer(python3=False)
108 107
109 108
110 109 class IPythonPartialTracebackLexer(RegexLexer):
111 110 """
112 111 Partial lexer for IPython tracebacks.
113 112
114 113 Handles all the non-python output. This works for both Python 2.x and 3.x.
115 114
116 115 """
117 116 name = 'IPython Partial Traceback'
118 117
119 118 tokens = {
120 119 'root': [
121 120 # Tracebacks for syntax errors have a different style.
122 121 # For both types of tracebacks, we mark the first line with
123 122 # Generic.Traceback. For syntax errors, we mark the filename
124 123 # as we mark the filenames for non-syntax tracebacks.
125 124 #
126 125 # These two regexps define how IPythonConsoleLexer finds a
127 126 # traceback.
128 127 #
129 128 ## Non-syntax traceback
130 129 (r'^(\^C)?(-+\n)', bygroups(Error, Generic.Traceback)),
131 130 ## Syntax traceback
132 131 (r'^( File)(.*)(, line )(\d+\n)',
133 132 bygroups(Generic.Traceback, Name.Namespace,
134 133 Generic.Traceback, Literal.Number.Integer)),
135 134
136 135 # (Exception Identifier)(Whitespace)(Traceback Message)
137 136 (r'(?u)(^[^\d\W]\w*)(\s*)(Traceback.*?\n)',
138 137 bygroups(Name.Exception, Generic.Whitespace, Text)),
139 138 # (Module/Filename)(Text)(Callee)(Function Signature)
140 139 # Better options for callee and function signature?
141 140 (r'(.*)( in )(.*)(\(.*\)\n)',
142 141 bygroups(Name.Namespace, Text, Name.Entity, Name.Tag)),
143 142 # Regular line: (Whitespace)(Line Number)(Python Code)
144 143 (r'(\s*?)(\d+)(.*?\n)',
145 144 bygroups(Generic.Whitespace, Literal.Number.Integer, Other)),
146 145 # Emphasized line: (Arrow)(Line Number)(Python Code)
147 146 # Using Exception token so arrow color matches the Exception.
148 147 (r'(-*>?\s?)(\d+)(.*?\n)',
149 148 bygroups(Name.Exception, Literal.Number.Integer, Other)),
150 149 # (Exception Identifier)(Message)
151 150 (r'(?u)(^[^\d\W]\w*)(:.*?\n)',
152 151 bygroups(Name.Exception, Text)),
153 152 # Tag everything else as Other, will be handled later.
154 153 (r'.*\n', Other),
155 154 ],
156 155 }
157 156
158 157
159 158 class IPythonTracebackLexer(DelegatingLexer):
160 159 """
161 160 IPython traceback lexer.
162 161
163 162 For doctests, the tracebacks can be snipped as much as desired with the
164 163 exception to the lines that designate a traceback. For non-syntax error
165 164 tracebacks, this is the line of hyphens. For syntax error tracebacks,
166 165 this is the line which lists the File and line number.
167 166
168 167 """
169 168 # The lexer inherits from DelegatingLexer. The "root" lexer is an
170 169 # appropriate IPython lexer, which depends on the value of the boolean
171 170 # `python3`. First, we parse with the partial IPython traceback lexer.
172 171 # Then, any code marked with the "Other" token is delegated to the root
173 172 # lexer.
174 173 #
175 174 name = 'IPython Traceback'
176 175 aliases = ['ipythontb']
177 176
178 177 def __init__(self, **options):
179 178 self.python3 = get_bool_opt(options, 'python3', False)
180 179 if self.python3:
181 180 self.aliases = ['ipython3tb']
182 181 else:
183 182 self.aliases = ['ipython2tb', 'ipythontb']
184 183
185 184 if self.python3:
186 185 IPyLexer = IPython3Lexer
187 186 else:
188 187 IPyLexer = IPythonLexer
189 188
190 189 DelegatingLexer.__init__(self, IPyLexer,
191 190 IPythonPartialTracebackLexer, **options)
192 191
193 @skip_doctest
194 192 class IPythonConsoleLexer(Lexer):
195 193 """
196 194 An IPython console lexer for IPython code-blocks and doctests, such as:
197 195
198 196 .. code-block:: rst
199 197
200 198 .. code-block:: ipythonconsole
201 199
202 200 In [1]: a = 'foo'
203 201
204 202 In [2]: a
205 203 Out[2]: 'foo'
206 204
207 205 In [3]: print a
208 206 foo
209 207
210 208 In [4]: 1 / 0
211 209
212 210
213 211 Support is also provided for IPython exceptions:
214 212
215 213 .. code-block:: rst
216 214
217 215 .. code-block:: ipythonconsole
218 216
219 217 In [1]: raise Exception
220 218
221 219 ---------------------------------------------------------------------------
222 220 Exception Traceback (most recent call last)
223 221 <ipython-input-1-fca2ab0ca76b> in <module>()
224 222 ----> 1 raise Exception
225 223
226 224 Exception:
227 225
228 226 """
229 227 name = 'IPython console session'
230 228 aliases = ['ipythonconsole']
231 229 mimetypes = ['text/x-ipython-console']
232 230
233 231 # The regexps used to determine what is input and what is output.
234 232 # The default prompts for IPython are:
235 233 #
236 234 # c.PromptManager.in_template = 'In [\#]: '
237 235 # c.PromptManager.in2_template = ' .\D.: '
238 236 # c.PromptManager.out_template = 'Out[\#]: '
239 237 #
240 238 in1_regex = r'In \[[0-9]+\]: '
241 239 in2_regex = r' \.\.+\.: '
242 240 out_regex = r'Out\[[0-9]+\]: '
243 241
244 242 #: The regex to determine when a traceback starts.
245 243 ipytb_start = re.compile(r'^(\^C)?(-+\n)|^( File)(.*)(, line )(\d+\n)')
246 244
247 245 def __init__(self, **options):
248 246 """Initialize the IPython console lexer.
249 247
250 248 Parameters
251 249 ----------
252 250 python3 : bool
253 251 If `True`, then the console inputs are parsed using a Python 3
254 252 lexer. Otherwise, they are parsed using a Python 2 lexer.
255 253 in1_regex : RegexObject
256 254 The compiled regular expression used to detect the start
257 255 of inputs. Although the IPython configuration setting may have a
258 256 trailing whitespace, do not include it in the regex. If `None`,
259 257 then the default input prompt is assumed.
260 258 in2_regex : RegexObject
261 259 The compiled regular expression used to detect the continuation
262 260 of inputs. Although the IPython configuration setting may have a
263 261 trailing whitespace, do not include it in the regex. If `None`,
264 262 then the default input prompt is assumed.
265 263 out_regex : RegexObject
266 264 The compiled regular expression used to detect outputs. If `None`,
267 265 then the default output prompt is assumed.
268 266
269 267 """
270 268 self.python3 = get_bool_opt(options, 'python3', False)
271 269 if self.python3:
272 270 self.aliases = ['ipython3console']
273 271 else:
274 272 self.aliases = ['ipython2console', 'ipythonconsole']
275 273
276 274 in1_regex = options.get('in1_regex', self.in1_regex)
277 275 in2_regex = options.get('in2_regex', self.in2_regex)
278 276 out_regex = options.get('out_regex', self.out_regex)
279 277
280 278 # So that we can work with input and output prompts which have been
281 279 # rstrip'd (possibly by editors) we also need rstrip'd variants. If
282 280 # we do not do this, then such prompts will be tagged as 'output'.
283 281 # The reason can't just use the rstrip'd variants instead is because
284 282 # we want any whitespace associated with the prompt to be inserted
285 283 # with the token. This allows formatted code to be modified so as hide
286 284 # the appearance of prompts, with the whitespace included. One example
287 285 # use of this is in copybutton.js from the standard lib Python docs.
288 286 in1_regex_rstrip = in1_regex.rstrip() + '\n'
289 287 in2_regex_rstrip = in2_regex.rstrip() + '\n'
290 288 out_regex_rstrip = out_regex.rstrip() + '\n'
291 289
292 290 # Compile and save them all.
293 291 attrs = ['in1_regex', 'in2_regex', 'out_regex',
294 292 'in1_regex_rstrip', 'in2_regex_rstrip', 'out_regex_rstrip']
295 293 for attr in attrs:
296 294 self.__setattr__(attr, re.compile(locals()[attr]))
297 295
298 296 Lexer.__init__(self, **options)
299 297
300 298 if self.python3:
301 299 pylexer = IPython3Lexer
302 300 tblexer = IPythonTracebackLexer
303 301 else:
304 302 pylexer = IPythonLexer
305 303 tblexer = IPythonTracebackLexer
306 304
307 305 self.pylexer = pylexer(**options)
308 306 self.tblexer = tblexer(**options)
309 307
310 308 self.reset()
311 309
312 310 def reset(self):
313 311 self.mode = 'output'
314 312 self.index = 0
315 313 self.buffer = u''
316 314 self.insertions = []
317 315
318 316 def buffered_tokens(self):
319 317 """
320 318 Generator of unprocessed tokens after doing insertions and before
321 319 changing to a new state.
322 320
323 321 """
324 322 if self.mode == 'output':
325 323 tokens = [(0, Generic.Output, self.buffer)]
326 324 elif self.mode == 'input':
327 325 tokens = self.pylexer.get_tokens_unprocessed(self.buffer)
328 326 else: # traceback
329 327 tokens = self.tblexer.get_tokens_unprocessed(self.buffer)
330 328
331 329 for i, t, v in do_insertions(self.insertions, tokens):
332 330 # All token indexes are relative to the buffer.
333 331 yield self.index + i, t, v
334 332
335 333 # Clear it all
336 334 self.index += len(self.buffer)
337 335 self.buffer = u''
338 336 self.insertions = []
339 337
340 338 def get_mci(self, line):
341 339 """
342 340 Parses the line and returns a 3-tuple: (mode, code, insertion).
343 341
344 342 `mode` is the next mode (or state) of the lexer, and is always equal
345 343 to 'input', 'output', or 'tb'.
346 344
347 345 `code` is a portion of the line that should be added to the buffer
348 346 corresponding to the next mode and eventually lexed by another lexer.
349 347 For example, `code` could be Python code if `mode` were 'input'.
350 348
351 349 `insertion` is a 3-tuple (index, token, text) representing an
352 350 unprocessed "token" that will be inserted into the stream of tokens
353 351 that are created from the buffer once we change modes. This is usually
354 352 the input or output prompt.
355 353
356 354 In general, the next mode depends on current mode and on the contents
357 355 of `line`.
358 356
359 357 """
360 358 # To reduce the number of regex match checks, we have multiple
361 359 # 'if' blocks instead of 'if-elif' blocks.
362 360
363 361 # Check for possible end of input
364 362 in2_match = self.in2_regex.match(line)
365 363 in2_match_rstrip = self.in2_regex_rstrip.match(line)
366 364 if (in2_match and in2_match.group().rstrip() == line.rstrip()) or \
367 365 in2_match_rstrip:
368 366 end_input = True
369 367 else:
370 368 end_input = False
371 369 if end_input and self.mode != 'tb':
372 370 # Only look for an end of input when not in tb mode.
373 371 # An ellipsis could appear within the traceback.
374 372 mode = 'output'
375 373 code = u''
376 374 insertion = (0, Generic.Prompt, line)
377 375 return mode, code, insertion
378 376
379 377 # Check for output prompt
380 378 out_match = self.out_regex.match(line)
381 379 out_match_rstrip = self.out_regex_rstrip.match(line)
382 380 if out_match or out_match_rstrip:
383 381 mode = 'output'
384 382 if out_match:
385 383 idx = out_match.end()
386 384 else:
387 385 idx = out_match_rstrip.end()
388 386 code = line[idx:]
389 387 # Use the 'heading' token for output. We cannot use Generic.Error
390 388 # since it would conflict with exceptions.
391 389 insertion = (0, Generic.Heading, line[:idx])
392 390 return mode, code, insertion
393 391
394 392
395 393 # Check for input or continuation prompt (non stripped version)
396 394 in1_match = self.in1_regex.match(line)
397 395 if in1_match or (in2_match and self.mode != 'tb'):
398 396 # New input or when not in tb, continued input.
399 397 # We do not check for continued input when in tb since it is
400 398 # allowable to replace a long stack with an ellipsis.
401 399 mode = 'input'
402 400 if in1_match:
403 401 idx = in1_match.end()
404 402 else: # in2_match
405 403 idx = in2_match.end()
406 404 code = line[idx:]
407 405 insertion = (0, Generic.Prompt, line[:idx])
408 406 return mode, code, insertion
409 407
410 408 # Check for input or continuation prompt (stripped version)
411 409 in1_match_rstrip = self.in1_regex_rstrip.match(line)
412 410 if in1_match_rstrip or (in2_match_rstrip and self.mode != 'tb'):
413 411 # New input or when not in tb, continued input.
414 412 # We do not check for continued input when in tb since it is
415 413 # allowable to replace a long stack with an ellipsis.
416 414 mode = 'input'
417 415 if in1_match_rstrip:
418 416 idx = in1_match_rstrip.end()
419 417 else: # in2_match
420 418 idx = in2_match_rstrip.end()
421 419 code = line[idx:]
422 420 insertion = (0, Generic.Prompt, line[:idx])
423 421 return mode, code, insertion
424 422
425 423 # Check for traceback
426 424 if self.ipytb_start.match(line):
427 425 mode = 'tb'
428 426 code = line
429 427 insertion = None
430 428 return mode, code, insertion
431 429
432 430 # All other stuff...
433 431 if self.mode in ('input', 'output'):
434 432 # We assume all other text is output. Multiline input that
435 433 # does not use the continuation marker cannot be detected.
436 434 # For example, the 3 in the following is clearly output:
437 435 #
438 436 # In [1]: print 3
439 437 # 3
440 438 #
441 439 # But the following second line is part of the input:
442 440 #
443 441 # In [2]: while True:
444 442 # print True
445 443 #
446 444 # In both cases, the 2nd line will be 'output'.
447 445 #
448 446 mode = 'output'
449 447 else:
450 448 mode = 'tb'
451 449
452 450 code = line
453 451 insertion = None
454 452
455 453 return mode, code, insertion
456 454
457 455 def get_tokens_unprocessed(self, text):
458 456 self.reset()
459 457 for match in line_re.finditer(text):
460 458 line = match.group()
461 459 mode, code, insertion = self.get_mci(line)
462 460
463 461 if mode != self.mode:
464 462 # Yield buffered tokens before transitioning to new mode.
465 463 for token in self.buffered_tokens():
466 464 yield token
467 465 self.mode = mode
468 466
469 467 if insertion:
470 468 self.insertions.append((len(self.buffer), [insertion]))
471 469 self.buffer += code
472 470 else:
473 471 for token in self.buffered_tokens():
474 472 yield token
475 473
476 474 class IPyLexer(Lexer):
477 475 """
478 476 Primary lexer for all IPython-like code.
479 477
480 478 This is a simple helper lexer. If the first line of the text begins with
481 479 "In \[[0-9]+\]:", then the entire text is parsed with an IPython console
482 480 lexer. If not, then the entire text is parsed with an IPython lexer.
483 481
484 482 The goal is to reduce the number of lexers that are registered
485 483 with Pygments.
486 484
487 485 """
488 486 name = 'IPy session'
489 487 aliases = ['ipy']
490 488
491 489 def __init__(self, **options):
492 490 self.python3 = get_bool_opt(options, 'python3', False)
493 491 if self.python3:
494 492 self.aliases = ['ipy3']
495 493 else:
496 494 self.aliases = ['ipy2', 'ipy']
497 495
498 496 Lexer.__init__(self, **options)
499 497
500 498 self.IPythonLexer = IPythonLexer(**options)
501 499 self.IPythonConsoleLexer = IPythonConsoleLexer(**options)
502 500
503 501 def get_tokens_unprocessed(self, text):
504 502 # Search for the input prompt anywhere...this allows code blocks to
505 503 # begin with comments as well.
506 504 if re.match(r'.*(In \[[0-9]+\]:)', text.strip(), re.DOTALL):
507 505 lex = self.IPythonConsoleLexer
508 506 else:
509 507 lex = self.IPythonLexer
510 508 for token in lex.get_tokens_unprocessed(text):
511 509 yield token
512 510
@@ -1,116 +1,114 b''
1 1 """
2 2 Password generation for the IPython notebook.
3 3 """
4 4 #-----------------------------------------------------------------------------
5 5 # Imports
6 6 #-----------------------------------------------------------------------------
7 7 # Stdlib
8 8 import getpass
9 9 import hashlib
10 10 import random
11 11
12 12 # Our own
13 13 from IPython.core.error import UsageError
14 from IPython.testing.skipdoctest import skip_doctest
15 14 from IPython.utils.py3compat import cast_bytes, str_to_bytes
16 15
17 16 #-----------------------------------------------------------------------------
18 17 # Globals
19 18 #-----------------------------------------------------------------------------
20 19
21 20 # Length of the salt in nr of hex chars, which implies salt_len * 4
22 21 # bits of randomness.
23 22 salt_len = 12
24 23
25 24 #-----------------------------------------------------------------------------
26 25 # Functions
27 26 #-----------------------------------------------------------------------------
28 27
29 @skip_doctest
30 28 def passwd(passphrase=None, algorithm='sha1'):
31 29 """Generate hashed password and salt for use in notebook configuration.
32 30
33 31 In the notebook configuration, set `c.NotebookApp.password` to
34 32 the generated string.
35 33
36 34 Parameters
37 35 ----------
38 36 passphrase : str
39 37 Password to hash. If unspecified, the user is asked to input
40 38 and verify a password.
41 39 algorithm : str
42 40 Hashing algorithm to use (e.g, 'sha1' or any argument supported
43 41 by :func:`hashlib.new`).
44 42
45 43 Returns
46 44 -------
47 45 hashed_passphrase : str
48 46 Hashed password, in the format 'hash_algorithm:salt:passphrase_hash'.
49 47
50 48 Examples
51 49 --------
52 50 >>> passwd('mypassword')
53 51 'sha1:7cf3:b7d6da294ea9592a9480c8f52e63cd42cfb9dd12'
54 52
55 53 """
56 54 if passphrase is None:
57 55 for i in range(3):
58 56 p0 = getpass.getpass('Enter password: ')
59 57 p1 = getpass.getpass('Verify password: ')
60 58 if p0 == p1:
61 59 passphrase = p0
62 60 break
63 61 else:
64 62 print('Passwords do not match.')
65 63 else:
66 64 raise UsageError('No matching passwords found. Giving up.')
67 65
68 66 h = hashlib.new(algorithm)
69 67 salt = ('%0' + str(salt_len) + 'x') % random.getrandbits(4 * salt_len)
70 68 h.update(cast_bytes(passphrase, 'utf-8') + str_to_bytes(salt, 'ascii'))
71 69
72 70 return ':'.join((algorithm, salt, h.hexdigest()))
73 71
74 72
75 73 def passwd_check(hashed_passphrase, passphrase):
76 74 """Verify that a given passphrase matches its hashed version.
77 75
78 76 Parameters
79 77 ----------
80 78 hashed_passphrase : str
81 79 Hashed password, in the format returned by `passwd`.
82 80 passphrase : str
83 81 Passphrase to validate.
84 82
85 83 Returns
86 84 -------
87 85 valid : bool
88 86 True if the passphrase matches the hash.
89 87
90 88 Examples
91 89 --------
92 90 >>> from IPython.lib.security import passwd_check
93 91 >>> passwd_check('sha1:0e112c3ddfce:a68df677475c2b47b6e86d0467eec97ac5f4b85a',
94 92 ... 'mypassword')
95 93 True
96 94
97 95 >>> passwd_check('sha1:0e112c3ddfce:a68df677475c2b47b6e86d0467eec97ac5f4b85a',
98 96 ... 'anotherpassword')
99 97 False
100 98 """
101 99 try:
102 100 algorithm, salt, pw_digest = hashed_passphrase.split(':', 2)
103 101 except (ValueError, TypeError):
104 102 return False
105 103
106 104 try:
107 105 h = hashlib.new(algorithm)
108 106 except ValueError:
109 107 return False
110 108
111 109 if len(pw_digest) == 0:
112 110 return False
113 111
114 112 h.update(cast_bytes(passphrase, 'utf-8') + cast_bytes(salt, 'ascii'))
115 113
116 114 return h.hexdigest() == pw_digest
@@ -1,642 +1,640 b''
1 1 # -*- coding: utf-8 -*-
2 2 """Subclass of InteractiveShell for terminal based frontends."""
3 3
4 4 #-----------------------------------------------------------------------------
5 5 # Copyright (C) 2001 Janko Hauser <jhauser@zscout.de>
6 6 # Copyright (C) 2001-2007 Fernando Perez. <fperez@colorado.edu>
7 7 # Copyright (C) 2008-2011 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16 from __future__ import print_function
17 17
18 18 import bdb
19 19 import os
20 20 import sys
21 21
22 22 from IPython.core.error import TryNext, UsageError
23 23 from IPython.core.usage import interactive_usage
24 24 from IPython.core.inputsplitter import IPythonInputSplitter
25 25 from IPython.core.interactiveshell import InteractiveShell, InteractiveShellABC
26 26 from IPython.core.magic import Magics, magics_class, line_magic
27 27 from IPython.lib.clipboard import ClipboardEmpty
28 from IPython.testing.skipdoctest import skip_doctest
29 28 from IPython.utils.encoding import get_stream_enc
30 29 from IPython.utils import py3compat
31 30 from IPython.utils.terminal import toggle_set_term_title, set_term_title
32 31 from IPython.utils.process import abbrev_cwd
33 32 from IPython.utils.warn import warn, error
34 33 from IPython.utils.text import num_ini_spaces, SList, strip_email_quotes
35 34 from IPython.utils.traitlets import Integer, CBool, Unicode
36 35
37 36 #-----------------------------------------------------------------------------
38 37 # Utilities
39 38 #-----------------------------------------------------------------------------
40 39
41 40 def get_default_editor():
42 41 try:
43 42 ed = os.environ['EDITOR']
44 43 if not py3compat.PY3:
45 44 ed = ed.decode()
46 45 return ed
47 46 except KeyError:
48 47 pass
49 48 except UnicodeError:
50 49 warn("$EDITOR environment variable is not pure ASCII. Using platform "
51 50 "default editor.")
52 51
53 52 if os.name == 'posix':
54 53 return 'vi' # the only one guaranteed to be there!
55 54 else:
56 55 return 'notepad' # same in Windows!
57 56
58 57 def get_pasted_lines(sentinel, l_input=py3compat.input, quiet=False):
59 58 """ Yield pasted lines until the user enters the given sentinel value.
60 59 """
61 60 if not quiet:
62 61 print("Pasting code; enter '%s' alone on the line to stop or use Ctrl-D." \
63 62 % sentinel)
64 63 prompt = ":"
65 64 else:
66 65 prompt = ""
67 66 while True:
68 67 try:
69 68 l = py3compat.str_to_unicode(l_input(prompt))
70 69 if l == sentinel:
71 70 return
72 71 else:
73 72 yield l
74 73 except EOFError:
75 74 print('<EOF>')
76 75 return
77 76
78 77
79 78 #------------------------------------------------------------------------
80 79 # Terminal-specific magics
81 80 #------------------------------------------------------------------------
82 81
83 82 @magics_class
84 83 class TerminalMagics(Magics):
85 84 def __init__(self, shell):
86 85 super(TerminalMagics, self).__init__(shell)
87 86 self.input_splitter = IPythonInputSplitter()
88 87
89 88 def store_or_execute(self, block, name):
90 89 """ Execute a block, or store it in a variable, per the user's request.
91 90 """
92 91 if name:
93 92 # If storing it for further editing
94 93 self.shell.user_ns[name] = SList(block.splitlines())
95 94 print("Block assigned to '%s'" % name)
96 95 else:
97 96 b = self.preclean_input(block)
98 97 self.shell.user_ns['pasted_block'] = b
99 98 self.shell.using_paste_magics = True
100 99 try:
101 100 self.shell.run_cell(b)
102 101 finally:
103 102 self.shell.using_paste_magics = False
104 103
105 104 def preclean_input(self, block):
106 105 lines = block.splitlines()
107 106 while lines and not lines[0].strip():
108 107 lines = lines[1:]
109 108 return strip_email_quotes('\n'.join(lines))
110 109
111 110 def rerun_pasted(self, name='pasted_block'):
112 111 """ Rerun a previously pasted command.
113 112 """
114 113 b = self.shell.user_ns.get(name)
115 114
116 115 # Sanity checks
117 116 if b is None:
118 117 raise UsageError('No previous pasted block available')
119 118 if not isinstance(b, py3compat.string_types):
120 119 raise UsageError(
121 120 "Variable 'pasted_block' is not a string, can't execute")
122 121
123 122 print("Re-executing '%s...' (%d chars)"% (b.split('\n',1)[0], len(b)))
124 123 self.shell.run_cell(b)
125 124
126 125 @line_magic
127 126 def autoindent(self, parameter_s = ''):
128 127 """Toggle autoindent on/off (if available)."""
129 128
130 129 self.shell.set_autoindent()
131 130 print("Automatic indentation is:",['OFF','ON'][self.shell.autoindent])
132 131
133 @skip_doctest
134 132 @line_magic
135 133 def cpaste(self, parameter_s=''):
136 134 """Paste & execute a pre-formatted code block from clipboard.
137 135
138 136 You must terminate the block with '--' (two minus-signs) or Ctrl-D
139 137 alone on the line. You can also provide your own sentinel with '%paste
140 138 -s %%' ('%%' is the new sentinel for this operation).
141 139
142 140 The block is dedented prior to execution to enable execution of method
143 141 definitions. '>' and '+' characters at the beginning of a line are
144 142 ignored, to allow pasting directly from e-mails, diff files and
145 143 doctests (the '...' continuation prompt is also stripped). The
146 144 executed block is also assigned to variable named 'pasted_block' for
147 145 later editing with '%edit pasted_block'.
148 146
149 147 You can also pass a variable name as an argument, e.g. '%cpaste foo'.
150 148 This assigns the pasted block to variable 'foo' as string, without
151 149 dedenting or executing it (preceding >>> and + is still stripped)
152 150
153 151 '%cpaste -r' re-executes the block previously entered by cpaste.
154 152 '%cpaste -q' suppresses any additional output messages.
155 153
156 154 Do not be alarmed by garbled output on Windows (it's a readline bug).
157 155 Just press enter and type -- (and press enter again) and the block
158 156 will be what was just pasted.
159 157
160 158 IPython statements (magics, shell escapes) are not supported (yet).
161 159
162 160 See also
163 161 --------
164 162 paste: automatically pull code from clipboard.
165 163
166 164 Examples
167 165 --------
168 166 ::
169 167
170 168 In [8]: %cpaste
171 169 Pasting code; enter '--' alone on the line to stop.
172 170 :>>> a = ["world!", "Hello"]
173 171 :>>> print " ".join(sorted(a))
174 172 :--
175 173 Hello world!
176 174 """
177 175 opts, name = self.parse_options(parameter_s, 'rqs:', mode='string')
178 176 if 'r' in opts:
179 177 self.rerun_pasted()
180 178 return
181 179
182 180 quiet = ('q' in opts)
183 181
184 182 sentinel = opts.get('s', u'--')
185 183 block = '\n'.join(get_pasted_lines(sentinel, quiet=quiet))
186 184 self.store_or_execute(block, name)
187 185
188 186 @line_magic
189 187 def paste(self, parameter_s=''):
190 188 """Paste & execute a pre-formatted code block from clipboard.
191 189
192 190 The text is pulled directly from the clipboard without user
193 191 intervention and printed back on the screen before execution (unless
194 192 the -q flag is given to force quiet mode).
195 193
196 194 The block is dedented prior to execution to enable execution of method
197 195 definitions. '>' and '+' characters at the beginning of a line are
198 196 ignored, to allow pasting directly from e-mails, diff files and
199 197 doctests (the '...' continuation prompt is also stripped). The
200 198 executed block is also assigned to variable named 'pasted_block' for
201 199 later editing with '%edit pasted_block'.
202 200
203 201 You can also pass a variable name as an argument, e.g. '%paste foo'.
204 202 This assigns the pasted block to variable 'foo' as string, without
205 203 executing it (preceding >>> and + is still stripped).
206 204
207 205 Options:
208 206
209 207 -r: re-executes the block previously entered by cpaste.
210 208
211 209 -q: quiet mode: do not echo the pasted text back to the terminal.
212 210
213 211 IPython statements (magics, shell escapes) are not supported (yet).
214 212
215 213 See also
216 214 --------
217 215 cpaste: manually paste code into terminal until you mark its end.
218 216 """
219 217 opts, name = self.parse_options(parameter_s, 'rq', mode='string')
220 218 if 'r' in opts:
221 219 self.rerun_pasted()
222 220 return
223 221 try:
224 222 block = self.shell.hooks.clipboard_get()
225 223 except TryNext as clipboard_exc:
226 224 message = getattr(clipboard_exc, 'args')
227 225 if message:
228 226 error(message[0])
229 227 else:
230 228 error('Could not get text from the clipboard.')
231 229 return
232 230 except ClipboardEmpty:
233 231 raise UsageError("The clipboard appears to be empty")
234 232
235 233 # By default, echo back to terminal unless quiet mode is requested
236 234 if 'q' not in opts:
237 235 write = self.shell.write
238 236 write(self.shell.pycolorize(block))
239 237 if not block.endswith('\n'):
240 238 write('\n')
241 239 write("## -- End pasted text --\n")
242 240
243 241 self.store_or_execute(block, name)
244 242
245 243 # Class-level: add a '%cls' magic only on Windows
246 244 if sys.platform == 'win32':
247 245 @line_magic
248 246 def cls(self, s):
249 247 """Clear screen.
250 248 """
251 249 os.system("cls")
252 250
253 251 #-----------------------------------------------------------------------------
254 252 # Main class
255 253 #-----------------------------------------------------------------------------
256 254
257 255 class TerminalInteractiveShell(InteractiveShell):
258 256
259 257 autoedit_syntax = CBool(False, config=True,
260 258 help="auto editing of files with syntax errors.")
261 259 confirm_exit = CBool(True, config=True,
262 260 help="""
263 261 Set to confirm when you try to exit IPython with an EOF (Control-D
264 262 in Unix, Control-Z/Enter in Windows). By typing 'exit' or 'quit',
265 263 you can force a direct exit without any confirmation.""",
266 264 )
267 265 # This display_banner only controls whether or not self.show_banner()
268 266 # is called when mainloop/interact are called. The default is False
269 267 # because for the terminal based application, the banner behavior
270 268 # is controlled by the application.
271 269 display_banner = CBool(False) # This isn't configurable!
272 270 embedded = CBool(False)
273 271 embedded_active = CBool(False)
274 272 editor = Unicode(get_default_editor(), config=True,
275 273 help="Set the editor used by IPython (default to $EDITOR/vi/notepad)."
276 274 )
277 275 pager = Unicode('less', config=True,
278 276 help="The shell program to be used for paging.")
279 277
280 278 screen_length = Integer(0, config=True,
281 279 help=
282 280 """Number of lines of your screen, used to control printing of very
283 281 long strings. Strings longer than this number of lines will be sent
284 282 through a pager instead of directly printed. The default value for
285 283 this is 0, which means IPython will auto-detect your screen size every
286 284 time it needs to print certain potentially long strings (this doesn't
287 285 change the behavior of the 'print' keyword, it's only triggered
288 286 internally). If for some reason this isn't working well (it needs
289 287 curses support), specify it yourself. Otherwise don't change the
290 288 default.""",
291 289 )
292 290 term_title = CBool(False, config=True,
293 291 help="Enable auto setting the terminal title."
294 292 )
295 293 usage = Unicode(interactive_usage)
296 294
297 295 # This `using_paste_magics` is used to detect whether the code is being
298 296 # executed via paste magics functions
299 297 using_paste_magics = CBool(False)
300 298
301 299 # In the terminal, GUI control is done via PyOS_InputHook
302 300 @staticmethod
303 301 def enable_gui(gui=None, app=None):
304 302 """Switch amongst GUI input hooks by name.
305 303 """
306 304 # Deferred import
307 305 from IPython.lib.inputhook import enable_gui as real_enable_gui
308 306 try:
309 307 return real_enable_gui(gui, app)
310 308 except ValueError as e:
311 309 raise UsageError("%s" % e)
312 310
313 311 system = InteractiveShell.system_raw
314 312
315 313 #-------------------------------------------------------------------------
316 314 # Overrides of init stages
317 315 #-------------------------------------------------------------------------
318 316
319 317 def init_display_formatter(self):
320 318 super(TerminalInteractiveShell, self).init_display_formatter()
321 319 # terminal only supports plaintext
322 320 self.display_formatter.active_types = ['text/plain']
323 321
324 322 #-------------------------------------------------------------------------
325 323 # Things related to the terminal
326 324 #-------------------------------------------------------------------------
327 325
328 326 @property
329 327 def usable_screen_length(self):
330 328 if self.screen_length == 0:
331 329 return 0
332 330 else:
333 331 num_lines_bot = self.separate_in.count('\n')+1
334 332 return self.screen_length - num_lines_bot
335 333
336 334 def _term_title_changed(self, name, new_value):
337 335 self.init_term_title()
338 336
339 337 def init_term_title(self):
340 338 # Enable or disable the terminal title.
341 339 if self.term_title:
342 340 toggle_set_term_title(True)
343 341 set_term_title('IPython: ' + abbrev_cwd())
344 342 else:
345 343 toggle_set_term_title(False)
346 344
347 345 #-------------------------------------------------------------------------
348 346 # Things related to aliases
349 347 #-------------------------------------------------------------------------
350 348
351 349 def init_alias(self):
352 350 # The parent class defines aliases that can be safely used with any
353 351 # frontend.
354 352 super(TerminalInteractiveShell, self).init_alias()
355 353
356 354 # Now define aliases that only make sense on the terminal, because they
357 355 # need direct access to the console in a way that we can't emulate in
358 356 # GUI or web frontend
359 357 if os.name == 'posix':
360 358 aliases = [('clear', 'clear'), ('more', 'more'), ('less', 'less'),
361 359 ('man', 'man')]
362 360 else :
363 361 aliases = []
364 362
365 363 for name, cmd in aliases:
366 364 self.alias_manager.soft_define_alias(name, cmd)
367 365
368 366 #-------------------------------------------------------------------------
369 367 # Mainloop and code execution logic
370 368 #-------------------------------------------------------------------------
371 369
372 370 def mainloop(self, display_banner=None):
373 371 """Start the mainloop.
374 372
375 373 If an optional banner argument is given, it will override the
376 374 internally created default banner.
377 375 """
378 376
379 377 with self.builtin_trap, self.display_trap:
380 378
381 379 while 1:
382 380 try:
383 381 self.interact(display_banner=display_banner)
384 382 #self.interact_with_readline()
385 383 # XXX for testing of a readline-decoupled repl loop, call
386 384 # interact_with_readline above
387 385 break
388 386 except KeyboardInterrupt:
389 387 # this should not be necessary, but KeyboardInterrupt
390 388 # handling seems rather unpredictable...
391 389 self.write("\nKeyboardInterrupt in interact()\n")
392 390
393 391 def _replace_rlhist_multiline(self, source_raw, hlen_before_cell):
394 392 """Store multiple lines as a single entry in history"""
395 393
396 394 # do nothing without readline or disabled multiline
397 395 if not self.has_readline or not self.multiline_history:
398 396 return hlen_before_cell
399 397
400 398 # windows rl has no remove_history_item
401 399 if not hasattr(self.readline, "remove_history_item"):
402 400 return hlen_before_cell
403 401
404 402 # skip empty cells
405 403 if not source_raw.rstrip():
406 404 return hlen_before_cell
407 405
408 406 # nothing changed do nothing, e.g. when rl removes consecutive dups
409 407 hlen = self.readline.get_current_history_length()
410 408 if hlen == hlen_before_cell:
411 409 return hlen_before_cell
412 410
413 411 for i in range(hlen - hlen_before_cell):
414 412 self.readline.remove_history_item(hlen - i - 1)
415 413 stdin_encoding = get_stream_enc(sys.stdin, 'utf-8')
416 414 self.readline.add_history(py3compat.unicode_to_str(source_raw.rstrip(),
417 415 stdin_encoding))
418 416 return self.readline.get_current_history_length()
419 417
420 418 def interact(self, display_banner=None):
421 419 """Closely emulate the interactive Python console."""
422 420
423 421 # batch run -> do not interact
424 422 if self.exit_now:
425 423 return
426 424
427 425 if display_banner is None:
428 426 display_banner = self.display_banner
429 427
430 428 if isinstance(display_banner, py3compat.string_types):
431 429 self.show_banner(display_banner)
432 430 elif display_banner:
433 431 self.show_banner()
434 432
435 433 more = False
436 434
437 435 if self.has_readline:
438 436 self.readline_startup_hook(self.pre_readline)
439 437 hlen_b4_cell = self.readline.get_current_history_length()
440 438 else:
441 439 hlen_b4_cell = 0
442 440 # exit_now is set by a call to %Exit or %Quit, through the
443 441 # ask_exit callback.
444 442
445 443 while not self.exit_now:
446 444 self.hooks.pre_prompt_hook()
447 445 if more:
448 446 try:
449 447 prompt = self.prompt_manager.render('in2')
450 448 except:
451 449 self.showtraceback()
452 450 if self.autoindent:
453 451 self.rl_do_indent = True
454 452
455 453 else:
456 454 try:
457 455 prompt = self.separate_in + self.prompt_manager.render('in')
458 456 except:
459 457 self.showtraceback()
460 458 try:
461 459 line = self.raw_input(prompt)
462 460 if self.exit_now:
463 461 # quick exit on sys.std[in|out] close
464 462 break
465 463 if self.autoindent:
466 464 self.rl_do_indent = False
467 465
468 466 except KeyboardInterrupt:
469 467 #double-guard against keyboardinterrupts during kbdint handling
470 468 try:
471 469 self.write('\n' + self.get_exception_only())
472 470 source_raw = self.input_splitter.raw_reset()
473 471 hlen_b4_cell = \
474 472 self._replace_rlhist_multiline(source_raw, hlen_b4_cell)
475 473 more = False
476 474 except KeyboardInterrupt:
477 475 pass
478 476 except EOFError:
479 477 if self.autoindent:
480 478 self.rl_do_indent = False
481 479 if self.has_readline:
482 480 self.readline_startup_hook(None)
483 481 self.write('\n')
484 482 self.exit()
485 483 except bdb.BdbQuit:
486 484 warn('The Python debugger has exited with a BdbQuit exception.\n'
487 485 'Because of how pdb handles the stack, it is impossible\n'
488 486 'for IPython to properly format this particular exception.\n'
489 487 'IPython will resume normal operation.')
490 488 except:
491 489 # exceptions here are VERY RARE, but they can be triggered
492 490 # asynchronously by signal handlers, for example.
493 491 self.showtraceback()
494 492 else:
495 493 try:
496 494 self.input_splitter.push(line)
497 495 more = self.input_splitter.push_accepts_more()
498 496 except SyntaxError:
499 497 # Run the code directly - run_cell takes care of displaying
500 498 # the exception.
501 499 more = False
502 500 if (self.SyntaxTB.last_syntax_error and
503 501 self.autoedit_syntax):
504 502 self.edit_syntax_error()
505 503 if not more:
506 504 source_raw = self.input_splitter.raw_reset()
507 505 self.run_cell(source_raw, store_history=True)
508 506 hlen_b4_cell = \
509 507 self._replace_rlhist_multiline(source_raw, hlen_b4_cell)
510 508
511 509 # Turn off the exit flag, so the mainloop can be restarted if desired
512 510 self.exit_now = False
513 511
514 512 def raw_input(self, prompt=''):
515 513 """Write a prompt and read a line.
516 514
517 515 The returned line does not include the trailing newline.
518 516 When the user enters the EOF key sequence, EOFError is raised.
519 517
520 518 Parameters
521 519 ----------
522 520
523 521 prompt : str, optional
524 522 A string to be printed to prompt the user.
525 523 """
526 524 # raw_input expects str, but we pass it unicode sometimes
527 525 prompt = py3compat.cast_bytes_py2(prompt)
528 526
529 527 try:
530 528 line = py3compat.str_to_unicode(self.raw_input_original(prompt))
531 529 except ValueError:
532 530 warn("\n********\nYou or a %run:ed script called sys.stdin.close()"
533 531 " or sys.stdout.close()!\nExiting IPython!\n")
534 532 self.ask_exit()
535 533 return ""
536 534
537 535 # Try to be reasonably smart about not re-indenting pasted input more
538 536 # than necessary. We do this by trimming out the auto-indent initial
539 537 # spaces, if the user's actual input started itself with whitespace.
540 538 if self.autoindent:
541 539 if num_ini_spaces(line) > self.indent_current_nsp:
542 540 line = line[self.indent_current_nsp:]
543 541 self.indent_current_nsp = 0
544 542
545 543 return line
546 544
547 545 #-------------------------------------------------------------------------
548 546 # Methods to support auto-editing of SyntaxErrors.
549 547 #-------------------------------------------------------------------------
550 548
551 549 def edit_syntax_error(self):
552 550 """The bottom half of the syntax error handler called in the main loop.
553 551
554 552 Loop until syntax error is fixed or user cancels.
555 553 """
556 554
557 555 while self.SyntaxTB.last_syntax_error:
558 556 # copy and clear last_syntax_error
559 557 err = self.SyntaxTB.clear_err_state()
560 558 if not self._should_recompile(err):
561 559 return
562 560 try:
563 561 # may set last_syntax_error again if a SyntaxError is raised
564 562 self.safe_execfile(err.filename,self.user_ns)
565 563 except:
566 564 self.showtraceback()
567 565 else:
568 566 try:
569 567 f = open(err.filename)
570 568 try:
571 569 # This should be inside a display_trap block and I
572 570 # think it is.
573 571 sys.displayhook(f.read())
574 572 finally:
575 573 f.close()
576 574 except:
577 575 self.showtraceback()
578 576
579 577 def _should_recompile(self,e):
580 578 """Utility routine for edit_syntax_error"""
581 579
582 580 if e.filename in ('<ipython console>','<input>','<string>',
583 581 '<console>','<BackgroundJob compilation>',
584 582 None):
585 583
586 584 return False
587 585 try:
588 586 if (self.autoedit_syntax and
589 587 not self.ask_yes_no('Return to editor to correct syntax error? '
590 588 '[Y/n] ','y')):
591 589 return False
592 590 except EOFError:
593 591 return False
594 592
595 593 def int0(x):
596 594 try:
597 595 return int(x)
598 596 except TypeError:
599 597 return 0
600 598 # always pass integer line and offset values to editor hook
601 599 try:
602 600 self.hooks.fix_error_editor(e.filename,
603 601 int0(e.lineno),int0(e.offset),e.msg)
604 602 except TryNext:
605 603 warn('Could not open editor')
606 604 return False
607 605 return True
608 606
609 607 #-------------------------------------------------------------------------
610 608 # Things related to exiting
611 609 #-------------------------------------------------------------------------
612 610
613 611 def ask_exit(self):
614 612 """ Ask the shell to exit. Can be overiden and used as a callback. """
615 613 self.exit_now = True
616 614
617 615 def exit(self):
618 616 """Handle interactive exit.
619 617
620 618 This method calls the ask_exit callback."""
621 619 if self.confirm_exit:
622 620 if self.ask_yes_no('Do you really want to exit ([y]/n)?','y'):
623 621 self.ask_exit()
624 622 else:
625 623 self.ask_exit()
626 624
627 625 #-------------------------------------------------------------------------
628 626 # Things related to magics
629 627 #-------------------------------------------------------------------------
630 628
631 629 def init_magics(self):
632 630 super(TerminalInteractiveShell, self).init_magics()
633 631 self.register_magics(TerminalMagics)
634 632
635 633 def showindentationerror(self):
636 634 super(TerminalInteractiveShell, self).showindentationerror()
637 635 if not self.using_paste_magics:
638 636 print("If you want to paste code into IPython, try the "
639 637 "%paste and %cpaste magic functions.")
640 638
641 639
642 640 InteractiveShellABC.register(TerminalInteractiveShell)
@@ -1,169 +1,165 b''
1 1 """Tests for the decorators we've created for IPython.
2 2 """
3 3 from __future__ import print_function
4 4
5 5 # Module imports
6 6 # Std lib
7 7 import inspect
8 8 import sys
9 9
10 10 # Third party
11 11 import nose.tools as nt
12 12
13 13 # Our own
14 14 from IPython.testing import decorators as dec
15 from IPython.testing.skipdoctest import skip_doctest
16 15
17 16 #-----------------------------------------------------------------------------
18 17 # Utilities
19 18
20 19 # Note: copied from OInspect, kept here so the testing stuff doesn't create
21 20 # circular dependencies and is easier to reuse.
22 21 def getargspec(obj):
23 22 """Get the names and default values of a function's arguments.
24 23
25 24 A tuple of four things is returned: (args, varargs, varkw, defaults).
26 25 'args' is a list of the argument names (it may contain nested lists).
27 26 'varargs' and 'varkw' are the names of the * and ** arguments or None.
28 27 'defaults' is an n-tuple of the default values of the last n arguments.
29 28
30 29 Modified version of inspect.getargspec from the Python Standard
31 30 Library."""
32 31
33 32 if inspect.isfunction(obj):
34 33 func_obj = obj
35 34 elif inspect.ismethod(obj):
36 35 func_obj = obj.__func__
37 36 else:
38 37 raise TypeError('arg is not a Python function')
39 38 args, varargs, varkw = inspect.getargs(func_obj.__code__)
40 39 return args, varargs, varkw, func_obj.__defaults__
41 40
42 41 #-----------------------------------------------------------------------------
43 42 # Testing functions
44 43
45 44 @dec.as_unittest
46 45 def trivial():
47 46 """A trivial test"""
48 47 pass
49 48
50 49
51 50 @dec.skip
52 51 def test_deliberately_broken():
53 52 """A deliberately broken test - we want to skip this one."""
54 53 1/0
55 54
56 55 @dec.skip('Testing the skip decorator')
57 56 def test_deliberately_broken2():
58 57 """Another deliberately broken test - we want to skip this one."""
59 58 1/0
60 59
61 60
62 61 # Verify that we can correctly skip the doctest for a function at will, but
63 62 # that the docstring itself is NOT destroyed by the decorator.
64 @skip_doctest
65 63 def doctest_bad(x,y=1,**k):
66 64 """A function whose doctest we need to skip.
67 65
68 66 >>> 1+1
69 67 3
70 68 """
71 69 print('x:',x)
72 70 print('y:',y)
73 71 print('k:',k)
74 72
75 73
76 74 def call_doctest_bad():
77 75 """Check that we can still call the decorated functions.
78 76
79 77 >>> doctest_bad(3,y=4)
80 78 x: 3
81 79 y: 4
82 80 k: {}
83 81 """
84 82 pass
85 83
86 84
87 85 def test_skip_dt_decorator():
88 86 """Doctest-skipping decorator should preserve the docstring.
89 87 """
90 88 # Careful: 'check' must be a *verbatim* copy of the doctest_bad docstring!
91 89 check = """A function whose doctest we need to skip.
92 90
93 91 >>> 1+1
94 92 3
95 93 """
96 94 # Fetch the docstring from doctest_bad after decoration.
97 95 val = doctest_bad.__doc__
98 96
99 97 nt.assert_equal(check,val,"doctest_bad docstrings don't match")
100 98
101 99
102 100 # Doctest skipping should work for class methods too
103 101 class FooClass(object):
104 102 """FooClass
105 103
106 104 Example:
107 105
108 106 >>> 1+1
109 107 2
110 108 """
111 109
112 @skip_doctest
113 110 def __init__(self,x):
114 111 """Make a FooClass.
115 112
116 113 Example:
117 114
118 115 >>> f = FooClass(3)
119 116 junk
120 117 """
121 118 print('Making a FooClass.')
122 119 self.x = x
123 120
124 @skip_doctest
125 121 def bar(self,y):
126 122 """Example:
127 123
128 124 >>> ff = FooClass(3)
129 125 >>> ff.bar(0)
130 126 boom!
131 127 >>> 1/0
132 128 bam!
133 129 """
134 130 return 1/y
135 131
136 132 def baz(self,y):
137 133 """Example:
138 134
139 135 >>> ff2 = FooClass(3)
140 136 Making a FooClass.
141 137 >>> ff2.baz(3)
142 138 True
143 139 """
144 140 return self.x==y
145 141
146 142
147 143 def test_skip_dt_decorator2():
148 144 """Doctest-skipping decorator should preserve function signature.
149 145 """
150 146 # Hardcoded correct answer
151 147 dtargs = (['x', 'y'], None, 'k', (1,))
152 148 # Introspect out the value
153 149 dtargsr = getargspec(doctest_bad)
154 150 assert dtargsr==dtargs, \
155 151 "Incorrectly reconstructed args for doctest_bad: %s" % (dtargsr,)
156 152
157 153
158 154 @dec.skip_linux
159 155 def test_linux():
160 156 nt.assert_false(sys.platform.startswith('linux'),"This test can't run under linux")
161 157
162 158 @dec.skip_win32
163 159 def test_win32():
164 160 nt.assert_not_equal(sys.platform,'win32',"This test can't run under windows")
165 161
166 162 @dec.skip_osx
167 163 def test_osx():
168 164 nt.assert_not_equal(sys.platform,'darwin',"This test can't run under osx")
169 165
@@ -1,448 +1,446 b''
1 1 # encoding: utf-8
2 2 """
3 3 Utilities for path handling.
4 4 """
5 5
6 6 # Copyright (c) IPython Development Team.
7 7 # Distributed under the terms of the Modified BSD License.
8 8
9 9 import os
10 10 import sys
11 11 import errno
12 12 import shutil
13 13 import random
14 14 import tempfile
15 15 import glob
16 16 from warnings import warn
17 17 from hashlib import md5
18 18
19 from IPython.testing.skipdoctest import skip_doctest
20 19 from IPython.utils.process import system
21 20 from IPython.utils import py3compat
22 21 from IPython.utils.decorators import undoc
23 22
24 23 #-----------------------------------------------------------------------------
25 24 # Code
26 25 #-----------------------------------------------------------------------------
27 26
28 27 fs_encoding = sys.getfilesystemencoding()
29 28
30 29 def _writable_dir(path):
31 30 """Whether `path` is a directory, to which the user has write access."""
32 31 return os.path.isdir(path) and os.access(path, os.W_OK)
33 32
34 33 if sys.platform == 'win32':
35 @skip_doctest
36 34 def _get_long_path_name(path):
37 35 """Get a long path name (expand ~) on Windows using ctypes.
38 36
39 37 Examples
40 38 --------
41 39
42 40 >>> get_long_path_name('c:\\docume~1')
43 41 u'c:\\\\Documents and Settings'
44 42
45 43 """
46 44 try:
47 45 import ctypes
48 46 except ImportError:
49 47 raise ImportError('you need to have ctypes installed for this to work')
50 48 _GetLongPathName = ctypes.windll.kernel32.GetLongPathNameW
51 49 _GetLongPathName.argtypes = [ctypes.c_wchar_p, ctypes.c_wchar_p,
52 50 ctypes.c_uint ]
53 51
54 52 buf = ctypes.create_unicode_buffer(260)
55 53 rv = _GetLongPathName(path, buf, 260)
56 54 if rv == 0 or rv > 260:
57 55 return path
58 56 else:
59 57 return buf.value
60 58 else:
61 59 def _get_long_path_name(path):
62 60 """Dummy no-op."""
63 61 return path
64 62
65 63
66 64
67 65 def get_long_path_name(path):
68 66 """Expand a path into its long form.
69 67
70 68 On Windows this expands any ~ in the paths. On other platforms, it is
71 69 a null operation.
72 70 """
73 71 return _get_long_path_name(path)
74 72
75 73
76 74 def unquote_filename(name, win32=(sys.platform=='win32')):
77 75 """ On Windows, remove leading and trailing quotes from filenames.
78 76 """
79 77 if win32:
80 78 if name.startswith(("'", '"')) and name.endswith(("'", '"')):
81 79 name = name[1:-1]
82 80 return name
83 81
84 82 def compress_user(path):
85 83 """Reverse of :func:`os.path.expanduser`
86 84 """
87 85 home = os.path.expanduser('~')
88 86 if path.startswith(home):
89 87 path = "~" + path[len(home):]
90 88 return path
91 89
92 90 def get_py_filename(name, force_win32=None):
93 91 """Return a valid python filename in the current directory.
94 92
95 93 If the given name is not a file, it adds '.py' and searches again.
96 94 Raises IOError with an informative message if the file isn't found.
97 95
98 96 On Windows, apply Windows semantics to the filename. In particular, remove
99 97 any quoting that has been applied to it. This option can be forced for
100 98 testing purposes.
101 99 """
102 100
103 101 name = os.path.expanduser(name)
104 102 if force_win32 is None:
105 103 win32 = (sys.platform == 'win32')
106 104 else:
107 105 win32 = force_win32
108 106 name = unquote_filename(name, win32=win32)
109 107 if not os.path.isfile(name) and not name.endswith('.py'):
110 108 name += '.py'
111 109 if os.path.isfile(name):
112 110 return name
113 111 else:
114 112 raise IOError('File `%r` not found.' % name)
115 113
116 114
117 115 def filefind(filename, path_dirs=None):
118 116 """Find a file by looking through a sequence of paths.
119 117
120 118 This iterates through a sequence of paths looking for a file and returns
121 119 the full, absolute path of the first occurence of the file. If no set of
122 120 path dirs is given, the filename is tested as is, after running through
123 121 :func:`expandvars` and :func:`expanduser`. Thus a simple call::
124 122
125 123 filefind('myfile.txt')
126 124
127 125 will find the file in the current working dir, but::
128 126
129 127 filefind('~/myfile.txt')
130 128
131 129 Will find the file in the users home directory. This function does not
132 130 automatically try any paths, such as the cwd or the user's home directory.
133 131
134 132 Parameters
135 133 ----------
136 134 filename : str
137 135 The filename to look for.
138 136 path_dirs : str, None or sequence of str
139 137 The sequence of paths to look for the file in. If None, the filename
140 138 need to be absolute or be in the cwd. If a string, the string is
141 139 put into a sequence and the searched. If a sequence, walk through
142 140 each element and join with ``filename``, calling :func:`expandvars`
143 141 and :func:`expanduser` before testing for existence.
144 142
145 143 Returns
146 144 -------
147 145 Raises :exc:`IOError` or returns absolute path to file.
148 146 """
149 147
150 148 # If paths are quoted, abspath gets confused, strip them...
151 149 filename = filename.strip('"').strip("'")
152 150 # If the input is an absolute path, just check it exists
153 151 if os.path.isabs(filename) and os.path.isfile(filename):
154 152 return filename
155 153
156 154 if path_dirs is None:
157 155 path_dirs = ("",)
158 156 elif isinstance(path_dirs, py3compat.string_types):
159 157 path_dirs = (path_dirs,)
160 158
161 159 for path in path_dirs:
162 160 if path == '.': path = py3compat.getcwd()
163 161 testname = expand_path(os.path.join(path, filename))
164 162 if os.path.isfile(testname):
165 163 return os.path.abspath(testname)
166 164
167 165 raise IOError("File %r does not exist in any of the search paths: %r" %
168 166 (filename, path_dirs) )
169 167
170 168
171 169 class HomeDirError(Exception):
172 170 pass
173 171
174 172
175 173 def get_home_dir(require_writable=False):
176 174 """Return the 'home' directory, as a unicode string.
177 175
178 176 Uses os.path.expanduser('~'), and checks for writability.
179 177
180 178 See stdlib docs for how this is determined.
181 179 $HOME is first priority on *ALL* platforms.
182 180
183 181 Parameters
184 182 ----------
185 183
186 184 require_writable : bool [default: False]
187 185 if True:
188 186 guarantees the return value is a writable directory, otherwise
189 187 raises HomeDirError
190 188 if False:
191 189 The path is resolved, but it is not guaranteed to exist or be writable.
192 190 """
193 191
194 192 homedir = os.path.expanduser('~')
195 193 # Next line will make things work even when /home/ is a symlink to
196 194 # /usr/home as it is on FreeBSD, for example
197 195 homedir = os.path.realpath(homedir)
198 196
199 197 if not _writable_dir(homedir) and os.name == 'nt':
200 198 # expanduser failed, use the registry to get the 'My Documents' folder.
201 199 try:
202 200 try:
203 201 import winreg as wreg # Py 3
204 202 except ImportError:
205 203 import _winreg as wreg # Py 2
206 204 key = wreg.OpenKey(
207 205 wreg.HKEY_CURRENT_USER,
208 206 "Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
209 207 )
210 208 homedir = wreg.QueryValueEx(key,'Personal')[0]
211 209 key.Close()
212 210 except:
213 211 pass
214 212
215 213 if (not require_writable) or _writable_dir(homedir):
216 214 return py3compat.cast_unicode(homedir, fs_encoding)
217 215 else:
218 216 raise HomeDirError('%s is not a writable dir, '
219 217 'set $HOME environment variable to override' % homedir)
220 218
221 219 def get_xdg_dir():
222 220 """Return the XDG_CONFIG_HOME, if it is defined and exists, else None.
223 221
224 222 This is only for non-OS X posix (Linux,Unix,etc.) systems.
225 223 """
226 224
227 225 env = os.environ
228 226
229 227 if os.name == 'posix' and sys.platform != 'darwin':
230 228 # Linux, Unix, AIX, etc.
231 229 # use ~/.config if empty OR not set
232 230 xdg = env.get("XDG_CONFIG_HOME", None) or os.path.join(get_home_dir(), '.config')
233 231 if xdg and _writable_dir(xdg):
234 232 return py3compat.cast_unicode(xdg, fs_encoding)
235 233
236 234 return None
237 235
238 236
239 237 def get_xdg_cache_dir():
240 238 """Return the XDG_CACHE_HOME, if it is defined and exists, else None.
241 239
242 240 This is only for non-OS X posix (Linux,Unix,etc.) systems.
243 241 """
244 242
245 243 env = os.environ
246 244
247 245 if os.name == 'posix' and sys.platform != 'darwin':
248 246 # Linux, Unix, AIX, etc.
249 247 # use ~/.cache if empty OR not set
250 248 xdg = env.get("XDG_CACHE_HOME", None) or os.path.join(get_home_dir(), '.cache')
251 249 if xdg and _writable_dir(xdg):
252 250 return py3compat.cast_unicode(xdg, fs_encoding)
253 251
254 252 return None
255 253
256 254
257 255 @undoc
258 256 def get_ipython_dir():
259 257 warn("get_ipython_dir has moved to the IPython.paths module")
260 258 from IPython.paths import get_ipython_dir
261 259 return get_ipython_dir()
262 260
263 261 @undoc
264 262 def get_ipython_cache_dir():
265 263 warn("get_ipython_cache_dir has moved to the IPython.paths module")
266 264 from IPython.paths import get_ipython_cache_dir
267 265 return get_ipython_cache_dir()
268 266
269 267 @undoc
270 268 def get_ipython_package_dir():
271 269 warn("get_ipython_package_dir has moved to the IPython.paths module")
272 270 from IPython.paths import get_ipython_package_dir
273 271 return get_ipython_package_dir()
274 272
275 273 @undoc
276 274 def get_ipython_module_path(module_str):
277 275 warn("get_ipython_module_path has moved to the IPython.paths module")
278 276 from IPython.paths import get_ipython_module_path
279 277 return get_ipython_module_path(module_str)
280 278
281 279 @undoc
282 280 def locate_profile(profile='default'):
283 281 warn("locate_profile has moved to the IPython.paths module")
284 282 from IPython.paths import locate_profile
285 283 return locate_profile(profile=profile)
286 284
287 285 def expand_path(s):
288 286 """Expand $VARS and ~names in a string, like a shell
289 287
290 288 :Examples:
291 289
292 290 In [2]: os.environ['FOO']='test'
293 291
294 292 In [3]: expand_path('variable FOO is $FOO')
295 293 Out[3]: 'variable FOO is test'
296 294 """
297 295 # This is a pretty subtle hack. When expand user is given a UNC path
298 296 # on Windows (\\server\share$\%username%), os.path.expandvars, removes
299 297 # the $ to get (\\server\share\%username%). I think it considered $
300 298 # alone an empty var. But, we need the $ to remains there (it indicates
301 299 # a hidden share).
302 300 if os.name=='nt':
303 301 s = s.replace('$\\', 'IPYTHON_TEMP')
304 302 s = os.path.expandvars(os.path.expanduser(s))
305 303 if os.name=='nt':
306 304 s = s.replace('IPYTHON_TEMP', '$\\')
307 305 return s
308 306
309 307
310 308 def unescape_glob(string):
311 309 """Unescape glob pattern in `string`."""
312 310 def unescape(s):
313 311 for pattern in '*[]!?':
314 312 s = s.replace(r'\{0}'.format(pattern), pattern)
315 313 return s
316 314 return '\\'.join(map(unescape, string.split('\\\\')))
317 315
318 316
319 317 def shellglob(args):
320 318 """
321 319 Do glob expansion for each element in `args` and return a flattened list.
322 320
323 321 Unmatched glob pattern will remain as-is in the returned list.
324 322
325 323 """
326 324 expanded = []
327 325 # Do not unescape backslash in Windows as it is interpreted as
328 326 # path separator:
329 327 unescape = unescape_glob if sys.platform != 'win32' else lambda x: x
330 328 for a in args:
331 329 expanded.extend(glob.glob(a) or [unescape(a)])
332 330 return expanded
333 331
334 332
335 333 def target_outdated(target,deps):
336 334 """Determine whether a target is out of date.
337 335
338 336 target_outdated(target,deps) -> 1/0
339 337
340 338 deps: list of filenames which MUST exist.
341 339 target: single filename which may or may not exist.
342 340
343 341 If target doesn't exist or is older than any file listed in deps, return
344 342 true, otherwise return false.
345 343 """
346 344 try:
347 345 target_time = os.path.getmtime(target)
348 346 except os.error:
349 347 return 1
350 348 for dep in deps:
351 349 dep_time = os.path.getmtime(dep)
352 350 if dep_time > target_time:
353 351 #print "For target",target,"Dep failed:",dep # dbg
354 352 #print "times (dep,tar):",dep_time,target_time # dbg
355 353 return 1
356 354 return 0
357 355
358 356
359 357 def target_update(target,deps,cmd):
360 358 """Update a target with a given command given a list of dependencies.
361 359
362 360 target_update(target,deps,cmd) -> runs cmd if target is outdated.
363 361
364 362 This is just a wrapper around target_outdated() which calls the given
365 363 command if target is outdated."""
366 364
367 365 if target_outdated(target,deps):
368 366 system(cmd)
369 367
370 368 @undoc
371 369 def filehash(path):
372 370 """Make an MD5 hash of a file, ignoring any differences in line
373 371 ending characters."""
374 372 warn("filehash() is deprecated")
375 373 with open(path, "rU") as f:
376 374 return md5(py3compat.str_to_bytes(f.read())).hexdigest()
377 375
378 376 ENOLINK = 1998
379 377
380 378 def link(src, dst):
381 379 """Hard links ``src`` to ``dst``, returning 0 or errno.
382 380
383 381 Note that the special errno ``ENOLINK`` will be returned if ``os.link`` isn't
384 382 supported by the operating system.
385 383 """
386 384
387 385 if not hasattr(os, "link"):
388 386 return ENOLINK
389 387 link_errno = 0
390 388 try:
391 389 os.link(src, dst)
392 390 except OSError as e:
393 391 link_errno = e.errno
394 392 return link_errno
395 393
396 394
397 395 def link_or_copy(src, dst):
398 396 """Attempts to hardlink ``src`` to ``dst``, copying if the link fails.
399 397
400 398 Attempts to maintain the semantics of ``shutil.copy``.
401 399
402 400 Because ``os.link`` does not overwrite files, a unique temporary file
403 401 will be used if the target already exists, then that file will be moved
404 402 into place.
405 403 """
406 404
407 405 if os.path.isdir(dst):
408 406 dst = os.path.join(dst, os.path.basename(src))
409 407
410 408 link_errno = link(src, dst)
411 409 if link_errno == errno.EEXIST:
412 410 if os.stat(src).st_ino == os.stat(dst).st_ino:
413 411 # dst is already a hard link to the correct file, so we don't need
414 412 # to do anything else. If we try to link and rename the file
415 413 # anyway, we get duplicate files - see http://bugs.python.org/issue21876
416 414 return
417 415
418 416 new_dst = dst + "-temp-%04X" %(random.randint(1, 16**4), )
419 417 try:
420 418 link_or_copy(src, new_dst)
421 419 except:
422 420 try:
423 421 os.remove(new_dst)
424 422 except OSError:
425 423 pass
426 424 raise
427 425 os.rename(new_dst, dst)
428 426 elif link_errno != 0:
429 427 # Either link isn't supported, or the filesystem doesn't support
430 428 # linking, or 'src' and 'dst' are on different filesystems.
431 429 shutil.copy(src, dst)
432 430
433 431 def ensure_dir_exists(path, mode=0o755):
434 432 """ensure that a directory exists
435 433
436 434 If it doesn't exist, try to create it and protect against a race condition
437 435 if another process is doing the same.
438 436
439 437 The default permissions are 755, which differ from os.makedirs default of 777.
440 438 """
441 439 if not os.path.exists(path):
442 440 try:
443 441 os.makedirs(path, mode=mode)
444 442 except OSError as e:
445 443 if e.errno != errno.EEXIST:
446 444 raise
447 445 elif not os.path.isdir(path):
448 446 raise IOError("%r exists but is not a directory" % path)
@@ -1,765 +1,764 b''
1 1 # encoding: utf-8
2 2 """
3 3 Utilities for working with strings and text.
4 4
5 5 Inheritance diagram:
6 6
7 7 .. inheritance-diagram:: IPython.utils.text
8 8 :parts: 3
9 9 """
10 10
11 11 import os
12 12 import re
13 13 import sys
14 14 import textwrap
15 15 from string import Formatter
16 16
17 17 from IPython.testing.skipdoctest import skip_doctest_py3, skip_doctest
18 18 from IPython.utils import py3compat
19 19
20 20 # datetime.strftime date format for ipython
21 21 if sys.platform == 'win32':
22 22 date_format = "%B %d, %Y"
23 23 else:
24 24 date_format = "%B %-d, %Y"
25 25
26 26 class LSString(str):
27 27 """String derivative with a special access attributes.
28 28
29 29 These are normal strings, but with the special attributes:
30 30
31 31 .l (or .list) : value as list (split on newlines).
32 32 .n (or .nlstr): original value (the string itself).
33 33 .s (or .spstr): value as whitespace-separated string.
34 34 .p (or .paths): list of path objects (requires path.py package)
35 35
36 36 Any values which require transformations are computed only once and
37 37 cached.
38 38
39 39 Such strings are very useful to efficiently interact with the shell, which
40 40 typically only understands whitespace-separated options for commands."""
41 41
42 42 def get_list(self):
43 43 try:
44 44 return self.__list
45 45 except AttributeError:
46 46 self.__list = self.split('\n')
47 47 return self.__list
48 48
49 49 l = list = property(get_list)
50 50
51 51 def get_spstr(self):
52 52 try:
53 53 return self.__spstr
54 54 except AttributeError:
55 55 self.__spstr = self.replace('\n',' ')
56 56 return self.__spstr
57 57
58 58 s = spstr = property(get_spstr)
59 59
60 60 def get_nlstr(self):
61 61 return self
62 62
63 63 n = nlstr = property(get_nlstr)
64 64
65 65 def get_paths(self):
66 66 from path import path
67 67 try:
68 68 return self.__paths
69 69 except AttributeError:
70 70 self.__paths = [path(p) for p in self.split('\n') if os.path.exists(p)]
71 71 return self.__paths
72 72
73 73 p = paths = property(get_paths)
74 74
75 75 # FIXME: We need to reimplement type specific displayhook and then add this
76 76 # back as a custom printer. This should also be moved outside utils into the
77 77 # core.
78 78
79 79 # def print_lsstring(arg):
80 80 # """ Prettier (non-repr-like) and more informative printer for LSString """
81 81 # print "LSString (.p, .n, .l, .s available). Value:"
82 82 # print arg
83 83 #
84 84 #
85 85 # print_lsstring = result_display.when_type(LSString)(print_lsstring)
86 86
87 87
88 88 class SList(list):
89 89 """List derivative with a special access attributes.
90 90
91 91 These are normal lists, but with the special attributes:
92 92
93 93 * .l (or .list) : value as list (the list itself).
94 94 * .n (or .nlstr): value as a string, joined on newlines.
95 95 * .s (or .spstr): value as a string, joined on spaces.
96 96 * .p (or .paths): list of path objects (requires path.py package)
97 97
98 98 Any values which require transformations are computed only once and
99 99 cached."""
100 100
101 101 def get_list(self):
102 102 return self
103 103
104 104 l = list = property(get_list)
105 105
106 106 def get_spstr(self):
107 107 try:
108 108 return self.__spstr
109 109 except AttributeError:
110 110 self.__spstr = ' '.join(self)
111 111 return self.__spstr
112 112
113 113 s = spstr = property(get_spstr)
114 114
115 115 def get_nlstr(self):
116 116 try:
117 117 return self.__nlstr
118 118 except AttributeError:
119 119 self.__nlstr = '\n'.join(self)
120 120 return self.__nlstr
121 121
122 122 n = nlstr = property(get_nlstr)
123 123
124 124 def get_paths(self):
125 125 from path import path
126 126 try:
127 127 return self.__paths
128 128 except AttributeError:
129 129 self.__paths = [path(p) for p in self if os.path.exists(p)]
130 130 return self.__paths
131 131
132 132 p = paths = property(get_paths)
133 133
134 134 def grep(self, pattern, prune = False, field = None):
135 135 """ Return all strings matching 'pattern' (a regex or callable)
136 136
137 137 This is case-insensitive. If prune is true, return all items
138 138 NOT matching the pattern.
139 139
140 140 If field is specified, the match must occur in the specified
141 141 whitespace-separated field.
142 142
143 143 Examples::
144 144
145 145 a.grep( lambda x: x.startswith('C') )
146 146 a.grep('Cha.*log', prune=1)
147 147 a.grep('chm', field=-1)
148 148 """
149 149
150 150 def match_target(s):
151 151 if field is None:
152 152 return s
153 153 parts = s.split()
154 154 try:
155 155 tgt = parts[field]
156 156 return tgt
157 157 except IndexError:
158 158 return ""
159 159
160 160 if isinstance(pattern, py3compat.string_types):
161 161 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
162 162 else:
163 163 pred = pattern
164 164 if not prune:
165 165 return SList([el for el in self if pred(match_target(el))])
166 166 else:
167 167 return SList([el for el in self if not pred(match_target(el))])
168 168
169 169 def fields(self, *fields):
170 170 """ Collect whitespace-separated fields from string list
171 171
172 172 Allows quick awk-like usage of string lists.
173 173
174 174 Example data (in var a, created by 'a = !ls -l')::
175 175
176 176 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
177 177 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
178 178
179 179 * ``a.fields(0)`` is ``['-rwxrwxrwx', 'drwxrwxrwx+']``
180 180 * ``a.fields(1,0)`` is ``['1 -rwxrwxrwx', '6 drwxrwxrwx+']``
181 181 (note the joining by space).
182 182 * ``a.fields(-1)`` is ``['ChangeLog', 'IPython']``
183 183
184 184 IndexErrors are ignored.
185 185
186 186 Without args, fields() just split()'s the strings.
187 187 """
188 188 if len(fields) == 0:
189 189 return [el.split() for el in self]
190 190
191 191 res = SList()
192 192 for el in [f.split() for f in self]:
193 193 lineparts = []
194 194
195 195 for fd in fields:
196 196 try:
197 197 lineparts.append(el[fd])
198 198 except IndexError:
199 199 pass
200 200 if lineparts:
201 201 res.append(" ".join(lineparts))
202 202
203 203 return res
204 204
205 205 def sort(self,field= None, nums = False):
206 206 """ sort by specified fields (see fields())
207 207
208 208 Example::
209 209
210 210 a.sort(1, nums = True)
211 211
212 212 Sorts a by second field, in numerical order (so that 21 > 3)
213 213
214 214 """
215 215
216 216 #decorate, sort, undecorate
217 217 if field is not None:
218 218 dsu = [[SList([line]).fields(field), line] for line in self]
219 219 else:
220 220 dsu = [[line, line] for line in self]
221 221 if nums:
222 222 for i in range(len(dsu)):
223 223 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
224 224 try:
225 225 n = int(numstr)
226 226 except ValueError:
227 227 n = 0;
228 228 dsu[i][0] = n
229 229
230 230
231 231 dsu.sort()
232 232 return SList([t[1] for t in dsu])
233 233
234 234
235 235 # FIXME: We need to reimplement type specific displayhook and then add this
236 236 # back as a custom printer. This should also be moved outside utils into the
237 237 # core.
238 238
239 239 # def print_slist(arg):
240 240 # """ Prettier (non-repr-like) and more informative printer for SList """
241 241 # print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
242 242 # if hasattr(arg, 'hideonce') and arg.hideonce:
243 243 # arg.hideonce = False
244 244 # return
245 245 #
246 246 # nlprint(arg) # This was a nested list printer, now removed.
247 247 #
248 248 # print_slist = result_display.when_type(SList)(print_slist)
249 249
250 250
251 251 def indent(instr,nspaces=4, ntabs=0, flatten=False):
252 252 """Indent a string a given number of spaces or tabstops.
253 253
254 254 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
255 255
256 256 Parameters
257 257 ----------
258 258
259 259 instr : basestring
260 260 The string to be indented.
261 261 nspaces : int (default: 4)
262 262 The number of spaces to be indented.
263 263 ntabs : int (default: 0)
264 264 The number of tabs to be indented.
265 265 flatten : bool (default: False)
266 266 Whether to scrub existing indentation. If True, all lines will be
267 267 aligned to the same indentation. If False, existing indentation will
268 268 be strictly increased.
269 269
270 270 Returns
271 271 -------
272 272
273 273 str|unicode : string indented by ntabs and nspaces.
274 274
275 275 """
276 276 if instr is None:
277 277 return
278 278 ind = '\t'*ntabs+' '*nspaces
279 279 if flatten:
280 280 pat = re.compile(r'^\s*', re.MULTILINE)
281 281 else:
282 282 pat = re.compile(r'^', re.MULTILINE)
283 283 outstr = re.sub(pat, ind, instr)
284 284 if outstr.endswith(os.linesep+ind):
285 285 return outstr[:-len(ind)]
286 286 else:
287 287 return outstr
288 288
289 289
290 290 def list_strings(arg):
291 291 """Always return a list of strings, given a string or list of strings
292 292 as input.
293 293
294 294 Examples
295 295 --------
296 296 ::
297 297
298 298 In [7]: list_strings('A single string')
299 299 Out[7]: ['A single string']
300 300
301 301 In [8]: list_strings(['A single string in a list'])
302 302 Out[8]: ['A single string in a list']
303 303
304 304 In [9]: list_strings(['A','list','of','strings'])
305 305 Out[9]: ['A', 'list', 'of', 'strings']
306 306 """
307 307
308 308 if isinstance(arg, py3compat.string_types): return [arg]
309 309 else: return arg
310 310
311 311
312 312 def marquee(txt='',width=78,mark='*'):
313 313 """Return the input string centered in a 'marquee'.
314 314
315 315 Examples
316 316 --------
317 317 ::
318 318
319 319 In [16]: marquee('A test',40)
320 320 Out[16]: '**************** A test ****************'
321 321
322 322 In [17]: marquee('A test',40,'-')
323 323 Out[17]: '---------------- A test ----------------'
324 324
325 325 In [18]: marquee('A test',40,' ')
326 326 Out[18]: ' A test '
327 327
328 328 """
329 329 if not txt:
330 330 return (mark*width)[:width]
331 331 nmark = (width-len(txt)-2)//len(mark)//2
332 332 if nmark < 0: nmark =0
333 333 marks = mark*nmark
334 334 return '%s %s %s' % (marks,txt,marks)
335 335
336 336
337 337 ini_spaces_re = re.compile(r'^(\s+)')
338 338
339 339 def num_ini_spaces(strng):
340 340 """Return the number of initial spaces in a string"""
341 341
342 342 ini_spaces = ini_spaces_re.match(strng)
343 343 if ini_spaces:
344 344 return ini_spaces.end()
345 345 else:
346 346 return 0
347 347
348 348
349 349 def format_screen(strng):
350 350 """Format a string for screen printing.
351 351
352 352 This removes some latex-type format codes."""
353 353 # Paragraph continue
354 354 par_re = re.compile(r'\\$',re.MULTILINE)
355 355 strng = par_re.sub('',strng)
356 356 return strng
357 357
358 358
359 359 def dedent(text):
360 360 """Equivalent of textwrap.dedent that ignores unindented first line.
361 361
362 362 This means it will still dedent strings like:
363 363 '''foo
364 364 is a bar
365 365 '''
366 366
367 367 For use in wrap_paragraphs.
368 368 """
369 369
370 370 if text.startswith('\n'):
371 371 # text starts with blank line, don't ignore the first line
372 372 return textwrap.dedent(text)
373 373
374 374 # split first line
375 375 splits = text.split('\n',1)
376 376 if len(splits) == 1:
377 377 # only one line
378 378 return textwrap.dedent(text)
379 379
380 380 first, rest = splits
381 381 # dedent everything but the first line
382 382 rest = textwrap.dedent(rest)
383 383 return '\n'.join([first, rest])
384 384
385 385
386 386 def wrap_paragraphs(text, ncols=80):
387 387 """Wrap multiple paragraphs to fit a specified width.
388 388
389 389 This is equivalent to textwrap.wrap, but with support for multiple
390 390 paragraphs, as separated by empty lines.
391 391
392 392 Returns
393 393 -------
394 394
395 395 list of complete paragraphs, wrapped to fill `ncols` columns.
396 396 """
397 397 paragraph_re = re.compile(r'\n(\s*\n)+', re.MULTILINE)
398 398 text = dedent(text).strip()
399 399 paragraphs = paragraph_re.split(text)[::2] # every other entry is space
400 400 out_ps = []
401 401 indent_re = re.compile(r'\n\s+', re.MULTILINE)
402 402 for p in paragraphs:
403 403 # presume indentation that survives dedent is meaningful formatting,
404 404 # so don't fill unless text is flush.
405 405 if indent_re.search(p) is None:
406 406 # wrap paragraph
407 407 p = textwrap.fill(p, ncols)
408 408 out_ps.append(p)
409 409 return out_ps
410 410
411 411
412 412 def long_substr(data):
413 413 """Return the longest common substring in a list of strings.
414 414
415 415 Credit: http://stackoverflow.com/questions/2892931/longest-common-substring-from-more-than-two-strings-python
416 416 """
417 417 substr = ''
418 418 if len(data) > 1 and len(data[0]) > 0:
419 419 for i in range(len(data[0])):
420 420 for j in range(len(data[0])-i+1):
421 421 if j > len(substr) and all(data[0][i:i+j] in x for x in data):
422 422 substr = data[0][i:i+j]
423 423 elif len(data) == 1:
424 424 substr = data[0]
425 425 return substr
426 426
427 427
428 428 def strip_email_quotes(text):
429 429 """Strip leading email quotation characters ('>').
430 430
431 431 Removes any combination of leading '>' interspersed with whitespace that
432 432 appears *identically* in all lines of the input text.
433 433
434 434 Parameters
435 435 ----------
436 436 text : str
437 437
438 438 Examples
439 439 --------
440 440
441 441 Simple uses::
442 442
443 443 In [2]: strip_email_quotes('> > text')
444 444 Out[2]: 'text'
445 445
446 446 In [3]: strip_email_quotes('> > text\\n> > more')
447 447 Out[3]: 'text\\nmore'
448 448
449 449 Note how only the common prefix that appears in all lines is stripped::
450 450
451 451 In [4]: strip_email_quotes('> > text\\n> > more\\n> more...')
452 452 Out[4]: '> text\\n> more\\nmore...'
453 453
454 454 So if any line has no quote marks ('>') , then none are stripped from any
455 455 of them ::
456 456
457 457 In [5]: strip_email_quotes('> > text\\n> > more\\nlast different')
458 458 Out[5]: '> > text\\n> > more\\nlast different'
459 459 """
460 460 lines = text.splitlines()
461 461 matches = set()
462 462 for line in lines:
463 463 prefix = re.match(r'^(\s*>[ >]*)', line)
464 464 if prefix:
465 465 matches.add(prefix.group(1))
466 466 else:
467 467 break
468 468 else:
469 469 prefix = long_substr(list(matches))
470 470 if prefix:
471 471 strip = len(prefix)
472 472 text = '\n'.join([ ln[strip:] for ln in lines])
473 473 return text
474 474
475 475 def strip_ansi(source):
476 476 """
477 477 Remove ansi escape codes from text.
478 478
479 479 Parameters
480 480 ----------
481 481 source : str
482 482 Source to remove the ansi from
483 483 """
484 484 return re.sub(r'\033\[(\d|;)+?m', '', source)
485 485
486 486
487 487 class EvalFormatter(Formatter):
488 488 """A String Formatter that allows evaluation of simple expressions.
489 489
490 490 Note that this version interprets a : as specifying a format string (as per
491 491 standard string formatting), so if slicing is required, you must explicitly
492 492 create a slice.
493 493
494 494 This is to be used in templating cases, such as the parallel batch
495 495 script templates, where simple arithmetic on arguments is useful.
496 496
497 497 Examples
498 498 --------
499 499 ::
500 500
501 501 In [1]: f = EvalFormatter()
502 502 In [2]: f.format('{n//4}', n=8)
503 503 Out[2]: '2'
504 504
505 505 In [3]: f.format("{greeting[slice(2,4)]}", greeting="Hello")
506 506 Out[3]: 'll'
507 507 """
508 508 def get_field(self, name, args, kwargs):
509 509 v = eval(name, kwargs)
510 510 return v, name
511 511
512 512 #XXX: As of Python 3.4, the format string parsing no longer splits on a colon
513 513 # inside [], so EvalFormatter can handle slicing. Once we only support 3.4 and
514 514 # above, it should be possible to remove FullEvalFormatter.
515 515
516 516 @skip_doctest_py3
517 517 class FullEvalFormatter(Formatter):
518 518 """A String Formatter that allows evaluation of simple expressions.
519 519
520 520 Any time a format key is not found in the kwargs,
521 521 it will be tried as an expression in the kwargs namespace.
522 522
523 523 Note that this version allows slicing using [1:2], so you cannot specify
524 524 a format string. Use :class:`EvalFormatter` to permit format strings.
525 525
526 526 Examples
527 527 --------
528 528 ::
529 529
530 530 In [1]: f = FullEvalFormatter()
531 531 In [2]: f.format('{n//4}', n=8)
532 532 Out[2]: u'2'
533 533
534 534 In [3]: f.format('{list(range(5))[2:4]}')
535 535 Out[3]: u'[2, 3]'
536 536
537 537 In [4]: f.format('{3*2}')
538 538 Out[4]: u'6'
539 539 """
540 540 # copied from Formatter._vformat with minor changes to allow eval
541 541 # and replace the format_spec code with slicing
542 542 def _vformat(self, format_string, args, kwargs, used_args, recursion_depth):
543 543 if recursion_depth < 0:
544 544 raise ValueError('Max string recursion exceeded')
545 545 result = []
546 546 for literal_text, field_name, format_spec, conversion in \
547 547 self.parse(format_string):
548 548
549 549 # output the literal text
550 550 if literal_text:
551 551 result.append(literal_text)
552 552
553 553 # if there's a field, output it
554 554 if field_name is not None:
555 555 # this is some markup, find the object and do
556 556 # the formatting
557 557
558 558 if format_spec:
559 559 # override format spec, to allow slicing:
560 560 field_name = ':'.join([field_name, format_spec])
561 561
562 562 # eval the contents of the field for the object
563 563 # to be formatted
564 564 obj = eval(field_name, kwargs)
565 565
566 566 # do any conversion on the resulting object
567 567 obj = self.convert_field(obj, conversion)
568 568
569 569 # format the object and append to the result
570 570 result.append(self.format_field(obj, ''))
571 571
572 572 return u''.join(py3compat.cast_unicode(s) for s in result)
573 573
574 574
575 575 @skip_doctest_py3
576 576 class DollarFormatter(FullEvalFormatter):
577 577 """Formatter allowing Itpl style $foo replacement, for names and attribute
578 578 access only. Standard {foo} replacement also works, and allows full
579 579 evaluation of its arguments.
580 580
581 581 Examples
582 582 --------
583 583 ::
584 584
585 585 In [1]: f = DollarFormatter()
586 586 In [2]: f.format('{n//4}', n=8)
587 587 Out[2]: u'2'
588 588
589 589 In [3]: f.format('23 * 76 is $result', result=23*76)
590 590 Out[3]: u'23 * 76 is 1748'
591 591
592 592 In [4]: f.format('$a or {b}', a=1, b=2)
593 593 Out[4]: u'1 or 2'
594 594 """
595 595 _dollar_pattern = re.compile("(.*?)\$(\$?[\w\.]+)")
596 596 def parse(self, fmt_string):
597 597 for literal_txt, field_name, format_spec, conversion \
598 598 in Formatter.parse(self, fmt_string):
599 599
600 600 # Find $foo patterns in the literal text.
601 601 continue_from = 0
602 602 txt = ""
603 603 for m in self._dollar_pattern.finditer(literal_txt):
604 604 new_txt, new_field = m.group(1,2)
605 605 # $$foo --> $foo
606 606 if new_field.startswith("$"):
607 607 txt += new_txt + new_field
608 608 else:
609 609 yield (txt + new_txt, new_field, "", None)
610 610 txt = ""
611 611 continue_from = m.end()
612 612
613 613 # Re-yield the {foo} style pattern
614 614 yield (txt + literal_txt[continue_from:], field_name, format_spec, conversion)
615 615
616 616 #-----------------------------------------------------------------------------
617 617 # Utils to columnize a list of string
618 618 #-----------------------------------------------------------------------------
619 619
620 620 def _chunks(l, n):
621 621 """Yield successive n-sized chunks from l."""
622 622 for i in py3compat.xrange(0, len(l), n):
623 623 yield l[i:i+n]
624 624
625 625
626 626 def _find_optimal(rlist , separator_size=2 , displaywidth=80):
627 627 """Calculate optimal info to columnize a list of string"""
628 628 for nrow in range(1, len(rlist)+1) :
629 629 chk = list(map(max,_chunks(rlist, nrow)))
630 630 sumlength = sum(chk)
631 631 ncols = len(chk)
632 632 if sumlength+separator_size*(ncols-1) <= displaywidth :
633 633 break;
634 634 return {'columns_numbers' : ncols,
635 635 'optimal_separator_width':(displaywidth - sumlength)/(ncols-1) if (ncols -1) else 0,
636 636 'rows_numbers' : nrow,
637 637 'columns_width' : chk
638 638 }
639 639
640 640
641 641 def _get_or_default(mylist, i, default=None):
642 642 """return list item number, or default if don't exist"""
643 643 if i >= len(mylist):
644 644 return default
645 645 else :
646 646 return mylist[i]
647 647
648 648
649 @skip_doctest
650 649 def compute_item_matrix(items, empty=None, *args, **kwargs) :
651 650 """Returns a nested list, and info to columnize items
652 651
653 652 Parameters
654 653 ----------
655 654
656 655 items
657 656 list of strings to columize
658 657 empty : (default None)
659 658 default value to fill list if needed
660 659 separator_size : int (default=2)
661 660 How much caracters will be used as a separation between each columns.
662 661 displaywidth : int (default=80)
663 662 The width of the area onto wich the columns should enter
664 663
665 664 Returns
666 665 -------
667 666
668 667 strings_matrix
669 668
670 669 nested list of string, the outer most list contains as many list as
671 670 rows, the innermost lists have each as many element as colums. If the
672 671 total number of elements in `items` does not equal the product of
673 672 rows*columns, the last element of some lists are filled with `None`.
674 673
675 674 dict_info
676 675 some info to make columnize easier:
677 676
678 677 columns_numbers
679 678 number of columns
680 679 rows_numbers
681 680 number of rows
682 681 columns_width
683 682 list of with of each columns
684 683 optimal_separator_width
685 684 best separator width between columns
686 685
687 686 Examples
688 687 --------
689 688 ::
690 689
691 690 In [1]: l = ['aaa','b','cc','d','eeeee','f','g','h','i','j','k','l']
692 691 ...: compute_item_matrix(l,displaywidth=12)
693 692 Out[1]:
694 693 ([['aaa', 'f', 'k'],
695 694 ['b', 'g', 'l'],
696 695 ['cc', 'h', None],
697 696 ['d', 'i', None],
698 697 ['eeeee', 'j', None]],
699 698 {'columns_numbers': 3,
700 699 'columns_width': [5, 1, 1],
701 700 'optimal_separator_width': 2,
702 701 'rows_numbers': 5})
703 702 """
704 703 info = _find_optimal(list(map(len, items)), *args, **kwargs)
705 704 nrow, ncol = info['rows_numbers'], info['columns_numbers']
706 705 return ([[ _get_or_default(items, c*nrow+i, default=empty) for c in range(ncol) ] for i in range(nrow) ], info)
707 706
708 707
709 708 def columnize(items, separator=' ', displaywidth=80):
710 709 """ Transform a list of strings into a single string with columns.
711 710
712 711 Parameters
713 712 ----------
714 713 items : sequence of strings
715 714 The strings to process.
716 715
717 716 separator : str, optional [default is two spaces]
718 717 The string that separates columns.
719 718
720 719 displaywidth : int, optional [default is 80]
721 720 Width of the display in number of characters.
722 721
723 722 Returns
724 723 -------
725 724 The formatted string.
726 725 """
727 726 if not items :
728 727 return '\n'
729 728 matrix, info = compute_item_matrix(items, separator_size=len(separator), displaywidth=displaywidth)
730 729 fmatrix = [filter(None, x) for x in matrix]
731 730 sjoin = lambda x : separator.join([ y.ljust(w, ' ') for y, w in zip(x, info['columns_width'])])
732 731 return '\n'.join(map(sjoin, fmatrix))+'\n'
733 732
734 733
735 734 def get_text_list(list_, last_sep=' and ', sep=", ", wrap_item_with=""):
736 735 """
737 736 Return a string with a natural enumeration of items
738 737
739 738 >>> get_text_list(['a', 'b', 'c', 'd'])
740 739 'a, b, c and d'
741 740 >>> get_text_list(['a', 'b', 'c'], ' or ')
742 741 'a, b or c'
743 742 >>> get_text_list(['a', 'b', 'c'], ', ')
744 743 'a, b, c'
745 744 >>> get_text_list(['a', 'b'], ' or ')
746 745 'a or b'
747 746 >>> get_text_list(['a'])
748 747 'a'
749 748 >>> get_text_list([])
750 749 ''
751 750 >>> get_text_list(['a', 'b'], wrap_item_with="`")
752 751 '`a` and `b`'
753 752 >>> get_text_list(['a', 'b', 'c', 'd'], " = ", sep=" + ")
754 753 'a + b + c = d'
755 754 """
756 755 if len(list_) == 0:
757 756 return ''
758 757 if wrap_item_with:
759 758 list_ = ['%s%s%s' % (wrap_item_with, item, wrap_item_with) for
760 759 item in list_]
761 760 if len(list_) == 1:
762 761 return list_[0]
763 762 return '%s%s%s' % (
764 763 sep.join(i for i in list_[:-1]),
765 764 last_sep, list_[-1]) No newline at end of file
@@ -1,486 +1,484 b''
1 1 """A ZMQ-based subclass of InteractiveShell.
2 2
3 3 This code is meant to ease the refactoring of the base InteractiveShell into
4 4 something with a cleaner architecture for 2-process use, without actually
5 5 breaking InteractiveShell itself. So we're doing something a bit ugly, where
6 6 we subclass and override what we want to fix. Once this is working well, we
7 7 can go back to the base class and refactor the code for a cleaner inheritance
8 8 implementation that doesn't rely on so much monkeypatching.
9 9
10 10 But this lets us maintain a fully working IPython as we develop the new
11 11 machinery. This should thus be thought of as scaffolding.
12 12 """
13 13
14 14 # Copyright (c) IPython Development Team.
15 15 # Distributed under the terms of the Modified BSD License.
16 16
17 17 from __future__ import print_function
18 18
19 19 import os
20 20 import sys
21 21 import time
22 22
23 23 from zmq.eventloop import ioloop
24 24
25 25 from IPython.core.interactiveshell import (
26 26 InteractiveShell, InteractiveShellABC
27 27 )
28 28 from IPython.core import page
29 29 from IPython.core.autocall import ZMQExitAutocall
30 30 from IPython.core.displaypub import DisplayPublisher
31 31 from IPython.core.error import UsageError
32 32 from IPython.core.magics import MacroToEdit, CodeMagics
33 33 from IPython.core.magic import magics_class, line_magic, Magics
34 34 from IPython.core import payloadpage
35 35 from IPython.core.usage import default_gui_banner
36 36 from IPython.display import display, Javascript
37 37 from ipython_kernel.inprocess.socket import SocketABC
38 38 from ipython_kernel import (
39 39 get_connection_file, get_connection_info, connect_qtconsole
40 40 )
41 from IPython.testing.skipdoctest import skip_doctest
42 41 from IPython.utils import openpy
43 42 from jupyter_client.jsonutil import json_clean, encode_images
44 43 from IPython.utils.process import arg_split
45 44 from IPython.utils import py3compat
46 45 from IPython.utils.py3compat import unicode_type
47 46 from IPython.utils.traitlets import Instance, Type, Dict, CBool, CBytes, Any
48 47 from IPython.utils.warn import error
49 48 from ipython_kernel.displayhook import ZMQShellDisplayHook
50 49 from ipython_kernel.datapub import ZMQDataPublisher
51 50 from ipython_kernel.session import extract_header
52 51 from .session import Session
53 52
54 53 #-----------------------------------------------------------------------------
55 54 # Functions and classes
56 55 #-----------------------------------------------------------------------------
57 56
58 57 class ZMQDisplayPublisher(DisplayPublisher):
59 58 """A display publisher that publishes data using a ZeroMQ PUB socket."""
60 59
61 60 session = Instance(Session, allow_none=True)
62 61 pub_socket = Instance(SocketABC, allow_none=True)
63 62 parent_header = Dict({})
64 63 topic = CBytes(b'display_data')
65 64
66 65 def set_parent(self, parent):
67 66 """Set the parent for outbound messages."""
68 67 self.parent_header = extract_header(parent)
69 68
70 69 def _flush_streams(self):
71 70 """flush IO Streams prior to display"""
72 71 sys.stdout.flush()
73 72 sys.stderr.flush()
74 73
75 74 def publish(self, data, metadata=None, source=None):
76 75 self._flush_streams()
77 76 if metadata is None:
78 77 metadata = {}
79 78 self._validate_data(data, metadata)
80 79 content = {}
81 80 content['data'] = encode_images(data)
82 81 content['metadata'] = metadata
83 82 self.session.send(
84 83 self.pub_socket, u'display_data', json_clean(content),
85 84 parent=self.parent_header, ident=self.topic,
86 85 )
87 86
88 87 def clear_output(self, wait=False):
89 88 content = dict(wait=wait)
90 89 self._flush_streams()
91 90 self.session.send(
92 91 self.pub_socket, u'clear_output', content,
93 92 parent=self.parent_header, ident=self.topic,
94 93 )
95 94
96 95 @magics_class
97 96 class KernelMagics(Magics):
98 97 #------------------------------------------------------------------------
99 98 # Magic overrides
100 99 #------------------------------------------------------------------------
101 100 # Once the base class stops inheriting from magic, this code needs to be
102 101 # moved into a separate machinery as well. For now, at least isolate here
103 102 # the magics which this class needs to implement differently from the base
104 103 # class, or that are unique to it.
105 104
106 105 _find_edit_target = CodeMagics._find_edit_target
107 106
108 @skip_doctest
109 107 @line_magic
110 108 def edit(self, parameter_s='', last_call=['','']):
111 109 """Bring up an editor and execute the resulting code.
112 110
113 111 Usage:
114 112 %edit [options] [args]
115 113
116 114 %edit runs an external text editor. You will need to set the command for
117 115 this editor via the ``TerminalInteractiveShell.editor`` option in your
118 116 configuration file before it will work.
119 117
120 118 This command allows you to conveniently edit multi-line code right in
121 119 your IPython session.
122 120
123 121 If called without arguments, %edit opens up an empty editor with a
124 122 temporary file and will execute the contents of this file when you
125 123 close it (don't forget to save it!).
126 124
127 125 Options:
128 126
129 127 -n <number>
130 128 Open the editor at a specified line number. By default, the IPython
131 129 editor hook uses the unix syntax 'editor +N filename', but you can
132 130 configure this by providing your own modified hook if your favorite
133 131 editor supports line-number specifications with a different syntax.
134 132
135 133 -p
136 134 Call the editor with the same data as the previous time it was used,
137 135 regardless of how long ago (in your current session) it was.
138 136
139 137 -r
140 138 Use 'raw' input. This option only applies to input taken from the
141 139 user's history. By default, the 'processed' history is used, so that
142 140 magics are loaded in their transformed version to valid Python. If
143 141 this option is given, the raw input as typed as the command line is
144 142 used instead. When you exit the editor, it will be executed by
145 143 IPython's own processor.
146 144
147 145 Arguments:
148 146
149 147 If arguments are given, the following possibilites exist:
150 148
151 149 - The arguments are numbers or pairs of colon-separated numbers (like
152 150 1 4:8 9). These are interpreted as lines of previous input to be
153 151 loaded into the editor. The syntax is the same of the %macro command.
154 152
155 153 - If the argument doesn't start with a number, it is evaluated as a
156 154 variable and its contents loaded into the editor. You can thus edit
157 155 any string which contains python code (including the result of
158 156 previous edits).
159 157
160 158 - If the argument is the name of an object (other than a string),
161 159 IPython will try to locate the file where it was defined and open the
162 160 editor at the point where it is defined. You can use ``%edit function``
163 161 to load an editor exactly at the point where 'function' is defined,
164 162 edit it and have the file be executed automatically.
165 163
166 164 If the object is a macro (see %macro for details), this opens up your
167 165 specified editor with a temporary file containing the macro's data.
168 166 Upon exit, the macro is reloaded with the contents of the file.
169 167
170 168 Note: opening at an exact line is only supported under Unix, and some
171 169 editors (like kedit and gedit up to Gnome 2.8) do not understand the
172 170 '+NUMBER' parameter necessary for this feature. Good editors like
173 171 (X)Emacs, vi, jed, pico and joe all do.
174 172
175 173 - If the argument is not found as a variable, IPython will look for a
176 174 file with that name (adding .py if necessary) and load it into the
177 175 editor. It will execute its contents with execfile() when you exit,
178 176 loading any code in the file into your interactive namespace.
179 177
180 178 Unlike in the terminal, this is designed to use a GUI editor, and we do
181 179 not know when it has closed. So the file you edit will not be
182 180 automatically executed or printed.
183 181
184 182 Note that %edit is also available through the alias %ed.
185 183 """
186 184
187 185 opts,args = self.parse_options(parameter_s,'prn:')
188 186
189 187 try:
190 188 filename, lineno, _ = CodeMagics._find_edit_target(self.shell, args, opts, last_call)
191 189 except MacroToEdit as e:
192 190 # TODO: Implement macro editing over 2 processes.
193 191 print("Macro editing not yet implemented in 2-process model.")
194 192 return
195 193
196 194 # Make sure we send to the client an absolute path, in case the working
197 195 # directory of client and kernel don't match
198 196 filename = os.path.abspath(filename)
199 197
200 198 payload = {
201 199 'source' : 'edit_magic',
202 200 'filename' : filename,
203 201 'line_number' : lineno
204 202 }
205 203 self.shell.payload_manager.write_payload(payload)
206 204
207 205 # A few magics that are adapted to the specifics of using pexpect and a
208 206 # remote terminal
209 207
210 208 @line_magic
211 209 def clear(self, arg_s):
212 210 """Clear the terminal."""
213 211 if os.name == 'posix':
214 212 self.shell.system("clear")
215 213 else:
216 214 self.shell.system("cls")
217 215
218 216 if os.name == 'nt':
219 217 # This is the usual name in windows
220 218 cls = line_magic('cls')(clear)
221 219
222 220 # Terminal pagers won't work over pexpect, but we do have our own pager
223 221
224 222 @line_magic
225 223 def less(self, arg_s):
226 224 """Show a file through the pager.
227 225
228 226 Files ending in .py are syntax-highlighted."""
229 227 if not arg_s:
230 228 raise UsageError('Missing filename.')
231 229
232 230 if arg_s.endswith('.py'):
233 231 cont = self.shell.pycolorize(openpy.read_py_file(arg_s, skip_encoding_cookie=False))
234 232 else:
235 233 cont = open(arg_s).read()
236 234 page.page(cont)
237 235
238 236 more = line_magic('more')(less)
239 237
240 238 # Man calls a pager, so we also need to redefine it
241 239 if os.name == 'posix':
242 240 @line_magic
243 241 def man(self, arg_s):
244 242 """Find the man page for the given command and display in pager."""
245 243 page.page(self.shell.getoutput('man %s | col -b' % arg_s,
246 244 split=False))
247 245
248 246 @line_magic
249 247 def connect_info(self, arg_s):
250 248 """Print information for connecting other clients to this kernel
251 249
252 250 It will print the contents of this session's connection file, as well as
253 251 shortcuts for local clients.
254 252
255 253 In the simplest case, when called from the most recently launched kernel,
256 254 secondary clients can be connected, simply with:
257 255
258 256 $> ipython <app> --existing
259 257
260 258 """
261 259
262 260 from IPython.core.application import BaseIPythonApplication as BaseIPApp
263 261
264 262 if BaseIPApp.initialized():
265 263 app = BaseIPApp.instance()
266 264 security_dir = app.profile_dir.security_dir
267 265 profile = app.profile
268 266 else:
269 267 profile = 'default'
270 268 security_dir = ''
271 269
272 270 try:
273 271 connection_file = get_connection_file()
274 272 info = get_connection_info(unpack=False)
275 273 except Exception as e:
276 274 error("Could not get connection info: %r" % e)
277 275 return
278 276
279 277 # add profile flag for non-default profile
280 278 profile_flag = "--profile %s" % profile if profile != 'default' else ""
281 279
282 280 # if it's in the security dir, truncate to basename
283 281 if security_dir == os.path.dirname(connection_file):
284 282 connection_file = os.path.basename(connection_file)
285 283
286 284
287 285 print (info + '\n')
288 286 print ("Paste the above JSON into a file, and connect with:\n"
289 287 " $> ipython <app> --existing <file>\n"
290 288 "or, if you are local, you can connect with just:\n"
291 289 " $> ipython <app> --existing {0} {1}\n"
292 290 "or even just:\n"
293 291 " $> ipython <app> --existing {1}\n"
294 292 "if this is the most recent IPython session you have started.".format(
295 293 connection_file, profile_flag
296 294 )
297 295 )
298 296
299 297 @line_magic
300 298 def qtconsole(self, arg_s):
301 299 """Open a qtconsole connected to this kernel.
302 300
303 301 Useful for connecting a qtconsole to running notebooks, for better
304 302 debugging.
305 303 """
306 304
307 305 # %qtconsole should imply bind_kernel for engines:
308 306 try:
309 307 from IPython.parallel import bind_kernel
310 308 except ImportError:
311 309 # technically possible, because parallel has higher pyzmq min-version
312 310 pass
313 311 else:
314 312 bind_kernel()
315 313
316 314 try:
317 315 p = connect_qtconsole(argv=arg_split(arg_s, os.name=='posix'))
318 316 except Exception as e:
319 317 error("Could not start qtconsole: %r" % e)
320 318 return
321 319
322 320 @line_magic
323 321 def autosave(self, arg_s):
324 322 """Set the autosave interval in the notebook (in seconds).
325 323
326 324 The default value is 120, or two minutes.
327 325 ``%autosave 0`` will disable autosave.
328 326
329 327 This magic only has an effect when called from the notebook interface.
330 328 It has no effect when called in a startup file.
331 329 """
332 330
333 331 try:
334 332 interval = int(arg_s)
335 333 except ValueError:
336 334 raise UsageError("%%autosave requires an integer, got %r" % arg_s)
337 335
338 336 # javascript wants milliseconds
339 337 milliseconds = 1000 * interval
340 338 display(Javascript("IPython.notebook.set_autosave_interval(%i)" % milliseconds),
341 339 include=['application/javascript']
342 340 )
343 341 if interval:
344 342 print("Autosaving every %i seconds" % interval)
345 343 else:
346 344 print("Autosave disabled")
347 345
348 346
349 347 class ZMQInteractiveShell(InteractiveShell):
350 348 """A subclass of InteractiveShell for ZMQ."""
351 349
352 350 displayhook_class = Type(ZMQShellDisplayHook)
353 351 display_pub_class = Type(ZMQDisplayPublisher)
354 352 data_pub_class = Type(ZMQDataPublisher)
355 353 kernel = Any()
356 354 parent_header = Any()
357 355
358 356 def _banner1_default(self):
359 357 return default_gui_banner
360 358
361 359 # Override the traitlet in the parent class, because there's no point using
362 360 # readline for the kernel. Can be removed when the readline code is moved
363 361 # to the terminal frontend.
364 362 colors_force = CBool(True)
365 363 readline_use = CBool(False)
366 364 # autoindent has no meaning in a zmqshell, and attempting to enable it
367 365 # will print a warning in the absence of readline.
368 366 autoindent = CBool(False)
369 367
370 368 exiter = Instance(ZMQExitAutocall)
371 369 def _exiter_default(self):
372 370 return ZMQExitAutocall(self)
373 371
374 372 def _exit_now_changed(self, name, old, new):
375 373 """stop eventloop when exit_now fires"""
376 374 if new:
377 375 loop = ioloop.IOLoop.instance()
378 376 loop.add_timeout(time.time()+0.1, loop.stop)
379 377
380 378 keepkernel_on_exit = None
381 379
382 380 # Over ZeroMQ, GUI control isn't done with PyOS_InputHook as there is no
383 381 # interactive input being read; we provide event loop support in ipkernel
384 382 @staticmethod
385 383 def enable_gui(gui):
386 384 from .eventloops import enable_gui as real_enable_gui
387 385 try:
388 386 real_enable_gui(gui)
389 387 except ValueError as e:
390 388 raise UsageError("%s" % e)
391 389
392 390 def init_environment(self):
393 391 """Configure the user's environment."""
394 392 env = os.environ
395 393 # These two ensure 'ls' produces nice coloring on BSD-derived systems
396 394 env['TERM'] = 'xterm-color'
397 395 env['CLICOLOR'] = '1'
398 396 # Since normal pagers don't work at all (over pexpect we don't have
399 397 # single-key control of the subprocess), try to disable paging in
400 398 # subprocesses as much as possible.
401 399 env['PAGER'] = 'cat'
402 400 env['GIT_PAGER'] = 'cat'
403 401
404 402 def init_hooks(self):
405 403 super(ZMQInteractiveShell, self).init_hooks()
406 404 self.set_hook('show_in_pager', page.as_hook(payloadpage.page), 99)
407 405
408 406 def ask_exit(self):
409 407 """Engage the exit actions."""
410 408 self.exit_now = (not self.keepkernel_on_exit)
411 409 payload = dict(
412 410 source='ask_exit',
413 411 keepkernel=self.keepkernel_on_exit,
414 412 )
415 413 self.payload_manager.write_payload(payload)
416 414
417 415 def _showtraceback(self, etype, evalue, stb):
418 416 # try to preserve ordering of tracebacks and print statements
419 417 sys.stdout.flush()
420 418 sys.stderr.flush()
421 419
422 420 exc_content = {
423 421 u'traceback' : stb,
424 422 u'ename' : unicode_type(etype.__name__),
425 423 u'evalue' : py3compat.safe_unicode(evalue),
426 424 }
427 425
428 426 dh = self.displayhook
429 427 # Send exception info over pub socket for other clients than the caller
430 428 # to pick up
431 429 topic = None
432 430 if dh.topic:
433 431 topic = dh.topic.replace(b'execute_result', b'error')
434 432
435 433 exc_msg = dh.session.send(dh.pub_socket, u'error', json_clean(exc_content), dh.parent_header, ident=topic)
436 434
437 435 # FIXME - Hack: store exception info in shell object. Right now, the
438 436 # caller is reading this info after the fact, we need to fix this logic
439 437 # to remove this hack. Even uglier, we need to store the error status
440 438 # here, because in the main loop, the logic that sets it is being
441 439 # skipped because runlines swallows the exceptions.
442 440 exc_content[u'status'] = u'error'
443 441 self._reply_content = exc_content
444 442 # /FIXME
445 443
446 444 return exc_content
447 445
448 446 def set_next_input(self, text, replace=False):
449 447 """Send the specified text to the frontend to be presented at the next
450 448 input cell."""
451 449 payload = dict(
452 450 source='set_next_input',
453 451 text=text,
454 452 replace=replace,
455 453 )
456 454 self.payload_manager.write_payload(payload)
457 455
458 456 def set_parent(self, parent):
459 457 """Set the parent header for associating output with its triggering input"""
460 458 self.parent_header = parent
461 459 self.displayhook.set_parent(parent)
462 460 self.display_pub.set_parent(parent)
463 461 self.data_pub.set_parent(parent)
464 462 try:
465 463 sys.stdout.set_parent(parent)
466 464 except AttributeError:
467 465 pass
468 466 try:
469 467 sys.stderr.set_parent(parent)
470 468 except AttributeError:
471 469 pass
472 470
473 471 def get_parent(self):
474 472 return self.parent_header
475 473
476 474 #-------------------------------------------------------------------------
477 475 # Things related to magics
478 476 #-------------------------------------------------------------------------
479 477
480 478 def init_magics(self):
481 479 super(ZMQInteractiveShell, self).init_magics()
482 480 self.register_magics(KernelMagics)
483 481 self.magics_manager.register_alias('ed', 'edit')
484 482
485 483
486 484 InteractiveShellABC.register(ZMQInteractiveShell)
@@ -1,441 +1,436 b''
1 1 # encoding: utf-8
2 2 """
3 3 =============
4 4 parallelmagic
5 5 =============
6 6
7 7 Magic command interface for interactive parallel work.
8 8
9 9 Usage
10 10 =====
11 11
12 12 ``%autopx``
13 13
14 14 {AUTOPX_DOC}
15 15
16 16 ``%px``
17 17
18 18 {PX_DOC}
19 19
20 20 ``%pxresult``
21 21
22 22 {RESULT_DOC}
23 23
24 24 ``%pxconfig``
25 25
26 26 {CONFIG_DOC}
27 27
28 28 """
29 29 from __future__ import print_function
30 30
31 31 #-----------------------------------------------------------------------------
32 32 # Copyright (C) 2008 The IPython Development Team
33 33 #
34 34 # Distributed under the terms of the BSD License. The full license is in
35 35 # the file COPYING, distributed as part of this software.
36 36 #-----------------------------------------------------------------------------
37 37
38 38 #-----------------------------------------------------------------------------
39 39 # Imports
40 40 #-----------------------------------------------------------------------------
41 41
42 42 import ast
43 43 import re
44 44
45 45 from IPython.core.error import UsageError
46 46 from IPython.core.magic import Magics
47 47 from IPython.core import magic_arguments
48 from IPython.testing.skipdoctest import skip_doctest
49 48 from IPython.utils.text import dedent
50 49
51 50 #-----------------------------------------------------------------------------
52 51 # Definitions of magic functions for use with IPython
53 52 #-----------------------------------------------------------------------------
54 53
55 54
56 55 NO_LAST_RESULT = "%pxresult recalls last %px result, which has not yet been used."
57 56
58 57 def exec_args(f):
59 58 """decorator for adding block/targets args for execution
60 59
61 60 applied to %pxconfig and %%px
62 61 """
63 62 args = [
64 63 magic_arguments.argument('-b', '--block', action="store_const",
65 64 const=True, dest='block',
66 65 help="use blocking (sync) execution",
67 66 ),
68 67 magic_arguments.argument('-a', '--noblock', action="store_const",
69 68 const=False, dest='block',
70 69 help="use non-blocking (async) execution",
71 70 ),
72 71 magic_arguments.argument('-t', '--targets', type=str,
73 72 help="specify the targets on which to execute",
74 73 ),
75 74 magic_arguments.argument('--local', action="store_const",
76 75 const=True, dest="local",
77 76 help="also execute the cell in the local namespace",
78 77 ),
79 78 magic_arguments.argument('--verbose', action="store_const",
80 79 const=True, dest="set_verbose",
81 80 help="print a message at each execution",
82 81 ),
83 82 magic_arguments.argument('--no-verbose', action="store_const",
84 83 const=False, dest="set_verbose",
85 84 help="don't print any messages",
86 85 ),
87 86 ]
88 87 for a in args:
89 88 f = a(f)
90 89 return f
91 90
92 91 def output_args(f):
93 92 """decorator for output-formatting args
94 93
95 94 applied to %pxresult and %%px
96 95 """
97 96 args = [
98 97 magic_arguments.argument('-r', action="store_const", dest='groupby',
99 98 const='order',
100 99 help="collate outputs in order (same as group-outputs=order)"
101 100 ),
102 101 magic_arguments.argument('-e', action="store_const", dest='groupby',
103 102 const='engine',
104 103 help="group outputs by engine (same as group-outputs=engine)"
105 104 ),
106 105 magic_arguments.argument('--group-outputs', dest='groupby', type=str,
107 106 choices=['engine', 'order', 'type'], default='type',
108 107 help="""Group the outputs in a particular way.
109 108
110 109 Choices are:
111 110
112 111 **type**: group outputs of all engines by type (stdout, stderr, displaypub, etc.).
113 112 **engine**: display all output for each engine together.
114 113 **order**: like type, but individual displaypub output from each engine is collated.
115 114 For example, if multiple plots are generated by each engine, the first
116 115 figure of each engine will be displayed, then the second of each, etc.
117 116 """
118 117 ),
119 118 magic_arguments.argument('-o', '--out', dest='save_name', type=str,
120 119 help="""store the AsyncResult object for this computation
121 120 in the global namespace under this name.
122 121 """
123 122 ),
124 123 ]
125 124 for a in args:
126 125 f = a(f)
127 126 return f
128 127
129 128 class ParallelMagics(Magics):
130 129 """A set of magics useful when controlling a parallel IPython cluster.
131 130 """
132 131
133 132 # magic-related
134 133 magics = None
135 134 registered = True
136 135
137 136 # suffix for magics
138 137 suffix = ''
139 138 # A flag showing if autopx is activated or not
140 139 _autopx = False
141 140 # the current view used by the magics:
142 141 view = None
143 142 # last result cache for %pxresult
144 143 last_result = None
145 144 # verbose flag
146 145 verbose = False
147 146
148 147 def __init__(self, shell, view, suffix=''):
149 148 self.view = view
150 149 self.suffix = suffix
151 150
152 151 # register magics
153 152 self.magics = dict(cell={},line={})
154 153 line_magics = self.magics['line']
155 154
156 155 px = 'px' + suffix
157 156 if not suffix:
158 157 # keep %result for legacy compatibility
159 158 line_magics['result'] = self.result
160 159
161 160 line_magics['pxresult' + suffix] = self.result
162 161 line_magics[px] = self.px
163 162 line_magics['pxconfig' + suffix] = self.pxconfig
164 163 line_magics['auto' + px] = self.autopx
165 164
166 165 self.magics['cell'][px] = self.cell_px
167 166
168 167 super(ParallelMagics, self).__init__(shell=shell)
169 168
170 169 def _eval_target_str(self, ts):
171 170 if ':' in ts:
172 171 targets = eval("self.view.client.ids[%s]" % ts)
173 172 elif 'all' in ts:
174 173 targets = 'all'
175 174 else:
176 175 targets = eval(ts)
177 176 return targets
178 177
179 178 @magic_arguments.magic_arguments()
180 179 @exec_args
181 180 def pxconfig(self, line):
182 181 """configure default targets/blocking for %px magics"""
183 182 args = magic_arguments.parse_argstring(self.pxconfig, line)
184 183 if args.targets:
185 184 self.view.targets = self._eval_target_str(args.targets)
186 185 if args.block is not None:
187 186 self.view.block = args.block
188 187 if args.set_verbose is not None:
189 188 self.verbose = args.set_verbose
190 189
191 190 @magic_arguments.magic_arguments()
192 191 @output_args
193 @skip_doctest
194 192 def result(self, line=''):
195 193 """Print the result of the last asynchronous %px command.
196 194
197 195 This lets you recall the results of %px computations after
198 196 asynchronous submission (block=False).
199 197
200 198 Examples
201 199 --------
202 200 ::
203 201
204 202 In [23]: %px os.getpid()
205 203 Async parallel execution on engine(s): all
206 204
207 205 In [24]: %pxresult
208 206 Out[8:10]: 60920
209 207 Out[9:10]: 60921
210 208 Out[10:10]: 60922
211 209 Out[11:10]: 60923
212 210 """
213 211 args = magic_arguments.parse_argstring(self.result, line)
214 212
215 213 if self.last_result is None:
216 214 raise UsageError(NO_LAST_RESULT)
217 215
218 216 self.last_result.get()
219 217 self.last_result.display_outputs(groupby=args.groupby)
220 218
221 @skip_doctest
222 219 def px(self, line=''):
223 220 """Executes the given python command in parallel.
224 221
225 222 Examples
226 223 --------
227 224 ::
228 225
229 226 In [24]: %px a = os.getpid()
230 227 Parallel execution on engine(s): all
231 228
232 229 In [25]: %px print a
233 230 [stdout:0] 1234
234 231 [stdout:1] 1235
235 232 [stdout:2] 1236
236 233 [stdout:3] 1237
237 234 """
238 235 return self.parallel_execute(line)
239 236
240 237 def parallel_execute(self, cell, block=None, groupby='type', save_name=None):
241 238 """implementation used by %px and %%parallel"""
242 239
243 240 # defaults:
244 241 block = self.view.block if block is None else block
245 242
246 243 base = "Parallel" if block else "Async parallel"
247 244
248 245 targets = self.view.targets
249 246 if isinstance(targets, list) and len(targets) > 10:
250 247 str_targets = str(targets[:4])[:-1] + ', ..., ' + str(targets[-4:])[1:]
251 248 else:
252 249 str_targets = str(targets)
253 250 if self.verbose:
254 251 print(base + " execution on engine(s): %s" % str_targets)
255 252
256 253 result = self.view.execute(cell, silent=False, block=False)
257 254 self.last_result = result
258 255
259 256 if save_name:
260 257 self.shell.user_ns[save_name] = result
261 258
262 259 if block:
263 260 result.get()
264 261 result.display_outputs(groupby)
265 262 else:
266 263 # return AsyncResult only on non-blocking submission
267 264 return result
268 265
269 266 @magic_arguments.magic_arguments()
270 267 @exec_args
271 268 @output_args
272 @skip_doctest
273 269 def cell_px(self, line='', cell=None):
274 270 """Executes the cell in parallel.
275 271
276 272 Examples
277 273 --------
278 274 ::
279 275
280 276 In [24]: %%px --noblock
281 277 ....: a = os.getpid()
282 278 Async parallel execution on engine(s): all
283 279
284 280 In [25]: %%px
285 281 ....: print a
286 282 [stdout:0] 1234
287 283 [stdout:1] 1235
288 284 [stdout:2] 1236
289 285 [stdout:3] 1237
290 286 """
291 287
292 288 args = magic_arguments.parse_argstring(self.cell_px, line)
293 289
294 290 if args.targets:
295 291 save_targets = self.view.targets
296 292 self.view.targets = self._eval_target_str(args.targets)
297 293 # if running local, don't block until after local has run
298 294 block = False if args.local else args.block
299 295 try:
300 296 ar = self.parallel_execute(cell, block=block,
301 297 groupby=args.groupby,
302 298 save_name=args.save_name,
303 299 )
304 300 finally:
305 301 if args.targets:
306 302 self.view.targets = save_targets
307 303
308 304 # run locally after submitting remote
309 305 block = self.view.block if args.block is None else args.block
310 306 if args.local:
311 307 self.shell.run_cell(cell)
312 308 # now apply blocking behavor to remote execution
313 309 if block:
314 310 ar.get()
315 311 ar.display_outputs(args.groupby)
316 312 if not block:
317 313 return ar
318 314
319 @skip_doctest
320 315 def autopx(self, line=''):
321 316 """Toggles auto parallel mode.
322 317
323 318 Once this is called, all commands typed at the command line are send to
324 319 the engines to be executed in parallel. To control which engine are
325 320 used, the ``targets`` attribute of the view before
326 321 entering ``%autopx`` mode.
327 322
328 323
329 324 Then you can do the following::
330 325
331 326 In [25]: %autopx
332 327 %autopx to enabled
333 328
334 329 In [26]: a = 10
335 330 Parallel execution on engine(s): [0,1,2,3]
336 331 In [27]: print a
337 332 Parallel execution on engine(s): [0,1,2,3]
338 333 [stdout:0] 10
339 334 [stdout:1] 10
340 335 [stdout:2] 10
341 336 [stdout:3] 10
342 337
343 338
344 339 In [27]: %autopx
345 340 %autopx disabled
346 341 """
347 342 if self._autopx:
348 343 self._disable_autopx()
349 344 else:
350 345 self._enable_autopx()
351 346
352 347 def _enable_autopx(self):
353 348 """Enable %autopx mode by saving the original run_cell and installing
354 349 pxrun_cell.
355 350 """
356 351 # override run_cell
357 352 self._original_run_cell = self.shell.run_cell
358 353 self.shell.run_cell = self.pxrun_cell
359 354
360 355 self._autopx = True
361 356 print("%autopx enabled")
362 357
363 358 def _disable_autopx(self):
364 359 """Disable %autopx by restoring the original InteractiveShell.run_cell.
365 360 """
366 361 if self._autopx:
367 362 self.shell.run_cell = self._original_run_cell
368 363 self._autopx = False
369 364 print("%autopx disabled")
370 365
371 366 def pxrun_cell(self, raw_cell, store_history=False, silent=False):
372 367 """drop-in replacement for InteractiveShell.run_cell.
373 368
374 369 This executes code remotely, instead of in the local namespace.
375 370
376 371 See InteractiveShell.run_cell for details.
377 372 """
378 373
379 374 if (not raw_cell) or raw_cell.isspace():
380 375 return
381 376
382 377 ipself = self.shell
383 378
384 379 with ipself.builtin_trap:
385 380 cell = ipself.prefilter_manager.prefilter_lines(raw_cell)
386 381
387 382 # Store raw and processed history
388 383 if store_history:
389 384 ipself.history_manager.store_inputs(ipself.execution_count,
390 385 cell, raw_cell)
391 386
392 387 # ipself.logger.log(cell, raw_cell)
393 388
394 389 cell_name = ipself.compile.cache(cell, ipself.execution_count)
395 390
396 391 try:
397 392 ast.parse(cell, filename=cell_name)
398 393 except (OverflowError, SyntaxError, ValueError, TypeError,
399 394 MemoryError):
400 395 # Case 1
401 396 ipself.showsyntaxerror()
402 397 ipself.execution_count += 1
403 398 return None
404 399 except NameError:
405 400 # ignore name errors, because we don't know the remote keys
406 401 pass
407 402
408 403 if store_history:
409 404 # Write output to the database. Does nothing unless
410 405 # history output logging is enabled.
411 406 ipself.history_manager.store_output(ipself.execution_count)
412 407 # Each cell is a *single* input, regardless of how many lines it has
413 408 ipself.execution_count += 1
414 409 if re.search(r'get_ipython\(\)\.magic\(u?["\']%?autopx', cell):
415 410 self._disable_autopx()
416 411 return False
417 412 else:
418 413 try:
419 414 result = self.view.execute(cell, silent=False, block=False)
420 415 except:
421 416 ipself.showtraceback()
422 417 return True
423 418 else:
424 419 if self.view.block:
425 420 try:
426 421 result.get()
427 422 except:
428 423 self.shell.showtraceback()
429 424 return True
430 425 else:
431 426 with ipself.builtin_trap:
432 427 result.display_outputs()
433 428 return False
434 429
435 430
436 431 __doc__ = __doc__.format(
437 432 AUTOPX_DOC = dedent(ParallelMagics.autopx.__doc__),
438 433 PX_DOC = dedent(ParallelMagics.px.__doc__),
439 434 RESULT_DOC = dedent(ParallelMagics.result.__doc__),
440 435 CONFIG_DOC = dedent(ParallelMagics.pxconfig.__doc__),
441 436 )
@@ -1,276 +1,273 b''
1 1 """Remote Functions and decorators for Views."""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 from __future__ import division
7 7
8 8 import sys
9 9 import warnings
10 10
11 11 from decorator import decorator
12 from IPython.testing.skipdoctest import skip_doctest
13 12
14 13 from . import map as Map
15 14 from .asyncresult import AsyncMapResult
16 15
17 16 #-----------------------------------------------------------------------------
18 17 # Functions and Decorators
19 18 #-----------------------------------------------------------------------------
20 19
21 @skip_doctest
22 20 def remote(view, block=None, **flags):
23 21 """Turn a function into a remote function.
24 22
25 23 This method can be used for map:
26 24
27 25 In [1]: @remote(view,block=True)
28 26 ...: def func(a):
29 27 ...: pass
30 28 """
31 29
32 30 def remote_function(f):
33 31 return RemoteFunction(view, f, block=block, **flags)
34 32 return remote_function
35 33
36 @skip_doctest
37 34 def parallel(view, dist='b', block=None, ordered=True, **flags):
38 35 """Turn a function into a parallel remote function.
39 36
40 37 This method can be used for map:
41 38
42 39 In [1]: @parallel(view, block=True)
43 40 ...: def func(a):
44 41 ...: pass
45 42 """
46 43
47 44 def parallel_function(f):
48 45 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
49 46 return parallel_function
50 47
51 48 def getname(f):
52 49 """Get the name of an object.
53 50
54 51 For use in case of callables that are not functions, and
55 52 thus may not have __name__ defined.
56 53
57 54 Order: f.__name__ > f.name > str(f)
58 55 """
59 56 try:
60 57 return f.__name__
61 58 except:
62 59 pass
63 60 try:
64 61 return f.name
65 62 except:
66 63 pass
67 64
68 65 return str(f)
69 66
70 67 @decorator
71 68 def sync_view_results(f, self, *args, **kwargs):
72 69 """sync relevant results from self.client to our results attribute.
73 70
74 71 This is a clone of view.sync_results, but for remote functions
75 72 """
76 73 view = self.view
77 74 if view._in_sync_results:
78 75 return f(self, *args, **kwargs)
79 76 view._in_sync_results = True
80 77 try:
81 78 ret = f(self, *args, **kwargs)
82 79 finally:
83 80 view._in_sync_results = False
84 81 view._sync_results()
85 82 return ret
86 83
87 84 #--------------------------------------------------------------------------
88 85 # Classes
89 86 #--------------------------------------------------------------------------
90 87
91 88 class RemoteFunction(object):
92 89 """Turn an existing function into a remote function.
93 90
94 91 Parameters
95 92 ----------
96 93
97 94 view : View instance
98 95 The view to be used for execution
99 96 f : callable
100 97 The function to be wrapped into a remote function
101 98 block : bool [default: None]
102 99 Whether to wait for results or not. The default behavior is
103 100 to use the current `block` attribute of `view`
104 101
105 102 **flags : remaining kwargs are passed to View.temp_flags
106 103 """
107 104
108 105 view = None # the remote connection
109 106 func = None # the wrapped function
110 107 block = None # whether to block
111 108 flags = None # dict of extra kwargs for temp_flags
112 109
113 110 def __init__(self, view, f, block=None, **flags):
114 111 self.view = view
115 112 self.func = f
116 113 self.block=block
117 114 self.flags=flags
118 115
119 116 def __call__(self, *args, **kwargs):
120 117 block = self.view.block if self.block is None else self.block
121 118 with self.view.temp_flags(block=block, **self.flags):
122 119 return self.view.apply(self.func, *args, **kwargs)
123 120
124 121
125 122 class ParallelFunction(RemoteFunction):
126 123 """Class for mapping a function to sequences.
127 124
128 125 This will distribute the sequences according the a mapper, and call
129 126 the function on each sub-sequence. If called via map, then the function
130 127 will be called once on each element, rather that each sub-sequence.
131 128
132 129 Parameters
133 130 ----------
134 131
135 132 view : View instance
136 133 The view to be used for execution
137 134 f : callable
138 135 The function to be wrapped into a remote function
139 136 dist : str [default: 'b']
140 137 The key for which mapObject to use to distribute sequences
141 138 options are:
142 139
143 140 * 'b' : use contiguous chunks in order
144 141 * 'r' : use round-robin striping
145 142
146 143 block : bool [default: None]
147 144 Whether to wait for results or not. The default behavior is
148 145 to use the current `block` attribute of `view`
149 146 chunksize : int or None
150 147 The size of chunk to use when breaking up sequences in a load-balanced manner
151 148 ordered : bool [default: True]
152 149 Whether the result should be kept in order. If False,
153 150 results become available as they arrive, regardless of submission order.
154 151 **flags
155 152 remaining kwargs are passed to View.temp_flags
156 153 """
157 154
158 155 chunksize = None
159 156 ordered = None
160 157 mapObject = None
161 158 _mapping = False
162 159
163 160 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
164 161 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
165 162 self.chunksize = chunksize
166 163 self.ordered = ordered
167 164
168 165 mapClass = Map.dists[dist]
169 166 self.mapObject = mapClass()
170 167
171 168 @sync_view_results
172 169 def __call__(self, *sequences):
173 170 client = self.view.client
174 171
175 172 lens = []
176 173 maxlen = minlen = -1
177 174 for i, seq in enumerate(sequences):
178 175 try:
179 176 n = len(seq)
180 177 except Exception:
181 178 seq = list(seq)
182 179 if isinstance(sequences, tuple):
183 180 # can't alter a tuple
184 181 sequences = list(sequences)
185 182 sequences[i] = seq
186 183 n = len(seq)
187 184 if n > maxlen:
188 185 maxlen = n
189 186 if minlen == -1 or n < minlen:
190 187 minlen = n
191 188 lens.append(n)
192 189
193 190 if maxlen == 0:
194 191 # nothing to iterate over
195 192 return []
196 193
197 194 # check that the length of sequences match
198 195 if not self._mapping and minlen != maxlen:
199 196 msg = 'all sequences must have equal length, but have %s' % lens
200 197 raise ValueError(msg)
201 198
202 199 balanced = 'Balanced' in self.view.__class__.__name__
203 200 if balanced:
204 201 if self.chunksize:
205 202 nparts = maxlen // self.chunksize + int(maxlen % self.chunksize > 0)
206 203 else:
207 204 nparts = maxlen
208 205 targets = [None]*nparts
209 206 else:
210 207 if self.chunksize:
211 208 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
212 209 # multiplexed:
213 210 targets = self.view.targets
214 211 # 'all' is lazily evaluated at execution time, which is now:
215 212 if targets == 'all':
216 213 targets = client._build_targets(targets)[1]
217 214 elif isinstance(targets, int):
218 215 # single-engine view, targets must be iterable
219 216 targets = [targets]
220 217 nparts = len(targets)
221 218
222 219 msg_ids = []
223 220 for index, t in enumerate(targets):
224 221 args = []
225 222 for seq in sequences:
226 223 part = self.mapObject.getPartition(seq, index, nparts, maxlen)
227 224 args.append(part)
228 225
229 226 if sum([len(arg) for arg in args]) == 0:
230 227 continue
231 228
232 229 if self._mapping:
233 230 if sys.version_info[0] >= 3:
234 231 f = lambda f, *sequences: list(map(f, *sequences))
235 232 else:
236 233 f = map
237 234 args = [self.func] + args
238 235 else:
239 236 f=self.func
240 237
241 238 view = self.view if balanced else client[t]
242 239 with view.temp_flags(block=False, **self.flags):
243 240 ar = view.apply(f, *args)
244 241
245 242 msg_ids.extend(ar.msg_ids)
246 243
247 244 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
248 245 fname=getname(self.func),
249 246 ordered=self.ordered
250 247 )
251 248
252 249 if self.block:
253 250 try:
254 251 return r.get()
255 252 except KeyboardInterrupt:
256 253 return r
257 254 else:
258 255 return r
259 256
260 257 def map(self, *sequences):
261 258 """call a function on each element of one or more sequence(s) remotely.
262 259 This should behave very much like the builtin map, but return an AsyncMapResult
263 260 if self.block is False.
264 261
265 262 That means it can take generators (will be cast to lists locally),
266 263 and mismatched sequence lengths will be padded with None.
267 264 """
268 265 # set _mapping as a flag for use inside self.__call__
269 266 self._mapping = True
270 267 try:
271 268 ret = self(*sequences)
272 269 finally:
273 270 self._mapping = False
274 271 return ret
275 272
276 273 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -1,1125 +1,1121 b''
1 1 """Views of remote engines."""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 from __future__ import print_function
7 7
8 8 import imp
9 9 import sys
10 10 import warnings
11 11 from contextlib import contextmanager
12 12 from types import ModuleType
13 13
14 14 import zmq
15 15
16 from IPython.testing.skipdoctest import skip_doctest
17 16 from IPython.utils import pickleutil
18 17 from IPython.utils.traitlets import (
19 18 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
20 19 )
21 20 from decorator import decorator
22 21
23 22 from ipython_parallel import util
24 23 from ipython_parallel.controller.dependency import Dependency, dependent
25 24 from IPython.utils.py3compat import string_types, iteritems, PY3
26 25
27 26 from . import map as Map
28 27 from .asyncresult import AsyncResult, AsyncMapResult
29 28 from .remotefunction import ParallelFunction, parallel, remote, getname
30 29
31 30 #-----------------------------------------------------------------------------
32 31 # Decorators
33 32 #-----------------------------------------------------------------------------
34 33
35 34 @decorator
36 35 def save_ids(f, self, *args, **kwargs):
37 36 """Keep our history and outstanding attributes up to date after a method call."""
38 37 n_previous = len(self.client.history)
39 38 try:
40 39 ret = f(self, *args, **kwargs)
41 40 finally:
42 41 nmsgs = len(self.client.history) - n_previous
43 42 msg_ids = self.client.history[-nmsgs:]
44 43 self.history.extend(msg_ids)
45 44 self.outstanding.update(msg_ids)
46 45 return ret
47 46
48 47 @decorator
49 48 def sync_results(f, self, *args, **kwargs):
50 49 """sync relevant results from self.client to our results attribute."""
51 50 if self._in_sync_results:
52 51 return f(self, *args, **kwargs)
53 52 self._in_sync_results = True
54 53 try:
55 54 ret = f(self, *args, **kwargs)
56 55 finally:
57 56 self._in_sync_results = False
58 57 self._sync_results()
59 58 return ret
60 59
61 60 @decorator
62 61 def spin_after(f, self, *args, **kwargs):
63 62 """call spin after the method."""
64 63 ret = f(self, *args, **kwargs)
65 64 self.spin()
66 65 return ret
67 66
68 67 #-----------------------------------------------------------------------------
69 68 # Classes
70 69 #-----------------------------------------------------------------------------
71 70
72 @skip_doctest
73 71 class View(HasTraits):
74 72 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
75 73
76 74 Don't use this class, use subclasses.
77 75
78 76 Methods
79 77 -------
80 78
81 79 spin
82 80 flushes incoming results and registration state changes
83 81 control methods spin, and requesting `ids` also ensures up to date
84 82
85 83 wait
86 84 wait on one or more msg_ids
87 85
88 86 execution methods
89 87 apply
90 88 legacy: execute, run
91 89
92 90 data movement
93 91 push, pull, scatter, gather
94 92
95 93 query methods
96 94 get_result, queue_status, purge_results, result_status
97 95
98 96 control methods
99 97 abort, shutdown
100 98
101 99 """
102 100 # flags
103 101 block=Bool(False)
104 102 track=Bool(True)
105 103 targets = Any()
106 104
107 105 history=List()
108 106 outstanding = Set()
109 107 results = Dict()
110 108 client = Instance('ipython_parallel.Client', allow_none=True)
111 109
112 110 _socket = Instance('zmq.Socket', allow_none=True)
113 111 _flag_names = List(['targets', 'block', 'track'])
114 112 _in_sync_results = Bool(False)
115 113 _targets = Any()
116 114 _idents = Any()
117 115
118 116 def __init__(self, client=None, socket=None, **flags):
119 117 super(View, self).__init__(client=client, _socket=socket)
120 118 self.results = client.results
121 119 self.block = client.block
122 120
123 121 self.set_flags(**flags)
124 122
125 123 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
126 124
127 125 def __repr__(self):
128 126 strtargets = str(self.targets)
129 127 if len(strtargets) > 16:
130 128 strtargets = strtargets[:12]+'...]'
131 129 return "<%s %s>"%(self.__class__.__name__, strtargets)
132 130
133 131 def __len__(self):
134 132 if isinstance(self.targets, list):
135 133 return len(self.targets)
136 134 elif isinstance(self.targets, int):
137 135 return 1
138 136 else:
139 137 return len(self.client)
140 138
141 139 def set_flags(self, **kwargs):
142 140 """set my attribute flags by keyword.
143 141
144 142 Views determine behavior with a few attributes (`block`, `track`, etc.).
145 143 These attributes can be set all at once by name with this method.
146 144
147 145 Parameters
148 146 ----------
149 147
150 148 block : bool
151 149 whether to wait for results
152 150 track : bool
153 151 whether to create a MessageTracker to allow the user to
154 152 safely edit after arrays and buffers during non-copying
155 153 sends.
156 154 """
157 155 for name, value in iteritems(kwargs):
158 156 if name not in self._flag_names:
159 157 raise KeyError("Invalid name: %r"%name)
160 158 else:
161 159 setattr(self, name, value)
162 160
163 161 @contextmanager
164 162 def temp_flags(self, **kwargs):
165 163 """temporarily set flags, for use in `with` statements.
166 164
167 165 See set_flags for permanent setting of flags
168 166
169 167 Examples
170 168 --------
171 169
172 170 >>> view.track=False
173 171 ...
174 172 >>> with view.temp_flags(track=True):
175 173 ... ar = view.apply(dostuff, my_big_array)
176 174 ... ar.tracker.wait() # wait for send to finish
177 175 >>> view.track
178 176 False
179 177
180 178 """
181 179 # preflight: save flags, and set temporaries
182 180 saved_flags = {}
183 181 for f in self._flag_names:
184 182 saved_flags[f] = getattr(self, f)
185 183 self.set_flags(**kwargs)
186 184 # yield to the with-statement block
187 185 try:
188 186 yield
189 187 finally:
190 188 # postflight: restore saved flags
191 189 self.set_flags(**saved_flags)
192 190
193 191
194 192 #----------------------------------------------------------------
195 193 # apply
196 194 #----------------------------------------------------------------
197 195
198 196 def _sync_results(self):
199 197 """to be called by @sync_results decorator
200 198
201 199 after submitting any tasks.
202 200 """
203 201 delta = self.outstanding.difference(self.client.outstanding)
204 202 completed = self.outstanding.intersection(delta)
205 203 self.outstanding = self.outstanding.difference(completed)
206 204
207 205 @sync_results
208 206 @save_ids
209 207 def _really_apply(self, f, args, kwargs, block=None, **options):
210 208 """wrapper for client.send_apply_request"""
211 209 raise NotImplementedError("Implement in subclasses")
212 210
213 211 def apply(self, f, *args, **kwargs):
214 212 """calls ``f(*args, **kwargs)`` on remote engines, returning the result.
215 213
216 214 This method sets all apply flags via this View's attributes.
217 215
218 216 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult`
219 217 instance if ``self.block`` is False, otherwise the return value of
220 218 ``f(*args, **kwargs)``.
221 219 """
222 220 return self._really_apply(f, args, kwargs)
223 221
224 222 def apply_async(self, f, *args, **kwargs):
225 223 """calls ``f(*args, **kwargs)`` on remote engines in a nonblocking manner.
226 224
227 225 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult` instance.
228 226 """
229 227 return self._really_apply(f, args, kwargs, block=False)
230 228
231 229 @spin_after
232 230 def apply_sync(self, f, *args, **kwargs):
233 231 """calls ``f(*args, **kwargs)`` on remote engines in a blocking manner,
234 232 returning the result.
235 233 """
236 234 return self._really_apply(f, args, kwargs, block=True)
237 235
238 236 #----------------------------------------------------------------
239 237 # wrappers for client and control methods
240 238 #----------------------------------------------------------------
241 239 @sync_results
242 240 def spin(self):
243 241 """spin the client, and sync"""
244 242 self.client.spin()
245 243
246 244 @sync_results
247 245 def wait(self, jobs=None, timeout=-1):
248 246 """waits on one or more `jobs`, for up to `timeout` seconds.
249 247
250 248 Parameters
251 249 ----------
252 250
253 251 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
254 252 ints are indices to self.history
255 253 strs are msg_ids
256 254 default: wait on all outstanding messages
257 255 timeout : float
258 256 a time in seconds, after which to give up.
259 257 default is -1, which means no timeout
260 258
261 259 Returns
262 260 -------
263 261
264 262 True : when all msg_ids are done
265 263 False : timeout reached, some msg_ids still outstanding
266 264 """
267 265 if jobs is None:
268 266 jobs = self.history
269 267 return self.client.wait(jobs, timeout)
270 268
271 269 def abort(self, jobs=None, targets=None, block=None):
272 270 """Abort jobs on my engines.
273 271
274 272 Parameters
275 273 ----------
276 274
277 275 jobs : None, str, list of strs, optional
278 276 if None: abort all jobs.
279 277 else: abort specific msg_id(s).
280 278 """
281 279 block = block if block is not None else self.block
282 280 targets = targets if targets is not None else self.targets
283 281 jobs = jobs if jobs is not None else list(self.outstanding)
284 282
285 283 return self.client.abort(jobs=jobs, targets=targets, block=block)
286 284
287 285 def queue_status(self, targets=None, verbose=False):
288 286 """Fetch the Queue status of my engines"""
289 287 targets = targets if targets is not None else self.targets
290 288 return self.client.queue_status(targets=targets, verbose=verbose)
291 289
292 290 def purge_results(self, jobs=[], targets=[]):
293 291 """Instruct the controller to forget specific results."""
294 292 if targets is None or targets == 'all':
295 293 targets = self.targets
296 294 return self.client.purge_results(jobs=jobs, targets=targets)
297 295
298 296 def shutdown(self, targets=None, restart=False, hub=False, block=None):
299 297 """Terminates one or more engine processes, optionally including the hub.
300 298 """
301 299 block = self.block if block is None else block
302 300 if targets is None or targets == 'all':
303 301 targets = self.targets
304 302 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
305 303
306 304 @spin_after
307 305 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
308 306 """return one or more results, specified by history index or msg_id.
309 307
310 308 See :meth:`IPython.parallel.client.client.Client.get_result` for details.
311 309 """
312 310
313 311 if indices_or_msg_ids is None:
314 312 indices_or_msg_ids = -1
315 313 if isinstance(indices_or_msg_ids, int):
316 314 indices_or_msg_ids = self.history[indices_or_msg_ids]
317 315 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
318 316 indices_or_msg_ids = list(indices_or_msg_ids)
319 317 for i,index in enumerate(indices_or_msg_ids):
320 318 if isinstance(index, int):
321 319 indices_or_msg_ids[i] = self.history[index]
322 320 return self.client.get_result(indices_or_msg_ids, block=block, owner=owner)
323 321
324 322 #-------------------------------------------------------------------
325 323 # Map
326 324 #-------------------------------------------------------------------
327 325
328 326 @sync_results
329 327 def map(self, f, *sequences, **kwargs):
330 328 """override in subclasses"""
331 329 raise NotImplementedError
332 330
333 331 def map_async(self, f, *sequences, **kwargs):
334 332 """Parallel version of builtin :func:`python:map`, using this view's engines.
335 333
336 334 This is equivalent to ``map(...block=False)``.
337 335
338 336 See `self.map` for details.
339 337 """
340 338 if 'block' in kwargs:
341 339 raise TypeError("map_async doesn't take a `block` keyword argument.")
342 340 kwargs['block'] = False
343 341 return self.map(f,*sequences,**kwargs)
344 342
345 343 def map_sync(self, f, *sequences, **kwargs):
346 344 """Parallel version of builtin :func:`python:map`, using this view's engines.
347 345
348 346 This is equivalent to ``map(...block=True)``.
349 347
350 348 See `self.map` for details.
351 349 """
352 350 if 'block' in kwargs:
353 351 raise TypeError("map_sync doesn't take a `block` keyword argument.")
354 352 kwargs['block'] = True
355 353 return self.map(f,*sequences,**kwargs)
356 354
357 355 def imap(self, f, *sequences, **kwargs):
358 356 """Parallel version of :func:`itertools.imap`.
359 357
360 358 See `self.map` for details.
361 359
362 360 """
363 361
364 362 return iter(self.map_async(f,*sequences, **kwargs))
365 363
366 364 #-------------------------------------------------------------------
367 365 # Decorators
368 366 #-------------------------------------------------------------------
369 367
370 368 def remote(self, block=None, **flags):
371 369 """Decorator for making a RemoteFunction"""
372 370 block = self.block if block is None else block
373 371 return remote(self, block=block, **flags)
374 372
375 373 def parallel(self, dist='b', block=None, **flags):
376 374 """Decorator for making a ParallelFunction"""
377 375 block = self.block if block is None else block
378 376 return parallel(self, dist=dist, block=block, **flags)
379 377
380 @skip_doctest
381 378 class DirectView(View):
382 379 """Direct Multiplexer View of one or more engines.
383 380
384 381 These are created via indexed access to a client:
385 382
386 383 >>> dv_1 = client[1]
387 384 >>> dv_all = client[:]
388 385 >>> dv_even = client[::2]
389 386 >>> dv_some = client[1:3]
390 387
391 388 This object provides dictionary access to engine namespaces:
392 389
393 390 # push a=5:
394 391 >>> dv['a'] = 5
395 392 # pull 'foo':
396 393 >>> dv['foo']
397 394
398 395 """
399 396
400 397 def __init__(self, client=None, socket=None, targets=None):
401 398 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
402 399
403 400 @property
404 401 def importer(self):
405 402 """sync_imports(local=True) as a property.
406 403
407 404 See sync_imports for details.
408 405
409 406 """
410 407 return self.sync_imports(True)
411 408
412 409 @contextmanager
413 410 def sync_imports(self, local=True, quiet=False):
414 411 """Context Manager for performing simultaneous local and remote imports.
415 412
416 413 'import x as y' will *not* work. The 'as y' part will simply be ignored.
417 414
418 415 If `local=True`, then the package will also be imported locally.
419 416
420 417 If `quiet=True`, no output will be produced when attempting remote
421 418 imports.
422 419
423 420 Note that remote-only (`local=False`) imports have not been implemented.
424 421
425 422 >>> with view.sync_imports():
426 423 ... from numpy import recarray
427 424 importing recarray from numpy on engine(s)
428 425
429 426 """
430 427 from IPython.utils.py3compat import builtin_mod
431 428 local_import = builtin_mod.__import__
432 429 modules = set()
433 430 results = []
434 431 @util.interactive
435 432 def remote_import(name, fromlist, level):
436 433 """the function to be passed to apply, that actually performs the import
437 434 on the engine, and loads up the user namespace.
438 435 """
439 436 import sys
440 437 user_ns = globals()
441 438 mod = __import__(name, fromlist=fromlist, level=level)
442 439 if fromlist:
443 440 for key in fromlist:
444 441 user_ns[key] = getattr(mod, key)
445 442 else:
446 443 user_ns[name] = sys.modules[name]
447 444
448 445 def view_import(name, globals={}, locals={}, fromlist=[], level=0):
449 446 """the drop-in replacement for __import__, that optionally imports
450 447 locally as well.
451 448 """
452 449 # don't override nested imports
453 450 save_import = builtin_mod.__import__
454 451 builtin_mod.__import__ = local_import
455 452
456 453 if imp.lock_held():
457 454 # this is a side-effect import, don't do it remotely, or even
458 455 # ignore the local effects
459 456 return local_import(name, globals, locals, fromlist, level)
460 457
461 458 imp.acquire_lock()
462 459 if local:
463 460 mod = local_import(name, globals, locals, fromlist, level)
464 461 else:
465 462 raise NotImplementedError("remote-only imports not yet implemented")
466 463 imp.release_lock()
467 464
468 465 key = name+':'+','.join(fromlist or [])
469 466 if level <= 0 and key not in modules:
470 467 modules.add(key)
471 468 if not quiet:
472 469 if fromlist:
473 470 print("importing %s from %s on engine(s)"%(','.join(fromlist), name))
474 471 else:
475 472 print("importing %s on engine(s)"%name)
476 473 results.append(self.apply_async(remote_import, name, fromlist, level))
477 474 # restore override
478 475 builtin_mod.__import__ = save_import
479 476
480 477 return mod
481 478
482 479 # override __import__
483 480 builtin_mod.__import__ = view_import
484 481 try:
485 482 # enter the block
486 483 yield
487 484 except ImportError:
488 485 if local:
489 486 raise
490 487 else:
491 488 # ignore import errors if not doing local imports
492 489 pass
493 490 finally:
494 491 # always restore __import__
495 492 builtin_mod.__import__ = local_import
496 493
497 494 for r in results:
498 495 # raise possible remote ImportErrors here
499 496 r.get()
500 497
501 498 def use_dill(self):
502 499 """Expand serialization support with dill
503 500
504 501 adds support for closures, etc.
505 502
506 503 This calls ipython_kernel.pickleutil.use_dill() here and on each engine.
507 504 """
508 505 pickleutil.use_dill()
509 506 return self.apply(pickleutil.use_dill)
510 507
511 508 def use_cloudpickle(self):
512 509 """Expand serialization support with cloudpickle.
513 510 """
514 511 pickleutil.use_cloudpickle()
515 512 return self.apply(pickleutil.use_cloudpickle)
516 513
517 514
518 515 @sync_results
519 516 @save_ids
520 517 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
521 518 """calls f(*args, **kwargs) on remote engines, returning the result.
522 519
523 520 This method sets all of `apply`'s flags via this View's attributes.
524 521
525 522 Parameters
526 523 ----------
527 524
528 525 f : callable
529 526
530 527 args : list [default: empty]
531 528
532 529 kwargs : dict [default: empty]
533 530
534 531 targets : target list [default: self.targets]
535 532 where to run
536 533 block : bool [default: self.block]
537 534 whether to block
538 535 track : bool [default: self.track]
539 536 whether to ask zmq to track the message, for safe non-copying sends
540 537
541 538 Returns
542 539 -------
543 540
544 541 if self.block is False:
545 542 returns AsyncResult
546 543 else:
547 544 returns actual result of f(*args, **kwargs) on the engine(s)
548 545 This will be a list of self.targets is also a list (even length 1), or
549 546 the single result if self.targets is an integer engine id
550 547 """
551 548 args = [] if args is None else args
552 549 kwargs = {} if kwargs is None else kwargs
553 550 block = self.block if block is None else block
554 551 track = self.track if track is None else track
555 552 targets = self.targets if targets is None else targets
556 553
557 554 _idents, _targets = self.client._build_targets(targets)
558 555 msg_ids = []
559 556 trackers = []
560 557 for ident in _idents:
561 558 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
562 559 ident=ident)
563 560 if track:
564 561 trackers.append(msg['tracker'])
565 562 msg_ids.append(msg['header']['msg_id'])
566 563 if isinstance(targets, int):
567 564 msg_ids = msg_ids[0]
568 565 tracker = None if track is False else zmq.MessageTracker(*trackers)
569 566 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets,
570 567 tracker=tracker, owner=True,
571 568 )
572 569 if block:
573 570 try:
574 571 return ar.get()
575 572 except KeyboardInterrupt:
576 573 pass
577 574 return ar
578 575
579 576
580 577 @sync_results
581 578 def map(self, f, *sequences, **kwargs):
582 579 """``view.map(f, *sequences, block=self.block)`` => list|AsyncMapResult
583 580
584 581 Parallel version of builtin `map`, using this View's `targets`.
585 582
586 583 There will be one task per target, so work will be chunked
587 584 if the sequences are longer than `targets`.
588 585
589 586 Results can be iterated as they are ready, but will become available in chunks.
590 587
591 588 Parameters
592 589 ----------
593 590
594 591 f : callable
595 592 function to be mapped
596 593 *sequences: one or more sequences of matching length
597 594 the sequences to be distributed and passed to `f`
598 595 block : bool
599 596 whether to wait for the result or not [default self.block]
600 597
601 598 Returns
602 599 -------
603 600
604 601
605 602 If block=False
606 603 An :class:`~ipython_parallel.client.asyncresult.AsyncMapResult` instance.
607 604 An object like AsyncResult, but which reassembles the sequence of results
608 605 into a single list. AsyncMapResults can be iterated through before all
609 606 results are complete.
610 607 else
611 608 A list, the result of ``map(f,*sequences)``
612 609 """
613 610
614 611 block = kwargs.pop('block', self.block)
615 612 for k in kwargs.keys():
616 613 if k not in ['block', 'track']:
617 614 raise TypeError("invalid keyword arg, %r"%k)
618 615
619 616 assert len(sequences) > 0, "must have some sequences to map onto!"
620 617 pf = ParallelFunction(self, f, block=block, **kwargs)
621 618 return pf.map(*sequences)
622 619
623 620 @sync_results
624 621 @save_ids
625 622 def execute(self, code, silent=True, targets=None, block=None):
626 623 """Executes `code` on `targets` in blocking or nonblocking manner.
627 624
628 625 ``execute`` is always `bound` (affects engine namespace)
629 626
630 627 Parameters
631 628 ----------
632 629
633 630 code : str
634 631 the code string to be executed
635 632 block : bool
636 633 whether or not to wait until done to return
637 634 default: self.block
638 635 """
639 636 block = self.block if block is None else block
640 637 targets = self.targets if targets is None else targets
641 638
642 639 _idents, _targets = self.client._build_targets(targets)
643 640 msg_ids = []
644 641 trackers = []
645 642 for ident in _idents:
646 643 msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident)
647 644 msg_ids.append(msg['header']['msg_id'])
648 645 if isinstance(targets, int):
649 646 msg_ids = msg_ids[0]
650 647 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets, owner=True)
651 648 if block:
652 649 try:
653 650 ar.get()
654 651 except KeyboardInterrupt:
655 652 pass
656 653 return ar
657 654
658 655 def run(self, filename, targets=None, block=None):
659 656 """Execute contents of `filename` on my engine(s).
660 657
661 658 This simply reads the contents of the file and calls `execute`.
662 659
663 660 Parameters
664 661 ----------
665 662
666 663 filename : str
667 664 The path to the file
668 665 targets : int/str/list of ints/strs
669 666 the engines on which to execute
670 667 default : all
671 668 block : bool
672 669 whether or not to wait until done
673 670 default: self.block
674 671
675 672 """
676 673 with open(filename, 'r') as f:
677 674 # add newline in case of trailing indented whitespace
678 675 # which will cause SyntaxError
679 676 code = f.read()+'\n'
680 677 return self.execute(code, block=block, targets=targets)
681 678
682 679 def update(self, ns):
683 680 """update remote namespace with dict `ns`
684 681
685 682 See `push` for details.
686 683 """
687 684 return self.push(ns, block=self.block, track=self.track)
688 685
689 686 def push(self, ns, targets=None, block=None, track=None):
690 687 """update remote namespace with dict `ns`
691 688
692 689 Parameters
693 690 ----------
694 691
695 692 ns : dict
696 693 dict of keys with which to update engine namespace(s)
697 694 block : bool [default : self.block]
698 695 whether to wait to be notified of engine receipt
699 696
700 697 """
701 698
702 699 block = block if block is not None else self.block
703 700 track = track if track is not None else self.track
704 701 targets = targets if targets is not None else self.targets
705 702 # applier = self.apply_sync if block else self.apply_async
706 703 if not isinstance(ns, dict):
707 704 raise TypeError("Must be a dict, not %s"%type(ns))
708 705 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
709 706
710 707 def get(self, key_s):
711 708 """get object(s) by `key_s` from remote namespace
712 709
713 710 see `pull` for details.
714 711 """
715 712 # block = block if block is not None else self.block
716 713 return self.pull(key_s, block=True)
717 714
718 715 def pull(self, names, targets=None, block=None):
719 716 """get object(s) by `name` from remote namespace
720 717
721 718 will return one object if it is a key.
722 719 can also take a list of keys, in which case it will return a list of objects.
723 720 """
724 721 block = block if block is not None else self.block
725 722 targets = targets if targets is not None else self.targets
726 723 applier = self.apply_sync if block else self.apply_async
727 724 if isinstance(names, string_types):
728 725 pass
729 726 elif isinstance(names, (list,tuple,set)):
730 727 for key in names:
731 728 if not isinstance(key, string_types):
732 729 raise TypeError("keys must be str, not type %r"%type(key))
733 730 else:
734 731 raise TypeError("names must be strs, not %r"%names)
735 732 return self._really_apply(util._pull, (names,), block=block, targets=targets)
736 733
737 734 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
738 735 """
739 736 Partition a Python sequence and send the partitions to a set of engines.
740 737 """
741 738 block = block if block is not None else self.block
742 739 track = track if track is not None else self.track
743 740 targets = targets if targets is not None else self.targets
744 741
745 742 # construct integer ID list:
746 743 targets = self.client._build_targets(targets)[1]
747 744
748 745 mapObject = Map.dists[dist]()
749 746 nparts = len(targets)
750 747 msg_ids = []
751 748 trackers = []
752 749 for index, engineid in enumerate(targets):
753 750 partition = mapObject.getPartition(seq, index, nparts)
754 751 if flatten and len(partition) == 1:
755 752 ns = {key: partition[0]}
756 753 else:
757 754 ns = {key: partition}
758 755 r = self.push(ns, block=False, track=track, targets=engineid)
759 756 msg_ids.extend(r.msg_ids)
760 757 if track:
761 758 trackers.append(r._tracker)
762 759
763 760 if track:
764 761 tracker = zmq.MessageTracker(*trackers)
765 762 else:
766 763 tracker = None
767 764
768 765 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets,
769 766 tracker=tracker, owner=True,
770 767 )
771 768 if block:
772 769 r.wait()
773 770 else:
774 771 return r
775 772
776 773 @sync_results
777 774 @save_ids
778 775 def gather(self, key, dist='b', targets=None, block=None):
779 776 """
780 777 Gather a partitioned sequence on a set of engines as a single local seq.
781 778 """
782 779 block = block if block is not None else self.block
783 780 targets = targets if targets is not None else self.targets
784 781 mapObject = Map.dists[dist]()
785 782 msg_ids = []
786 783
787 784 # construct integer ID list:
788 785 targets = self.client._build_targets(targets)[1]
789 786
790 787 for index, engineid in enumerate(targets):
791 788 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
792 789
793 790 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
794 791
795 792 if block:
796 793 try:
797 794 return r.get()
798 795 except KeyboardInterrupt:
799 796 pass
800 797 return r
801 798
802 799 def __getitem__(self, key):
803 800 return self.get(key)
804 801
805 802 def __setitem__(self,key, value):
806 803 self.update({key:value})
807 804
808 805 def clear(self, targets=None, block=None):
809 806 """Clear the remote namespaces on my engines."""
810 807 block = block if block is not None else self.block
811 808 targets = targets if targets is not None else self.targets
812 809 return self.client.clear(targets=targets, block=block)
813 810
814 811 #----------------------------------------
815 812 # activate for %px, %autopx, etc. magics
816 813 #----------------------------------------
817 814
818 815 def activate(self, suffix=''):
819 816 """Activate IPython magics associated with this View
820 817
821 818 Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig`
822 819
823 820 Parameters
824 821 ----------
825 822
826 823 suffix: str [default: '']
827 824 The suffix, if any, for the magics. This allows you to have
828 825 multiple views associated with parallel magics at the same time.
829 826
830 827 e.g. ``rc[::2].activate(suffix='_even')`` will give you
831 828 the magics ``%px_even``, ``%pxresult_even``, etc. for running magics
832 829 on the even engines.
833 830 """
834 831
835 832 from IPython.parallel.client.magics import ParallelMagics
836 833
837 834 try:
838 835 # This is injected into __builtins__.
839 836 ip = get_ipython()
840 837 except NameError:
841 838 print("The IPython parallel magics (%px, etc.) only work within IPython.")
842 839 return
843 840
844 841 M = ParallelMagics(ip, self, suffix)
845 842 ip.magics_manager.register(M)
846 843
847 844
848 @skip_doctest
849 845 class LoadBalancedView(View):
850 846 """An load-balancing View that only executes via the Task scheduler.
851 847
852 848 Load-balanced views can be created with the client's `view` method:
853 849
854 850 >>> v = client.load_balanced_view()
855 851
856 852 or targets can be specified, to restrict the potential destinations:
857 853
858 854 >>> v = client.load_balanced_view([1,3])
859 855
860 856 which would restrict loadbalancing to between engines 1 and 3.
861 857
862 858 """
863 859
864 860 follow=Any()
865 861 after=Any()
866 862 timeout=CFloat()
867 863 retries = Integer(0)
868 864
869 865 _task_scheme = Any()
870 866 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
871 867
872 868 def __init__(self, client=None, socket=None, **flags):
873 869 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
874 870 self._task_scheme=client._task_scheme
875 871
876 872 def _validate_dependency(self, dep):
877 873 """validate a dependency.
878 874
879 875 For use in `set_flags`.
880 876 """
881 877 if dep is None or isinstance(dep, string_types + (AsyncResult, Dependency)):
882 878 return True
883 879 elif isinstance(dep, (list,set, tuple)):
884 880 for d in dep:
885 881 if not isinstance(d, string_types + (AsyncResult,)):
886 882 return False
887 883 elif isinstance(dep, dict):
888 884 if set(dep.keys()) != set(Dependency().as_dict().keys()):
889 885 return False
890 886 if not isinstance(dep['msg_ids'], list):
891 887 return False
892 888 for d in dep['msg_ids']:
893 889 if not isinstance(d, string_types):
894 890 return False
895 891 else:
896 892 return False
897 893
898 894 return True
899 895
900 896 def _render_dependency(self, dep):
901 897 """helper for building jsonable dependencies from various input forms."""
902 898 if isinstance(dep, Dependency):
903 899 return dep.as_dict()
904 900 elif isinstance(dep, AsyncResult):
905 901 return dep.msg_ids
906 902 elif dep is None:
907 903 return []
908 904 else:
909 905 # pass to Dependency constructor
910 906 return list(Dependency(dep))
911 907
912 908 def set_flags(self, **kwargs):
913 909 """set my attribute flags by keyword.
914 910
915 911 A View is a wrapper for the Client's apply method, but with attributes
916 912 that specify keyword arguments, those attributes can be set by keyword
917 913 argument with this method.
918 914
919 915 Parameters
920 916 ----------
921 917
922 918 block : bool
923 919 whether to wait for results
924 920 track : bool
925 921 whether to create a MessageTracker to allow the user to
926 922 safely edit after arrays and buffers during non-copying
927 923 sends.
928 924
929 925 after : Dependency or collection of msg_ids
930 926 Only for load-balanced execution (targets=None)
931 927 Specify a list of msg_ids as a time-based dependency.
932 928 This job will only be run *after* the dependencies
933 929 have been met.
934 930
935 931 follow : Dependency or collection of msg_ids
936 932 Only for load-balanced execution (targets=None)
937 933 Specify a list of msg_ids as a location-based dependency.
938 934 This job will only be run on an engine where this dependency
939 935 is met.
940 936
941 937 timeout : float/int or None
942 938 Only for load-balanced execution (targets=None)
943 939 Specify an amount of time (in seconds) for the scheduler to
944 940 wait for dependencies to be met before failing with a
945 941 DependencyTimeout.
946 942
947 943 retries : int
948 944 Number of times a task will be retried on failure.
949 945 """
950 946
951 947 super(LoadBalancedView, self).set_flags(**kwargs)
952 948 for name in ('follow', 'after'):
953 949 if name in kwargs:
954 950 value = kwargs[name]
955 951 if self._validate_dependency(value):
956 952 setattr(self, name, value)
957 953 else:
958 954 raise ValueError("Invalid dependency: %r"%value)
959 955 if 'timeout' in kwargs:
960 956 t = kwargs['timeout']
961 957 if not isinstance(t, (int, float, type(None))):
962 958 if (not PY3) and (not isinstance(t, long)):
963 959 raise TypeError("Invalid type for timeout: %r"%type(t))
964 960 if t is not None:
965 961 if t < 0:
966 962 raise ValueError("Invalid timeout: %s"%t)
967 963 self.timeout = t
968 964
969 965 @sync_results
970 966 @save_ids
971 967 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
972 968 after=None, follow=None, timeout=None,
973 969 targets=None, retries=None):
974 970 """calls f(*args, **kwargs) on a remote engine, returning the result.
975 971
976 972 This method temporarily sets all of `apply`'s flags for a single call.
977 973
978 974 Parameters
979 975 ----------
980 976
981 977 f : callable
982 978
983 979 args : list [default: empty]
984 980
985 981 kwargs : dict [default: empty]
986 982
987 983 block : bool [default: self.block]
988 984 whether to block
989 985 track : bool [default: self.track]
990 986 whether to ask zmq to track the message, for safe non-copying sends
991 987
992 988 !!!!!! TODO: THE REST HERE !!!!
993 989
994 990 Returns
995 991 -------
996 992
997 993 if self.block is False:
998 994 returns AsyncResult
999 995 else:
1000 996 returns actual result of f(*args, **kwargs) on the engine(s)
1001 997 This will be a list of self.targets is also a list (even length 1), or
1002 998 the single result if self.targets is an integer engine id
1003 999 """
1004 1000
1005 1001 # validate whether we can run
1006 1002 if self._socket.closed:
1007 1003 msg = "Task farming is disabled"
1008 1004 if self._task_scheme == 'pure':
1009 1005 msg += " because the pure ZMQ scheduler cannot handle"
1010 1006 msg += " disappearing engines."
1011 1007 raise RuntimeError(msg)
1012 1008
1013 1009 if self._task_scheme == 'pure':
1014 1010 # pure zmq scheme doesn't support extra features
1015 1011 msg = "Pure ZMQ scheduler doesn't support the following flags:"
1016 1012 "follow, after, retries, targets, timeout"
1017 1013 if (follow or after or retries or targets or timeout):
1018 1014 # hard fail on Scheduler flags
1019 1015 raise RuntimeError(msg)
1020 1016 if isinstance(f, dependent):
1021 1017 # soft warn on functional dependencies
1022 1018 warnings.warn(msg, RuntimeWarning)
1023 1019
1024 1020 # build args
1025 1021 args = [] if args is None else args
1026 1022 kwargs = {} if kwargs is None else kwargs
1027 1023 block = self.block if block is None else block
1028 1024 track = self.track if track is None else track
1029 1025 after = self.after if after is None else after
1030 1026 retries = self.retries if retries is None else retries
1031 1027 follow = self.follow if follow is None else follow
1032 1028 timeout = self.timeout if timeout is None else timeout
1033 1029 targets = self.targets if targets is None else targets
1034 1030
1035 1031 if not isinstance(retries, int):
1036 1032 raise TypeError('retries must be int, not %r'%type(retries))
1037 1033
1038 1034 if targets is None:
1039 1035 idents = []
1040 1036 else:
1041 1037 idents = self.client._build_targets(targets)[0]
1042 1038 # ensure *not* bytes
1043 1039 idents = [ ident.decode() for ident in idents ]
1044 1040
1045 1041 after = self._render_dependency(after)
1046 1042 follow = self._render_dependency(follow)
1047 1043 metadata = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
1048 1044
1049 1045 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
1050 1046 metadata=metadata)
1051 1047 tracker = None if track is False else msg['tracker']
1052 1048
1053 1049 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f),
1054 1050 targets=None, tracker=tracker, owner=True,
1055 1051 )
1056 1052 if block:
1057 1053 try:
1058 1054 return ar.get()
1059 1055 except KeyboardInterrupt:
1060 1056 pass
1061 1057 return ar
1062 1058
1063 1059 @sync_results
1064 1060 @save_ids
1065 1061 def map(self, f, *sequences, **kwargs):
1066 1062 """``view.map(f, *sequences, block=self.block, chunksize=1, ordered=True)`` => list|AsyncMapResult
1067 1063
1068 1064 Parallel version of builtin `map`, load-balanced by this View.
1069 1065
1070 1066 `block`, and `chunksize` can be specified by keyword only.
1071 1067
1072 1068 Each `chunksize` elements will be a separate task, and will be
1073 1069 load-balanced. This lets individual elements be available for iteration
1074 1070 as soon as they arrive.
1075 1071
1076 1072 Parameters
1077 1073 ----------
1078 1074
1079 1075 f : callable
1080 1076 function to be mapped
1081 1077 *sequences: one or more sequences of matching length
1082 1078 the sequences to be distributed and passed to `f`
1083 1079 block : bool [default self.block]
1084 1080 whether to wait for the result or not
1085 1081 track : bool
1086 1082 whether to create a MessageTracker to allow the user to
1087 1083 safely edit after arrays and buffers during non-copying
1088 1084 sends.
1089 1085 chunksize : int [default 1]
1090 1086 how many elements should be in each task.
1091 1087 ordered : bool [default True]
1092 1088 Whether the results should be gathered as they arrive, or enforce
1093 1089 the order of submission.
1094 1090
1095 1091 Only applies when iterating through AsyncMapResult as results arrive.
1096 1092 Has no effect when block=True.
1097 1093
1098 1094 Returns
1099 1095 -------
1100 1096
1101 1097 if block=False
1102 1098 An :class:`~ipython_parallel.client.asyncresult.AsyncMapResult` instance.
1103 1099 An object like AsyncResult, but which reassembles the sequence of results
1104 1100 into a single list. AsyncMapResults can be iterated through before all
1105 1101 results are complete.
1106 1102 else
1107 1103 A list, the result of ``map(f,*sequences)``
1108 1104 """
1109 1105
1110 1106 # default
1111 1107 block = kwargs.get('block', self.block)
1112 1108 chunksize = kwargs.get('chunksize', 1)
1113 1109 ordered = kwargs.get('ordered', True)
1114 1110
1115 1111 keyset = set(kwargs.keys())
1116 1112 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1117 1113 if extra_keys:
1118 1114 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1119 1115
1120 1116 assert len(sequences) > 0, "must have some sequences to map onto!"
1121 1117
1122 1118 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1123 1119 return pf.map(*sequences)
1124 1120
1125 1121 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,1878 +1,1875 b''
1 1 # encoding: utf-8
2 2 """
3 3 A lightweight Traits like module.
4 4
5 5 This is designed to provide a lightweight, simple, pure Python version of
6 6 many of the capabilities of enthought.traits. This includes:
7 7
8 8 * Validation
9 9 * Type specification with defaults
10 10 * Static and dynamic notification
11 11 * Basic predefined types
12 12 * An API that is similar to enthought.traits
13 13
14 14 We don't support:
15 15
16 16 * Delegation
17 17 * Automatic GUI generation
18 18 * A full set of trait types. Most importantly, we don't provide container
19 19 traits (list, dict, tuple) that can trigger notifications if their
20 20 contents change.
21 21 * API compatibility with enthought.traits
22 22
23 23 There are also some important difference in our design:
24 24
25 25 * enthought.traits does not validate default values. We do.
26 26
27 27 We choose to create this module because we need these capabilities, but
28 28 we need them to be pure Python so they work in all Python implementations,
29 29 including Jython and IronPython.
30 30
31 31 Inheritance diagram:
32 32
33 33 .. inheritance-diagram:: IPython.utils.traitlets
34 34 :parts: 3
35 35 """
36 36
37 37 # Copyright (c) IPython Development Team.
38 38 # Distributed under the terms of the Modified BSD License.
39 39 #
40 40 # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
41 41 # also under the terms of the Modified BSD License.
42 42
43 43 import contextlib
44 44 import inspect
45 45 import re
46 46 import sys
47 47 import types
48 48 from types import FunctionType
49 49 try:
50 50 from types import ClassType, InstanceType
51 51 ClassTypes = (ClassType, type)
52 52 except:
53 53 ClassTypes = (type,)
54 54 from warnings import warn
55 55
56 56 from IPython.utils import py3compat
57 57 from IPython.utils import eventful
58 58 from IPython.utils.getargspec import getargspec
59 59 from IPython.utils.importstring import import_item
60 60 from IPython.utils.py3compat import iteritems, string_types
61 from IPython.testing.skipdoctest import skip_doctest
62 61
63 62 from .sentinel import Sentinel
64 63 SequenceTypes = (list, tuple, set, frozenset)
65 64
66 65 #-----------------------------------------------------------------------------
67 66 # Basic classes
68 67 #-----------------------------------------------------------------------------
69 68
70 69
71 70 NoDefaultSpecified = Sentinel('NoDefaultSpecified', __name__,
72 71 '''
73 72 Used in Traitlets to specify that no defaults are set in kwargs
74 73 '''
75 74 )
76 75
77 76
78 77 class Undefined ( object ): pass
79 78 Undefined = Undefined()
80 79
81 80 class TraitError(Exception):
82 81 pass
83 82
84 83 #-----------------------------------------------------------------------------
85 84 # Utilities
86 85 #-----------------------------------------------------------------------------
87 86
88 87
89 88 def class_of ( object ):
90 89 """ Returns a string containing the class name of an object with the
91 90 correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image',
92 91 'a PlotValue').
93 92 """
94 93 if isinstance( object, py3compat.string_types ):
95 94 return add_article( object )
96 95
97 96 return add_article( object.__class__.__name__ )
98 97
99 98
100 99 def add_article ( name ):
101 100 """ Returns a string containing the correct indefinite article ('a' or 'an')
102 101 prefixed to the specified string.
103 102 """
104 103 if name[:1].lower() in 'aeiou':
105 104 return 'an ' + name
106 105
107 106 return 'a ' + name
108 107
109 108
110 109 def repr_type(obj):
111 110 """ Return a string representation of a value and its type for readable
112 111 error messages.
113 112 """
114 113 the_type = type(obj)
115 114 if (not py3compat.PY3) and the_type is InstanceType:
116 115 # Old-style class.
117 116 the_type = obj.__class__
118 117 msg = '%r %r' % (obj, the_type)
119 118 return msg
120 119
121 120
122 121 def is_trait(t):
123 122 """ Returns whether the given value is an instance or subclass of TraitType.
124 123 """
125 124 return (isinstance(t, TraitType) or
126 125 (isinstance(t, type) and issubclass(t, TraitType)))
127 126
128 127
129 128 def parse_notifier_name(name):
130 129 """Convert the name argument to a list of names.
131 130
132 131 Examples
133 132 --------
134 133
135 134 >>> parse_notifier_name('a')
136 135 ['a']
137 136 >>> parse_notifier_name(['a','b'])
138 137 ['a', 'b']
139 138 >>> parse_notifier_name(None)
140 139 ['anytrait']
141 140 """
142 141 if isinstance(name, string_types):
143 142 return [name]
144 143 elif name is None:
145 144 return ['anytrait']
146 145 elif isinstance(name, (list, tuple)):
147 146 for n in name:
148 147 assert isinstance(n, string_types), "names must be strings"
149 148 return name
150 149
151 150
152 151 class _SimpleTest:
153 152 def __init__ ( self, value ): self.value = value
154 153 def __call__ ( self, test ):
155 154 return test == self.value
156 155 def __repr__(self):
157 156 return "<SimpleTest(%r)" % self.value
158 157 def __str__(self):
159 158 return self.__repr__()
160 159
161 160
162 161 def getmembers(object, predicate=None):
163 162 """A safe version of inspect.getmembers that handles missing attributes.
164 163
165 164 This is useful when there are descriptor based attributes that for
166 165 some reason raise AttributeError even though they exist. This happens
167 166 in zope.inteface with the __provides__ attribute.
168 167 """
169 168 results = []
170 169 for key in dir(object):
171 170 try:
172 171 value = getattr(object, key)
173 172 except AttributeError:
174 173 pass
175 174 else:
176 175 if not predicate or predicate(value):
177 176 results.append((key, value))
178 177 results.sort()
179 178 return results
180 179
181 180 def _validate_link(*tuples):
182 181 """Validate arguments for traitlet link functions"""
183 182 for t in tuples:
184 183 if not len(t) == 2:
185 184 raise TypeError("Each linked traitlet must be specified as (HasTraits, 'trait_name'), not %r" % t)
186 185 obj, trait_name = t
187 186 if not isinstance(obj, HasTraits):
188 187 raise TypeError("Each object must be HasTraits, not %r" % type(obj))
189 188 if not trait_name in obj.traits():
190 189 raise TypeError("%r has no trait %r" % (obj, trait_name))
191 190
192 @skip_doctest
193 191 class link(object):
194 192 """Link traits from different objects together so they remain in sync.
195 193
196 194 Parameters
197 195 ----------
198 196 *args : pairs of objects/attributes
199 197
200 198 Examples
201 199 --------
202 200
203 201 >>> c = link((obj1, 'value'), (obj2, 'value'), (obj3, 'value'))
204 202 >>> obj1.value = 5 # updates other objects as well
205 203 """
206 204 updating = False
207 205 def __init__(self, *args):
208 206 if len(args) < 2:
209 207 raise TypeError('At least two traitlets must be provided.')
210 208 _validate_link(*args)
211 209
212 210 self.objects = {}
213 211
214 212 initial = getattr(args[0][0], args[0][1])
215 213 for obj, attr in args:
216 214 setattr(obj, attr, initial)
217 215
218 216 callback = self._make_closure(obj, attr)
219 217 obj.on_trait_change(callback, attr)
220 218 self.objects[(obj, attr)] = callback
221 219
222 220 @contextlib.contextmanager
223 221 def _busy_updating(self):
224 222 self.updating = True
225 223 try:
226 224 yield
227 225 finally:
228 226 self.updating = False
229 227
230 228 def _make_closure(self, sending_obj, sending_attr):
231 229 def update(name, old, new):
232 230 self._update(sending_obj, sending_attr, new)
233 231 return update
234 232
235 233 def _update(self, sending_obj, sending_attr, new):
236 234 if self.updating:
237 235 return
238 236 with self._busy_updating():
239 237 for obj, attr in self.objects.keys():
240 238 setattr(obj, attr, new)
241 239
242 240 def unlink(self):
243 241 for key, callback in self.objects.items():
244 242 (obj, attr) = key
245 243 obj.on_trait_change(callback, attr, remove=True)
246 244
247 @skip_doctest
248 245 class directional_link(object):
249 246 """Link the trait of a source object with traits of target objects.
250 247
251 248 Parameters
252 249 ----------
253 250 source : pair of object, name
254 251 targets : pairs of objects/attributes
255 252
256 253 Examples
257 254 --------
258 255
259 256 >>> c = directional_link((src, 'value'), (tgt1, 'value'), (tgt2, 'value'))
260 257 >>> src.value = 5 # updates target objects
261 258 >>> tgt1.value = 6 # does not update other objects
262 259 """
263 260 updating = False
264 261
265 262 def __init__(self, source, *targets):
266 263 if len(targets) < 1:
267 264 raise TypeError('At least two traitlets must be provided.')
268 265 _validate_link(source, *targets)
269 266 self.source = source
270 267 self.targets = targets
271 268
272 269 # Update current value
273 270 src_attr_value = getattr(source[0], source[1])
274 271 for obj, attr in targets:
275 272 setattr(obj, attr, src_attr_value)
276 273
277 274 # Wire
278 275 self.source[0].on_trait_change(self._update, self.source[1])
279 276
280 277 @contextlib.contextmanager
281 278 def _busy_updating(self):
282 279 self.updating = True
283 280 try:
284 281 yield
285 282 finally:
286 283 self.updating = False
287 284
288 285 def _update(self, name, old, new):
289 286 if self.updating:
290 287 return
291 288 with self._busy_updating():
292 289 for obj, attr in self.targets:
293 290 setattr(obj, attr, new)
294 291
295 292 def unlink(self):
296 293 self.source[0].on_trait_change(self._update, self.source[1], remove=True)
297 294 self.source = None
298 295 self.targets = []
299 296
300 297 dlink = directional_link
301 298
302 299
303 300 #-----------------------------------------------------------------------------
304 301 # Base TraitType for all traits
305 302 #-----------------------------------------------------------------------------
306 303
307 304
308 305 class TraitType(object):
309 306 """A base class for all trait descriptors.
310 307
311 308 Notes
312 309 -----
313 310 Our implementation of traits is based on Python's descriptor
314 311 prototol. This class is the base class for all such descriptors. The
315 312 only magic we use is a custom metaclass for the main :class:`HasTraits`
316 313 class that does the following:
317 314
318 315 1. Sets the :attr:`name` attribute of every :class:`TraitType`
319 316 instance in the class dict to the name of the attribute.
320 317 2. Sets the :attr:`this_class` attribute of every :class:`TraitType`
321 318 instance in the class dict to the *class* that declared the trait.
322 319 This is used by the :class:`This` trait to allow subclasses to
323 320 accept superclasses for :class:`This` values.
324 321 """
325 322
326 323 metadata = {}
327 324 default_value = Undefined
328 325 allow_none = False
329 326 info_text = 'any value'
330 327
331 328 def __init__(self, default_value=NoDefaultSpecified, allow_none=None, **metadata):
332 329 """Create a TraitType.
333 330 """
334 331 if default_value is not NoDefaultSpecified:
335 332 self.default_value = default_value
336 333 if allow_none is not None:
337 334 self.allow_none = allow_none
338 335
339 336 if 'default' in metadata:
340 337 # Warn the user that they probably meant default_value.
341 338 warn(
342 339 "Parameter 'default' passed to TraitType. "
343 340 "Did you mean 'default_value'?"
344 341 )
345 342
346 343 if len(metadata) > 0:
347 344 if len(self.metadata) > 0:
348 345 self._metadata = self.metadata.copy()
349 346 self._metadata.update(metadata)
350 347 else:
351 348 self._metadata = metadata
352 349 else:
353 350 self._metadata = self.metadata
354 351
355 352 self.init()
356 353
357 354 def init(self):
358 355 pass
359 356
360 357 def get_default_value(self):
361 358 """Create a new instance of the default value."""
362 359 return self.default_value
363 360
364 361 def instance_init(self):
365 362 """Part of the initialization which may depends on the underlying
366 363 HasTraits instance.
367 364
368 365 It is typically overloaded for specific trait types.
369 366
370 367 This method is called by :meth:`HasTraits.__new__` and in the
371 368 :meth:`TraitType.instance_init` method of trait types holding
372 369 other trait types.
373 370 """
374 371 pass
375 372
376 373 def init_default_value(self, obj):
377 374 """Instantiate the default value for the trait type.
378 375
379 376 This method is called by :meth:`TraitType.set_default_value` in the
380 377 case a default value is provided at construction time or later when
381 378 accessing the trait value for the first time in
382 379 :meth:`HasTraits.__get__`.
383 380 """
384 381 value = self.get_default_value()
385 382 value = self._validate(obj, value)
386 383 obj._trait_values[self.name] = value
387 384 return value
388 385
389 386 def set_default_value(self, obj):
390 387 """Set the default value on a per instance basis.
391 388
392 389 This method is called by :meth:`HasTraits.__new__` to instantiate and
393 390 validate the default value. The creation and validation of
394 391 default values must be delayed until the parent :class:`HasTraits`
395 392 class has been instantiated.
396 393 Parameters
397 394 ----------
398 395 obj : :class:`HasTraits` instance
399 396 The parent :class:`HasTraits` instance that has just been
400 397 created.
401 398 """
402 399 # Check for a deferred initializer defined in the same class as the
403 400 # trait declaration or above.
404 401 mro = type(obj).mro()
405 402 meth_name = '_%s_default' % self.name
406 403 for cls in mro[:mro.index(self.this_class)+1]:
407 404 if meth_name in cls.__dict__:
408 405 break
409 406 else:
410 407 # We didn't find one. Do static initialization.
411 408 self.init_default_value(obj)
412 409 return
413 410 # Complete the dynamic initialization.
414 411 obj._trait_dyn_inits[self.name] = meth_name
415 412
416 413 def __get__(self, obj, cls=None):
417 414 """Get the value of the trait by self.name for the instance.
418 415
419 416 Default values are instantiated when :meth:`HasTraits.__new__`
420 417 is called. Thus by the time this method gets called either the
421 418 default value or a user defined value (they called :meth:`__set__`)
422 419 is in the :class:`HasTraits` instance.
423 420 """
424 421 if obj is None:
425 422 return self
426 423 else:
427 424 try:
428 425 value = obj._trait_values[self.name]
429 426 except KeyError:
430 427 # Check for a dynamic initializer.
431 428 if self.name in obj._trait_dyn_inits:
432 429 method = getattr(obj, obj._trait_dyn_inits[self.name])
433 430 value = method()
434 431 # FIXME: Do we really validate here?
435 432 value = self._validate(obj, value)
436 433 obj._trait_values[self.name] = value
437 434 return value
438 435 else:
439 436 return self.init_default_value(obj)
440 437 except Exception:
441 438 # HasTraits should call set_default_value to populate
442 439 # this. So this should never be reached.
443 440 raise TraitError('Unexpected error in TraitType: '
444 441 'default value not set properly')
445 442 else:
446 443 return value
447 444
448 445 def __set__(self, obj, value):
449 446 new_value = self._validate(obj, value)
450 447 try:
451 448 old_value = obj._trait_values[self.name]
452 449 except KeyError:
453 450 old_value = Undefined
454 451
455 452 obj._trait_values[self.name] = new_value
456 453 try:
457 454 silent = bool(old_value == new_value)
458 455 except:
459 456 # if there is an error in comparing, default to notify
460 457 silent = False
461 458 if silent is not True:
462 459 # we explicitly compare silent to True just in case the equality
463 460 # comparison above returns something other than True/False
464 461 obj._notify_trait(self.name, old_value, new_value)
465 462
466 463 def _validate(self, obj, value):
467 464 if value is None and self.allow_none:
468 465 return value
469 466 if hasattr(self, 'validate'):
470 467 value = self.validate(obj, value)
471 468 if obj._cross_validation_lock is False:
472 469 value = self._cross_validate(obj, value)
473 470 return value
474 471
475 472 def _cross_validate(self, obj, value):
476 473 if hasattr(obj, '_%s_validate' % self.name):
477 474 cross_validate = getattr(obj, '_%s_validate' % self.name)
478 475 value = cross_validate(value, self)
479 476 return value
480 477
481 478 def __or__(self, other):
482 479 if isinstance(other, Union):
483 480 return Union([self] + other.trait_types)
484 481 else:
485 482 return Union([self, other])
486 483
487 484 def info(self):
488 485 return self.info_text
489 486
490 487 def error(self, obj, value):
491 488 if obj is not None:
492 489 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
493 490 % (self.name, class_of(obj),
494 491 self.info(), repr_type(value))
495 492 else:
496 493 e = "The '%s' trait must be %s, but a value of %r was specified." \
497 494 % (self.name, self.info(), repr_type(value))
498 495 raise TraitError(e)
499 496
500 497 def get_metadata(self, key, default=None):
501 498 return getattr(self, '_metadata', {}).get(key, default)
502 499
503 500 def set_metadata(self, key, value):
504 501 getattr(self, '_metadata', {})[key] = value
505 502
506 503
507 504 #-----------------------------------------------------------------------------
508 505 # The HasTraits implementation
509 506 #-----------------------------------------------------------------------------
510 507
511 508
512 509 class MetaHasTraits(type):
513 510 """A metaclass for HasTraits.
514 511
515 512 This metaclass makes sure that any TraitType class attributes are
516 513 instantiated and sets their name attribute.
517 514 """
518 515
519 516 def __new__(mcls, name, bases, classdict):
520 517 """Create the HasTraits class.
521 518
522 519 This instantiates all TraitTypes in the class dict and sets their
523 520 :attr:`name` attribute.
524 521 """
525 522 # print "MetaHasTraitlets (mcls, name): ", mcls, name
526 523 # print "MetaHasTraitlets (bases): ", bases
527 524 # print "MetaHasTraitlets (classdict): ", classdict
528 525 for k,v in iteritems(classdict):
529 526 if isinstance(v, TraitType):
530 527 v.name = k
531 528 elif inspect.isclass(v):
532 529 if issubclass(v, TraitType):
533 530 vinst = v()
534 531 vinst.name = k
535 532 classdict[k] = vinst
536 533 return super(MetaHasTraits, mcls).__new__(mcls, name, bases, classdict)
537 534
538 535 def __init__(cls, name, bases, classdict):
539 536 """Finish initializing the HasTraits class.
540 537
541 538 This sets the :attr:`this_class` attribute of each TraitType in the
542 539 class dict to the newly created class ``cls``.
543 540 """
544 541 for k, v in iteritems(classdict):
545 542 if isinstance(v, TraitType):
546 543 v.this_class = cls
547 544 super(MetaHasTraits, cls).__init__(name, bases, classdict)
548 545
549 546
550 547 class HasTraits(py3compat.with_metaclass(MetaHasTraits, object)):
551 548
552 549 def __new__(cls, *args, **kw):
553 550 # This is needed because object.__new__ only accepts
554 551 # the cls argument.
555 552 new_meth = super(HasTraits, cls).__new__
556 553 if new_meth is object.__new__:
557 554 inst = new_meth(cls)
558 555 else:
559 556 inst = new_meth(cls, **kw)
560 557 inst._trait_values = {}
561 558 inst._trait_notifiers = {}
562 559 inst._trait_dyn_inits = {}
563 560 inst._cross_validation_lock = True
564 561 # Here we tell all the TraitType instances to set their default
565 562 # values on the instance.
566 563 for key in dir(cls):
567 564 # Some descriptors raise AttributeError like zope.interface's
568 565 # __provides__ attributes even though they exist. This causes
569 566 # AttributeErrors even though they are listed in dir(cls).
570 567 try:
571 568 value = getattr(cls, key)
572 569 except AttributeError:
573 570 pass
574 571 else:
575 572 if isinstance(value, TraitType):
576 573 value.instance_init()
577 574 if key not in kw:
578 575 value.set_default_value(inst)
579 576 inst._cross_validation_lock = False
580 577 return inst
581 578
582 579 def __init__(self, *args, **kw):
583 580 # Allow trait values to be set using keyword arguments.
584 581 # We need to use setattr for this to trigger validation and
585 582 # notifications.
586 583 with self.hold_trait_notifications():
587 584 for key, value in iteritems(kw):
588 585 setattr(self, key, value)
589 586
590 587 @contextlib.contextmanager
591 588 def hold_trait_notifications(self):
592 589 """Context manager for bundling trait change notifications and cross
593 590 validation.
594 591
595 592 Use this when doing multiple trait assignments (init, config), to avoid
596 593 race conditions in trait notifiers requesting other trait values.
597 594 All trait notifications will fire after all values have been assigned.
598 595 """
599 596 if self._cross_validation_lock is True:
600 597 yield
601 598 return
602 599 else:
603 600 self._cross_validation_lock = True
604 601 cache = {}
605 602 notifications = {}
606 603 _notify_trait = self._notify_trait
607 604
608 605 def cache_values(*a):
609 606 cache[a[0]] = a
610 607
611 608 def hold_notifications(*a):
612 609 notifications[a[0]] = a
613 610
614 611 self._notify_trait = cache_values
615 612
616 613 try:
617 614 yield
618 615 finally:
619 616 try:
620 617 self._notify_trait = hold_notifications
621 618 for name in cache:
622 619 if hasattr(self, '_%s_validate' % name):
623 620 cross_validate = getattr(self, '_%s_validate' % name)
624 621 setattr(self, name, cross_validate(getattr(self, name), self))
625 622 except TraitError as e:
626 623 self._notify_trait = lambda *x: None
627 624 for name in cache:
628 625 if cache[name][1] is not Undefined:
629 626 setattr(self, name, cache[name][1])
630 627 else:
631 628 delattr(self, name)
632 629 cache = {}
633 630 notifications = {}
634 631 raise e
635 632 finally:
636 633 self._notify_trait = _notify_trait
637 634 self._cross_validation_lock = False
638 635 if isinstance(_notify_trait, types.MethodType):
639 636 # FIXME: remove when support is bumped to 3.4.
640 637 # when original method is restored,
641 638 # remove the redundant value from __dict__
642 639 # (only used to preserve pickleability on Python < 3.4)
643 640 self.__dict__.pop('_notify_trait', None)
644 641 # trigger delayed notifications
645 642 for v in dict(cache, **notifications).values():
646 643 self._notify_trait(*v)
647 644
648 645 def _notify_trait(self, name, old_value, new_value):
649 646
650 647 # First dynamic ones
651 648 callables = []
652 649 callables.extend(self._trait_notifiers.get(name,[]))
653 650 callables.extend(self._trait_notifiers.get('anytrait',[]))
654 651
655 652 # Now static ones
656 653 try:
657 654 cb = getattr(self, '_%s_changed' % name)
658 655 except:
659 656 pass
660 657 else:
661 658 callables.append(cb)
662 659
663 660 # Call them all now
664 661 for c in callables:
665 662 # Traits catches and logs errors here. I allow them to raise
666 663 if callable(c):
667 664 argspec = getargspec(c)
668 665
669 666 nargs = len(argspec[0])
670 667 # Bound methods have an additional 'self' argument
671 668 # I don't know how to treat unbound methods, but they
672 669 # can't really be used for callbacks.
673 670 if isinstance(c, types.MethodType):
674 671 offset = -1
675 672 else:
676 673 offset = 0
677 674 if nargs + offset == 0:
678 675 c()
679 676 elif nargs + offset == 1:
680 677 c(name)
681 678 elif nargs + offset == 2:
682 679 c(name, new_value)
683 680 elif nargs + offset == 3:
684 681 c(name, old_value, new_value)
685 682 else:
686 683 raise TraitError('a trait changed callback '
687 684 'must have 0-3 arguments.')
688 685 else:
689 686 raise TraitError('a trait changed callback '
690 687 'must be callable.')
691 688
692 689
693 690 def _add_notifiers(self, handler, name):
694 691 if name not in self._trait_notifiers:
695 692 nlist = []
696 693 self._trait_notifiers[name] = nlist
697 694 else:
698 695 nlist = self._trait_notifiers[name]
699 696 if handler not in nlist:
700 697 nlist.append(handler)
701 698
702 699 def _remove_notifiers(self, handler, name):
703 700 if name in self._trait_notifiers:
704 701 nlist = self._trait_notifiers[name]
705 702 try:
706 703 index = nlist.index(handler)
707 704 except ValueError:
708 705 pass
709 706 else:
710 707 del nlist[index]
711 708
712 709 def on_trait_change(self, handler, name=None, remove=False):
713 710 """Setup a handler to be called when a trait changes.
714 711
715 712 This is used to setup dynamic notifications of trait changes.
716 713
717 714 Static handlers can be created by creating methods on a HasTraits
718 715 subclass with the naming convention '_[traitname]_changed'. Thus,
719 716 to create static handler for the trait 'a', create the method
720 717 _a_changed(self, name, old, new) (fewer arguments can be used, see
721 718 below).
722 719
723 720 Parameters
724 721 ----------
725 722 handler : callable
726 723 A callable that is called when a trait changes. Its
727 724 signature can be handler(), handler(name), handler(name, new)
728 725 or handler(name, old, new).
729 726 name : list, str, None
730 727 If None, the handler will apply to all traits. If a list
731 728 of str, handler will apply to all names in the list. If a
732 729 str, the handler will apply just to that name.
733 730 remove : bool
734 731 If False (the default), then install the handler. If True
735 732 then unintall it.
736 733 """
737 734 if remove:
738 735 names = parse_notifier_name(name)
739 736 for n in names:
740 737 self._remove_notifiers(handler, n)
741 738 else:
742 739 names = parse_notifier_name(name)
743 740 for n in names:
744 741 self._add_notifiers(handler, n)
745 742
746 743 @classmethod
747 744 def class_trait_names(cls, **metadata):
748 745 """Get a list of all the names of this class' traits.
749 746
750 747 This method is just like the :meth:`trait_names` method,
751 748 but is unbound.
752 749 """
753 750 return cls.class_traits(**metadata).keys()
754 751
755 752 @classmethod
756 753 def class_traits(cls, **metadata):
757 754 """Get a `dict` of all the traits of this class. The dictionary
758 755 is keyed on the name and the values are the TraitType objects.
759 756
760 757 This method is just like the :meth:`traits` method, but is unbound.
761 758
762 759 The TraitTypes returned don't know anything about the values
763 760 that the various HasTrait's instances are holding.
764 761
765 762 The metadata kwargs allow functions to be passed in which
766 763 filter traits based on metadata values. The functions should
767 764 take a single value as an argument and return a boolean. If
768 765 any function returns False, then the trait is not included in
769 766 the output. This does not allow for any simple way of
770 767 testing that a metadata name exists and has any
771 768 value because get_metadata returns None if a metadata key
772 769 doesn't exist.
773 770 """
774 771 traits = dict([memb for memb in getmembers(cls) if
775 772 isinstance(memb[1], TraitType)])
776 773
777 774 if len(metadata) == 0:
778 775 return traits
779 776
780 777 for meta_name, meta_eval in metadata.items():
781 778 if type(meta_eval) is not FunctionType:
782 779 metadata[meta_name] = _SimpleTest(meta_eval)
783 780
784 781 result = {}
785 782 for name, trait in traits.items():
786 783 for meta_name, meta_eval in metadata.items():
787 784 if not meta_eval(trait.get_metadata(meta_name)):
788 785 break
789 786 else:
790 787 result[name] = trait
791 788
792 789 return result
793 790
794 791 def trait_names(self, **metadata):
795 792 """Get a list of all the names of this class' traits."""
796 793 return self.traits(**metadata).keys()
797 794
798 795 def traits(self, **metadata):
799 796 """Get a `dict` of all the traits of this class. The dictionary
800 797 is keyed on the name and the values are the TraitType objects.
801 798
802 799 The TraitTypes returned don't know anything about the values
803 800 that the various HasTrait's instances are holding.
804 801
805 802 The metadata kwargs allow functions to be passed in which
806 803 filter traits based on metadata values. The functions should
807 804 take a single value as an argument and return a boolean. If
808 805 any function returns False, then the trait is not included in
809 806 the output. This does not allow for any simple way of
810 807 testing that a metadata name exists and has any
811 808 value because get_metadata returns None if a metadata key
812 809 doesn't exist.
813 810 """
814 811 traits = dict([memb for memb in getmembers(self.__class__) if
815 812 isinstance(memb[1], TraitType)])
816 813
817 814 if len(metadata) == 0:
818 815 return traits
819 816
820 817 for meta_name, meta_eval in metadata.items():
821 818 if type(meta_eval) is not FunctionType:
822 819 metadata[meta_name] = _SimpleTest(meta_eval)
823 820
824 821 result = {}
825 822 for name, trait in traits.items():
826 823 for meta_name, meta_eval in metadata.items():
827 824 if not meta_eval(trait.get_metadata(meta_name)):
828 825 break
829 826 else:
830 827 result[name] = trait
831 828
832 829 return result
833 830
834 831 def trait_metadata(self, traitname, key, default=None):
835 832 """Get metadata values for trait by key."""
836 833 try:
837 834 trait = getattr(self.__class__, traitname)
838 835 except AttributeError:
839 836 raise TraitError("Class %s does not have a trait named %s" %
840 837 (self.__class__.__name__, traitname))
841 838 else:
842 839 return trait.get_metadata(key, default)
843 840
844 841 def add_trait(self, traitname, trait):
845 842 """Dynamically add a trait attribute to the HasTraits instance."""
846 843 self.__class__ = type(self.__class__.__name__, (self.__class__,),
847 844 {traitname: trait})
848 845 trait.set_default_value(self)
849 846
850 847 #-----------------------------------------------------------------------------
851 848 # Actual TraitTypes implementations/subclasses
852 849 #-----------------------------------------------------------------------------
853 850
854 851 #-----------------------------------------------------------------------------
855 852 # TraitTypes subclasses for handling classes and instances of classes
856 853 #-----------------------------------------------------------------------------
857 854
858 855
859 856 class ClassBasedTraitType(TraitType):
860 857 """
861 858 A trait with error reporting and string -> type resolution for Type,
862 859 Instance and This.
863 860 """
864 861
865 862 def _resolve_string(self, string):
866 863 """
867 864 Resolve a string supplied for a type into an actual object.
868 865 """
869 866 return import_item(string)
870 867
871 868 def error(self, obj, value):
872 869 kind = type(value)
873 870 if (not py3compat.PY3) and kind is InstanceType:
874 871 msg = 'class %s' % value.__class__.__name__
875 872 else:
876 873 msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
877 874
878 875 if obj is not None:
879 876 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
880 877 % (self.name, class_of(obj),
881 878 self.info(), msg)
882 879 else:
883 880 e = "The '%s' trait must be %s, but a value of %r was specified." \
884 881 % (self.name, self.info(), msg)
885 882
886 883 raise TraitError(e)
887 884
888 885
889 886 class Type(ClassBasedTraitType):
890 887 """A trait whose value must be a subclass of a specified class."""
891 888
892 889 def __init__ (self, default_value=None, klass=None, allow_none=False,
893 890 **metadata):
894 891 """Construct a Type trait
895 892
896 893 A Type trait specifies that its values must be subclasses of
897 894 a particular class.
898 895
899 896 If only ``default_value`` is given, it is used for the ``klass`` as
900 897 well.
901 898
902 899 Parameters
903 900 ----------
904 901 default_value : class, str or None
905 902 The default value must be a subclass of klass. If an str,
906 903 the str must be a fully specified class name, like 'foo.bar.Bah'.
907 904 The string is resolved into real class, when the parent
908 905 :class:`HasTraits` class is instantiated.
909 906 klass : class, str, None
910 907 Values of this trait must be a subclass of klass. The klass
911 908 may be specified in a string like: 'foo.bar.MyClass'.
912 909 The string is resolved into real class, when the parent
913 910 :class:`HasTraits` class is instantiated.
914 911 allow_none : bool [ default True ]
915 912 Indicates whether None is allowed as an assignable value. Even if
916 913 ``False``, the default value may be ``None``.
917 914 """
918 915 if default_value is None:
919 916 if klass is None:
920 917 klass = object
921 918 elif klass is None:
922 919 klass = default_value
923 920
924 921 if not (inspect.isclass(klass) or isinstance(klass, py3compat.string_types)):
925 922 raise TraitError("A Type trait must specify a class.")
926 923
927 924 self.klass = klass
928 925
929 926 super(Type, self).__init__(default_value, allow_none=allow_none, **metadata)
930 927
931 928 def validate(self, obj, value):
932 929 """Validates that the value is a valid object instance."""
933 930 if isinstance(value, py3compat.string_types):
934 931 try:
935 932 value = self._resolve_string(value)
936 933 except ImportError:
937 934 raise TraitError("The '%s' trait of %s instance must be a type, but "
938 935 "%r could not be imported" % (self.name, obj, value))
939 936 try:
940 937 if issubclass(value, self.klass):
941 938 return value
942 939 except:
943 940 pass
944 941
945 942 self.error(obj, value)
946 943
947 944 def info(self):
948 945 """ Returns a description of the trait."""
949 946 if isinstance(self.klass, py3compat.string_types):
950 947 klass = self.klass
951 948 else:
952 949 klass = self.klass.__name__
953 950 result = 'a subclass of ' + klass
954 951 if self.allow_none:
955 952 return result + ' or None'
956 953 return result
957 954
958 955 def instance_init(self):
959 956 self._resolve_classes()
960 957 super(Type, self).instance_init()
961 958
962 959 def _resolve_classes(self):
963 960 if isinstance(self.klass, py3compat.string_types):
964 961 self.klass = self._resolve_string(self.klass)
965 962 if isinstance(self.default_value, py3compat.string_types):
966 963 self.default_value = self._resolve_string(self.default_value)
967 964
968 965 def get_default_value(self):
969 966 return self.default_value
970 967
971 968
972 969 class DefaultValueGenerator(object):
973 970 """A class for generating new default value instances."""
974 971
975 972 def __init__(self, *args, **kw):
976 973 self.args = args
977 974 self.kw = kw
978 975
979 976 def generate(self, klass):
980 977 return klass(*self.args, **self.kw)
981 978
982 979
983 980 class Instance(ClassBasedTraitType):
984 981 """A trait whose value must be an instance of a specified class.
985 982
986 983 The value can also be an instance of a subclass of the specified class.
987 984
988 985 Subclasses can declare default classes by overriding the klass attribute
989 986 """
990 987
991 988 klass = None
992 989
993 990 def __init__(self, klass=None, args=None, kw=None, allow_none=False,
994 991 **metadata ):
995 992 """Construct an Instance trait.
996 993
997 994 This trait allows values that are instances of a particular
998 995 class or its subclasses. Our implementation is quite different
999 996 from that of enthough.traits as we don't allow instances to be used
1000 997 for klass and we handle the ``args`` and ``kw`` arguments differently.
1001 998
1002 999 Parameters
1003 1000 ----------
1004 1001 klass : class, str
1005 1002 The class that forms the basis for the trait. Class names
1006 1003 can also be specified as strings, like 'foo.bar.Bar'.
1007 1004 args : tuple
1008 1005 Positional arguments for generating the default value.
1009 1006 kw : dict
1010 1007 Keyword arguments for generating the default value.
1011 1008 allow_none : bool [default True]
1012 1009 Indicates whether None is allowed as a value.
1013 1010
1014 1011 Notes
1015 1012 -----
1016 1013 If both ``args`` and ``kw`` are None, then the default value is None.
1017 1014 If ``args`` is a tuple and ``kw`` is a dict, then the default is
1018 1015 created as ``klass(*args, **kw)``. If exactly one of ``args`` or ``kw`` is
1019 1016 None, the None is replaced by ``()`` or ``{}``, respectively.
1020 1017 """
1021 1018 if klass is None:
1022 1019 klass = self.klass
1023 1020
1024 1021 if (klass is not None) and (inspect.isclass(klass) or isinstance(klass, py3compat.string_types)):
1025 1022 self.klass = klass
1026 1023 else:
1027 1024 raise TraitError('The klass attribute must be a class'
1028 1025 ' not: %r' % klass)
1029 1026
1030 1027 # self.klass is a class, so handle default_value
1031 1028 if args is None and kw is None:
1032 1029 default_value = None
1033 1030 else:
1034 1031 if args is None:
1035 1032 # kw is not None
1036 1033 args = ()
1037 1034 elif kw is None:
1038 1035 # args is not None
1039 1036 kw = {}
1040 1037
1041 1038 if not isinstance(kw, dict):
1042 1039 raise TraitError("The 'kw' argument must be a dict or None.")
1043 1040 if not isinstance(args, tuple):
1044 1041 raise TraitError("The 'args' argument must be a tuple or None.")
1045 1042
1046 1043 default_value = DefaultValueGenerator(*args, **kw)
1047 1044
1048 1045 super(Instance, self).__init__(default_value, allow_none=allow_none, **metadata)
1049 1046
1050 1047 def validate(self, obj, value):
1051 1048 if isinstance(value, self.klass):
1052 1049 return value
1053 1050 else:
1054 1051 self.error(obj, value)
1055 1052
1056 1053 def info(self):
1057 1054 if isinstance(self.klass, py3compat.string_types):
1058 1055 klass = self.klass
1059 1056 else:
1060 1057 klass = self.klass.__name__
1061 1058 result = class_of(klass)
1062 1059 if self.allow_none:
1063 1060 return result + ' or None'
1064 1061
1065 1062 return result
1066 1063
1067 1064 def instance_init(self):
1068 1065 self._resolve_classes()
1069 1066 super(Instance, self).instance_init()
1070 1067
1071 1068 def _resolve_classes(self):
1072 1069 if isinstance(self.klass, py3compat.string_types):
1073 1070 self.klass = self._resolve_string(self.klass)
1074 1071
1075 1072 def get_default_value(self):
1076 1073 """Instantiate a default value instance.
1077 1074
1078 1075 This is called when the containing HasTraits classes'
1079 1076 :meth:`__new__` method is called to ensure that a unique instance
1080 1077 is created for each HasTraits instance.
1081 1078 """
1082 1079 dv = self.default_value
1083 1080 if isinstance(dv, DefaultValueGenerator):
1084 1081 return dv.generate(self.klass)
1085 1082 else:
1086 1083 return dv
1087 1084
1088 1085
1089 1086 class ForwardDeclaredMixin(object):
1090 1087 """
1091 1088 Mixin for forward-declared versions of Instance and Type.
1092 1089 """
1093 1090 def _resolve_string(self, string):
1094 1091 """
1095 1092 Find the specified class name by looking for it in the module in which
1096 1093 our this_class attribute was defined.
1097 1094 """
1098 1095 modname = self.this_class.__module__
1099 1096 return import_item('.'.join([modname, string]))
1100 1097
1101 1098
1102 1099 class ForwardDeclaredType(ForwardDeclaredMixin, Type):
1103 1100 """
1104 1101 Forward-declared version of Type.
1105 1102 """
1106 1103 pass
1107 1104
1108 1105
1109 1106 class ForwardDeclaredInstance(ForwardDeclaredMixin, Instance):
1110 1107 """
1111 1108 Forward-declared version of Instance.
1112 1109 """
1113 1110 pass
1114 1111
1115 1112
1116 1113 class This(ClassBasedTraitType):
1117 1114 """A trait for instances of the class containing this trait.
1118 1115
1119 1116 Because how how and when class bodies are executed, the ``This``
1120 1117 trait can only have a default value of None. This, and because we
1121 1118 always validate default values, ``allow_none`` is *always* true.
1122 1119 """
1123 1120
1124 1121 info_text = 'an instance of the same type as the receiver or None'
1125 1122
1126 1123 def __init__(self, **metadata):
1127 1124 super(This, self).__init__(None, **metadata)
1128 1125
1129 1126 def validate(self, obj, value):
1130 1127 # What if value is a superclass of obj.__class__? This is
1131 1128 # complicated if it was the superclass that defined the This
1132 1129 # trait.
1133 1130 if isinstance(value, self.this_class) or (value is None):
1134 1131 return value
1135 1132 else:
1136 1133 self.error(obj, value)
1137 1134
1138 1135
1139 1136 class Union(TraitType):
1140 1137 """A trait type representing a Union type."""
1141 1138
1142 1139 def __init__(self, trait_types, **metadata):
1143 1140 """Construct a Union trait.
1144 1141
1145 1142 This trait allows values that are allowed by at least one of the
1146 1143 specified trait types. A Union traitlet cannot have metadata on
1147 1144 its own, besides the metadata of the listed types.
1148 1145
1149 1146 Parameters
1150 1147 ----------
1151 1148 trait_types: sequence
1152 1149 The list of trait types of length at least 1.
1153 1150
1154 1151 Notes
1155 1152 -----
1156 1153 Union([Float(), Bool(), Int()]) attempts to validate the provided values
1157 1154 with the validation function of Float, then Bool, and finally Int.
1158 1155 """
1159 1156 self.trait_types = trait_types
1160 1157 self.info_text = " or ".join([tt.info_text for tt in self.trait_types])
1161 1158 self.default_value = self.trait_types[0].get_default_value()
1162 1159 super(Union, self).__init__(**metadata)
1163 1160
1164 1161 def instance_init(self):
1165 1162 for trait_type in self.trait_types:
1166 1163 trait_type.name = self.name
1167 1164 trait_type.this_class = self.this_class
1168 1165 trait_type.instance_init()
1169 1166 super(Union, self).instance_init()
1170 1167
1171 1168 def validate(self, obj, value):
1172 1169 for trait_type in self.trait_types:
1173 1170 try:
1174 1171 v = trait_type._validate(obj, value)
1175 1172 self._metadata = trait_type._metadata
1176 1173 return v
1177 1174 except TraitError:
1178 1175 continue
1179 1176 self.error(obj, value)
1180 1177
1181 1178 def __or__(self, other):
1182 1179 if isinstance(other, Union):
1183 1180 return Union(self.trait_types + other.trait_types)
1184 1181 else:
1185 1182 return Union(self.trait_types + [other])
1186 1183
1187 1184 #-----------------------------------------------------------------------------
1188 1185 # Basic TraitTypes implementations/subclasses
1189 1186 #-----------------------------------------------------------------------------
1190 1187
1191 1188
1192 1189 class Any(TraitType):
1193 1190 default_value = None
1194 1191 info_text = 'any value'
1195 1192
1196 1193
1197 1194 class Int(TraitType):
1198 1195 """An int trait."""
1199 1196
1200 1197 default_value = 0
1201 1198 info_text = 'an int'
1202 1199
1203 1200 def validate(self, obj, value):
1204 1201 if isinstance(value, int):
1205 1202 return value
1206 1203 self.error(obj, value)
1207 1204
1208 1205 class CInt(Int):
1209 1206 """A casting version of the int trait."""
1210 1207
1211 1208 def validate(self, obj, value):
1212 1209 try:
1213 1210 return int(value)
1214 1211 except:
1215 1212 self.error(obj, value)
1216 1213
1217 1214 if py3compat.PY3:
1218 1215 Long, CLong = Int, CInt
1219 1216 Integer = Int
1220 1217 else:
1221 1218 class Long(TraitType):
1222 1219 """A long integer trait."""
1223 1220
1224 1221 default_value = 0
1225 1222 info_text = 'a long'
1226 1223
1227 1224 def validate(self, obj, value):
1228 1225 if isinstance(value, long):
1229 1226 return value
1230 1227 if isinstance(value, int):
1231 1228 return long(value)
1232 1229 self.error(obj, value)
1233 1230
1234 1231
1235 1232 class CLong(Long):
1236 1233 """A casting version of the long integer trait."""
1237 1234
1238 1235 def validate(self, obj, value):
1239 1236 try:
1240 1237 return long(value)
1241 1238 except:
1242 1239 self.error(obj, value)
1243 1240
1244 1241 class Integer(TraitType):
1245 1242 """An integer trait.
1246 1243
1247 1244 Longs that are unnecessary (<= sys.maxint) are cast to ints."""
1248 1245
1249 1246 default_value = 0
1250 1247 info_text = 'an integer'
1251 1248
1252 1249 def validate(self, obj, value):
1253 1250 if isinstance(value, int):
1254 1251 return value
1255 1252 if isinstance(value, long):
1256 1253 # downcast longs that fit in int:
1257 1254 # note that int(n > sys.maxint) returns a long, so
1258 1255 # we don't need a condition on this cast
1259 1256 return int(value)
1260 1257 if sys.platform == "cli":
1261 1258 from System import Int64
1262 1259 if isinstance(value, Int64):
1263 1260 return int(value)
1264 1261 self.error(obj, value)
1265 1262
1266 1263
1267 1264 class Float(TraitType):
1268 1265 """A float trait."""
1269 1266
1270 1267 default_value = 0.0
1271 1268 info_text = 'a float'
1272 1269
1273 1270 def validate(self, obj, value):
1274 1271 if isinstance(value, float):
1275 1272 return value
1276 1273 if isinstance(value, int):
1277 1274 return float(value)
1278 1275 self.error(obj, value)
1279 1276
1280 1277
1281 1278 class CFloat(Float):
1282 1279 """A casting version of the float trait."""
1283 1280
1284 1281 def validate(self, obj, value):
1285 1282 try:
1286 1283 return float(value)
1287 1284 except:
1288 1285 self.error(obj, value)
1289 1286
1290 1287 class Complex(TraitType):
1291 1288 """A trait for complex numbers."""
1292 1289
1293 1290 default_value = 0.0 + 0.0j
1294 1291 info_text = 'a complex number'
1295 1292
1296 1293 def validate(self, obj, value):
1297 1294 if isinstance(value, complex):
1298 1295 return value
1299 1296 if isinstance(value, (float, int)):
1300 1297 return complex(value)
1301 1298 self.error(obj, value)
1302 1299
1303 1300
1304 1301 class CComplex(Complex):
1305 1302 """A casting version of the complex number trait."""
1306 1303
1307 1304 def validate (self, obj, value):
1308 1305 try:
1309 1306 return complex(value)
1310 1307 except:
1311 1308 self.error(obj, value)
1312 1309
1313 1310 # We should always be explicit about whether we're using bytes or unicode, both
1314 1311 # for Python 3 conversion and for reliable unicode behaviour on Python 2. So
1315 1312 # we don't have a Str type.
1316 1313 class Bytes(TraitType):
1317 1314 """A trait for byte strings."""
1318 1315
1319 1316 default_value = b''
1320 1317 info_text = 'a bytes object'
1321 1318
1322 1319 def validate(self, obj, value):
1323 1320 if isinstance(value, bytes):
1324 1321 return value
1325 1322 self.error(obj, value)
1326 1323
1327 1324
1328 1325 class CBytes(Bytes):
1329 1326 """A casting version of the byte string trait."""
1330 1327
1331 1328 def validate(self, obj, value):
1332 1329 try:
1333 1330 return bytes(value)
1334 1331 except:
1335 1332 self.error(obj, value)
1336 1333
1337 1334
1338 1335 class Unicode(TraitType):
1339 1336 """A trait for unicode strings."""
1340 1337
1341 1338 default_value = u''
1342 1339 info_text = 'a unicode string'
1343 1340
1344 1341 def validate(self, obj, value):
1345 1342 if isinstance(value, py3compat.unicode_type):
1346 1343 return value
1347 1344 if isinstance(value, bytes):
1348 1345 try:
1349 1346 return value.decode('ascii', 'strict')
1350 1347 except UnicodeDecodeError:
1351 1348 msg = "Could not decode {!r} for unicode trait '{}' of {} instance."
1352 1349 raise TraitError(msg.format(value, self.name, class_of(obj)))
1353 1350 self.error(obj, value)
1354 1351
1355 1352
1356 1353 class CUnicode(Unicode):
1357 1354 """A casting version of the unicode trait."""
1358 1355
1359 1356 def validate(self, obj, value):
1360 1357 try:
1361 1358 return py3compat.unicode_type(value)
1362 1359 except:
1363 1360 self.error(obj, value)
1364 1361
1365 1362
1366 1363 class ObjectName(TraitType):
1367 1364 """A string holding a valid object name in this version of Python.
1368 1365
1369 1366 This does not check that the name exists in any scope."""
1370 1367 info_text = "a valid object identifier in Python"
1371 1368
1372 1369 if py3compat.PY3:
1373 1370 # Python 3:
1374 1371 coerce_str = staticmethod(lambda _,s: s)
1375 1372
1376 1373 else:
1377 1374 # Python 2:
1378 1375 def coerce_str(self, obj, value):
1379 1376 "In Python 2, coerce ascii-only unicode to str"
1380 1377 if isinstance(value, unicode):
1381 1378 try:
1382 1379 return str(value)
1383 1380 except UnicodeEncodeError:
1384 1381 self.error(obj, value)
1385 1382 return value
1386 1383
1387 1384 def validate(self, obj, value):
1388 1385 value = self.coerce_str(obj, value)
1389 1386
1390 1387 if isinstance(value, string_types) and py3compat.isidentifier(value):
1391 1388 return value
1392 1389 self.error(obj, value)
1393 1390
1394 1391 class DottedObjectName(ObjectName):
1395 1392 """A string holding a valid dotted object name in Python, such as A.b3._c"""
1396 1393 def validate(self, obj, value):
1397 1394 value = self.coerce_str(obj, value)
1398 1395
1399 1396 if isinstance(value, string_types) and py3compat.isidentifier(value, dotted=True):
1400 1397 return value
1401 1398 self.error(obj, value)
1402 1399
1403 1400
1404 1401 class Bool(TraitType):
1405 1402 """A boolean (True, False) trait."""
1406 1403
1407 1404 default_value = False
1408 1405 info_text = 'a boolean'
1409 1406
1410 1407 def validate(self, obj, value):
1411 1408 if isinstance(value, bool):
1412 1409 return value
1413 1410 self.error(obj, value)
1414 1411
1415 1412
1416 1413 class CBool(Bool):
1417 1414 """A casting version of the boolean trait."""
1418 1415
1419 1416 def validate(self, obj, value):
1420 1417 try:
1421 1418 return bool(value)
1422 1419 except:
1423 1420 self.error(obj, value)
1424 1421
1425 1422
1426 1423 class Enum(TraitType):
1427 1424 """An enum that whose value must be in a given sequence."""
1428 1425
1429 1426 def __init__(self, values, default_value=None, **metadata):
1430 1427 self.values = values
1431 1428 super(Enum, self).__init__(default_value, **metadata)
1432 1429
1433 1430 def validate(self, obj, value):
1434 1431 if value in self.values:
1435 1432 return value
1436 1433 self.error(obj, value)
1437 1434
1438 1435 def info(self):
1439 1436 """ Returns a description of the trait."""
1440 1437 result = 'any of ' + repr(self.values)
1441 1438 if self.allow_none:
1442 1439 return result + ' or None'
1443 1440 return result
1444 1441
1445 1442 class CaselessStrEnum(Enum):
1446 1443 """An enum of strings that are caseless in validate."""
1447 1444
1448 1445 def validate(self, obj, value):
1449 1446 if not isinstance(value, py3compat.string_types):
1450 1447 self.error(obj, value)
1451 1448
1452 1449 for v in self.values:
1453 1450 if v.lower() == value.lower():
1454 1451 return v
1455 1452 self.error(obj, value)
1456 1453
1457 1454 class Container(Instance):
1458 1455 """An instance of a container (list, set, etc.)
1459 1456
1460 1457 To be subclassed by overriding klass.
1461 1458 """
1462 1459 klass = None
1463 1460 _cast_types = ()
1464 1461 _valid_defaults = SequenceTypes
1465 1462 _trait = None
1466 1463
1467 1464 def __init__(self, trait=None, default_value=None, allow_none=False,
1468 1465 **metadata):
1469 1466 """Create a container trait type from a list, set, or tuple.
1470 1467
1471 1468 The default value is created by doing ``List(default_value)``,
1472 1469 which creates a copy of the ``default_value``.
1473 1470
1474 1471 ``trait`` can be specified, which restricts the type of elements
1475 1472 in the container to that TraitType.
1476 1473
1477 1474 If only one arg is given and it is not a Trait, it is taken as
1478 1475 ``default_value``:
1479 1476
1480 1477 ``c = List([1,2,3])``
1481 1478
1482 1479 Parameters
1483 1480 ----------
1484 1481
1485 1482 trait : TraitType [ optional ]
1486 1483 the type for restricting the contents of the Container. If unspecified,
1487 1484 types are not checked.
1488 1485
1489 1486 default_value : SequenceType [ optional ]
1490 1487 The default value for the Trait. Must be list/tuple/set, and
1491 1488 will be cast to the container type.
1492 1489
1493 1490 allow_none : bool [ default False ]
1494 1491 Whether to allow the value to be None
1495 1492
1496 1493 **metadata : any
1497 1494 further keys for extensions to the Trait (e.g. config)
1498 1495
1499 1496 """
1500 1497 # allow List([values]):
1501 1498 if default_value is None and not is_trait(trait):
1502 1499 default_value = trait
1503 1500 trait = None
1504 1501
1505 1502 if default_value is None:
1506 1503 args = ()
1507 1504 elif isinstance(default_value, self._valid_defaults):
1508 1505 args = (default_value,)
1509 1506 else:
1510 1507 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1511 1508
1512 1509 if is_trait(trait):
1513 1510 self._trait = trait() if isinstance(trait, type) else trait
1514 1511 self._trait.name = 'element'
1515 1512 elif trait is not None:
1516 1513 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1517 1514
1518 1515 super(Container,self).__init__(klass=self.klass, args=args,
1519 1516 allow_none=allow_none, **metadata)
1520 1517
1521 1518 def element_error(self, obj, element, validator):
1522 1519 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1523 1520 % (self.name, class_of(obj), validator.info(), repr_type(element))
1524 1521 raise TraitError(e)
1525 1522
1526 1523 def validate(self, obj, value):
1527 1524 if isinstance(value, self._cast_types):
1528 1525 value = self.klass(value)
1529 1526 value = super(Container, self).validate(obj, value)
1530 1527 if value is None:
1531 1528 return value
1532 1529
1533 1530 value = self.validate_elements(obj, value)
1534 1531
1535 1532 return value
1536 1533
1537 1534 def validate_elements(self, obj, value):
1538 1535 validated = []
1539 1536 if self._trait is None or isinstance(self._trait, Any):
1540 1537 return value
1541 1538 for v in value:
1542 1539 try:
1543 1540 v = self._trait._validate(obj, v)
1544 1541 except TraitError:
1545 1542 self.element_error(obj, v, self._trait)
1546 1543 else:
1547 1544 validated.append(v)
1548 1545 return self.klass(validated)
1549 1546
1550 1547 def instance_init(self):
1551 1548 if isinstance(self._trait, TraitType):
1552 1549 self._trait.this_class = self.this_class
1553 1550 self._trait.instance_init()
1554 1551 super(Container, self).instance_init()
1555 1552
1556 1553
1557 1554 class List(Container):
1558 1555 """An instance of a Python list."""
1559 1556 klass = list
1560 1557 _cast_types = (tuple,)
1561 1558
1562 1559 def __init__(self, trait=None, default_value=None, minlen=0, maxlen=sys.maxsize, **metadata):
1563 1560 """Create a List trait type from a list, set, or tuple.
1564 1561
1565 1562 The default value is created by doing ``List(default_value)``,
1566 1563 which creates a copy of the ``default_value``.
1567 1564
1568 1565 ``trait`` can be specified, which restricts the type of elements
1569 1566 in the container to that TraitType.
1570 1567
1571 1568 If only one arg is given and it is not a Trait, it is taken as
1572 1569 ``default_value``:
1573 1570
1574 1571 ``c = List([1,2,3])``
1575 1572
1576 1573 Parameters
1577 1574 ----------
1578 1575
1579 1576 trait : TraitType [ optional ]
1580 1577 the type for restricting the contents of the Container. If unspecified,
1581 1578 types are not checked.
1582 1579
1583 1580 default_value : SequenceType [ optional ]
1584 1581 The default value for the Trait. Must be list/tuple/set, and
1585 1582 will be cast to the container type.
1586 1583
1587 1584 minlen : Int [ default 0 ]
1588 1585 The minimum length of the input list
1589 1586
1590 1587 maxlen : Int [ default sys.maxsize ]
1591 1588 The maximum length of the input list
1592 1589
1593 1590 allow_none : bool [ default False ]
1594 1591 Whether to allow the value to be None
1595 1592
1596 1593 **metadata : any
1597 1594 further keys for extensions to the Trait (e.g. config)
1598 1595
1599 1596 """
1600 1597 self._minlen = minlen
1601 1598 self._maxlen = maxlen
1602 1599 super(List, self).__init__(trait=trait, default_value=default_value,
1603 1600 **metadata)
1604 1601
1605 1602 def length_error(self, obj, value):
1606 1603 e = "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." \
1607 1604 % (self.name, class_of(obj), self._minlen, self._maxlen, value)
1608 1605 raise TraitError(e)
1609 1606
1610 1607 def validate_elements(self, obj, value):
1611 1608 length = len(value)
1612 1609 if length < self._minlen or length > self._maxlen:
1613 1610 self.length_error(obj, value)
1614 1611
1615 1612 return super(List, self).validate_elements(obj, value)
1616 1613
1617 1614 def validate(self, obj, value):
1618 1615 value = super(List, self).validate(obj, value)
1619 1616 value = self.validate_elements(obj, value)
1620 1617 return value
1621 1618
1622 1619
1623 1620 class Set(List):
1624 1621 """An instance of a Python set."""
1625 1622 klass = set
1626 1623 _cast_types = (tuple, list)
1627 1624
1628 1625
1629 1626 class Tuple(Container):
1630 1627 """An instance of a Python tuple."""
1631 1628 klass = tuple
1632 1629 _cast_types = (list,)
1633 1630
1634 1631 def __init__(self, *traits, **metadata):
1635 1632 """Tuple(*traits, default_value=None, **medatata)
1636 1633
1637 1634 Create a tuple from a list, set, or tuple.
1638 1635
1639 1636 Create a fixed-type tuple with Traits:
1640 1637
1641 1638 ``t = Tuple(Int, Str, CStr)``
1642 1639
1643 1640 would be length 3, with Int,Str,CStr for each element.
1644 1641
1645 1642 If only one arg is given and it is not a Trait, it is taken as
1646 1643 default_value:
1647 1644
1648 1645 ``t = Tuple((1,2,3))``
1649 1646
1650 1647 Otherwise, ``default_value`` *must* be specified by keyword.
1651 1648
1652 1649 Parameters
1653 1650 ----------
1654 1651
1655 1652 *traits : TraitTypes [ optional ]
1656 1653 the types for restricting the contents of the Tuple. If unspecified,
1657 1654 types are not checked. If specified, then each positional argument
1658 1655 corresponds to an element of the tuple. Tuples defined with traits
1659 1656 are of fixed length.
1660 1657
1661 1658 default_value : SequenceType [ optional ]
1662 1659 The default value for the Tuple. Must be list/tuple/set, and
1663 1660 will be cast to a tuple. If `traits` are specified, the
1664 1661 `default_value` must conform to the shape and type they specify.
1665 1662
1666 1663 allow_none : bool [ default False ]
1667 1664 Whether to allow the value to be None
1668 1665
1669 1666 **metadata : any
1670 1667 further keys for extensions to the Trait (e.g. config)
1671 1668
1672 1669 """
1673 1670 default_value = metadata.pop('default_value', None)
1674 1671 allow_none = metadata.pop('allow_none', True)
1675 1672
1676 1673 # allow Tuple((values,)):
1677 1674 if len(traits) == 1 and default_value is None and not is_trait(traits[0]):
1678 1675 default_value = traits[0]
1679 1676 traits = ()
1680 1677
1681 1678 if default_value is None:
1682 1679 args = ()
1683 1680 elif isinstance(default_value, self._valid_defaults):
1684 1681 args = (default_value,)
1685 1682 else:
1686 1683 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1687 1684
1688 1685 self._traits = []
1689 1686 for trait in traits:
1690 1687 t = trait() if isinstance(trait, type) else trait
1691 1688 t.name = 'element'
1692 1689 self._traits.append(t)
1693 1690
1694 1691 if self._traits and default_value is None:
1695 1692 # don't allow default to be an empty container if length is specified
1696 1693 args = None
1697 1694 super(Container,self).__init__(klass=self.klass, args=args, allow_none=allow_none, **metadata)
1698 1695
1699 1696 def validate_elements(self, obj, value):
1700 1697 if not self._traits:
1701 1698 # nothing to validate
1702 1699 return value
1703 1700 if len(value) != len(self._traits):
1704 1701 e = "The '%s' trait of %s instance requires %i elements, but a value of %s was specified." \
1705 1702 % (self.name, class_of(obj), len(self._traits), repr_type(value))
1706 1703 raise TraitError(e)
1707 1704
1708 1705 validated = []
1709 1706 for t, v in zip(self._traits, value):
1710 1707 try:
1711 1708 v = t._validate(obj, v)
1712 1709 except TraitError:
1713 1710 self.element_error(obj, v, t)
1714 1711 else:
1715 1712 validated.append(v)
1716 1713 return tuple(validated)
1717 1714
1718 1715 def instance_init(self):
1719 1716 for trait in self._traits:
1720 1717 if isinstance(trait, TraitType):
1721 1718 trait.this_class = self.this_class
1722 1719 trait.instance_init()
1723 1720 super(Container, self).instance_init()
1724 1721
1725 1722
1726 1723 class Dict(Instance):
1727 1724 """An instance of a Python dict."""
1728 1725 _trait = None
1729 1726
1730 1727 def __init__(self, trait=None, default_value=NoDefaultSpecified, allow_none=False, **metadata):
1731 1728 """Create a dict trait type from a dict.
1732 1729
1733 1730 The default value is created by doing ``dict(default_value)``,
1734 1731 which creates a copy of the ``default_value``.
1735 1732
1736 1733 trait : TraitType [ optional ]
1737 1734 the type for restricting the contents of the Container. If unspecified,
1738 1735 types are not checked.
1739 1736
1740 1737 default_value : SequenceType [ optional ]
1741 1738 The default value for the Dict. Must be dict, tuple, or None, and
1742 1739 will be cast to a dict if not None. If `trait` is specified, the
1743 1740 `default_value` must conform to the constraints it specifies.
1744 1741
1745 1742 allow_none : bool [ default False ]
1746 1743 Whether to allow the value to be None
1747 1744
1748 1745 """
1749 1746 if default_value is NoDefaultSpecified and trait is not None:
1750 1747 if not is_trait(trait):
1751 1748 default_value = trait
1752 1749 trait = None
1753 1750 if default_value is NoDefaultSpecified:
1754 1751 default_value = {}
1755 1752 if default_value is None:
1756 1753 args = None
1757 1754 elif isinstance(default_value, dict):
1758 1755 args = (default_value,)
1759 1756 elif isinstance(default_value, SequenceTypes):
1760 1757 args = (default_value,)
1761 1758 else:
1762 1759 raise TypeError('default value of Dict was %s' % default_value)
1763 1760
1764 1761 if is_trait(trait):
1765 1762 self._trait = trait() if isinstance(trait, type) else trait
1766 1763 self._trait.name = 'element'
1767 1764 elif trait is not None:
1768 1765 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1769 1766
1770 1767 super(Dict,self).__init__(klass=dict, args=args,
1771 1768 allow_none=allow_none, **metadata)
1772 1769
1773 1770 def element_error(self, obj, element, validator):
1774 1771 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1775 1772 % (self.name, class_of(obj), validator.info(), repr_type(element))
1776 1773 raise TraitError(e)
1777 1774
1778 1775 def validate(self, obj, value):
1779 1776 value = super(Dict, self).validate(obj, value)
1780 1777 if value is None:
1781 1778 return value
1782 1779 value = self.validate_elements(obj, value)
1783 1780 return value
1784 1781
1785 1782 def validate_elements(self, obj, value):
1786 1783 if self._trait is None or isinstance(self._trait, Any):
1787 1784 return value
1788 1785 validated = {}
1789 1786 for key in value:
1790 1787 v = value[key]
1791 1788 try:
1792 1789 v = self._trait._validate(obj, v)
1793 1790 except TraitError:
1794 1791 self.element_error(obj, v, self._trait)
1795 1792 else:
1796 1793 validated[key] = v
1797 1794 return self.klass(validated)
1798 1795
1799 1796 def instance_init(self):
1800 1797 if isinstance(self._trait, TraitType):
1801 1798 self._trait.this_class = self.this_class
1802 1799 self._trait.instance_init()
1803 1800 super(Dict, self).instance_init()
1804 1801
1805 1802
1806 1803 class EventfulDict(Instance):
1807 1804 """An instance of an EventfulDict."""
1808 1805
1809 1806 def __init__(self, default_value={}, allow_none=False, **metadata):
1810 1807 """Create a EventfulDict trait type from a dict.
1811 1808
1812 1809 The default value is created by doing
1813 1810 ``eventful.EvenfulDict(default_value)``, which creates a copy of the
1814 1811 ``default_value``.
1815 1812 """
1816 1813 if default_value is None:
1817 1814 args = None
1818 1815 elif isinstance(default_value, dict):
1819 1816 args = (default_value,)
1820 1817 elif isinstance(default_value, SequenceTypes):
1821 1818 args = (default_value,)
1822 1819 else:
1823 1820 raise TypeError('default value of EventfulDict was %s' % default_value)
1824 1821
1825 1822 super(EventfulDict, self).__init__(klass=eventful.EventfulDict, args=args,
1826 1823 allow_none=allow_none, **metadata)
1827 1824
1828 1825
1829 1826 class EventfulList(Instance):
1830 1827 """An instance of an EventfulList."""
1831 1828
1832 1829 def __init__(self, default_value=None, allow_none=False, **metadata):
1833 1830 """Create a EventfulList trait type from a dict.
1834 1831
1835 1832 The default value is created by doing
1836 1833 ``eventful.EvenfulList(default_value)``, which creates a copy of the
1837 1834 ``default_value``.
1838 1835 """
1839 1836 if default_value is None:
1840 1837 args = ((),)
1841 1838 else:
1842 1839 args = (default_value,)
1843 1840
1844 1841 super(EventfulList, self).__init__(klass=eventful.EventfulList, args=args,
1845 1842 allow_none=allow_none, **metadata)
1846 1843
1847 1844
1848 1845 class TCPAddress(TraitType):
1849 1846 """A trait for an (ip, port) tuple.
1850 1847
1851 1848 This allows for both IPv4 IP addresses as well as hostnames.
1852 1849 """
1853 1850
1854 1851 default_value = ('127.0.0.1', 0)
1855 1852 info_text = 'an (ip, port) tuple'
1856 1853
1857 1854 def validate(self, obj, value):
1858 1855 if isinstance(value, tuple):
1859 1856 if len(value) == 2:
1860 1857 if isinstance(value[0], py3compat.string_types) and isinstance(value[1], int):
1861 1858 port = value[1]
1862 1859 if port >= 0 and port <= 65535:
1863 1860 return value
1864 1861 self.error(obj, value)
1865 1862
1866 1863 class CRegExp(TraitType):
1867 1864 """A casting compiled regular expression trait.
1868 1865
1869 1866 Accepts both strings and compiled regular expressions. The resulting
1870 1867 attribute will be a compiled regular expression."""
1871 1868
1872 1869 info_text = 'a regular expression'
1873 1870
1874 1871 def validate(self, obj, value):
1875 1872 try:
1876 1873 return re.compile(value)
1877 1874 except:
1878 1875 self.error(obj, value)
General Comments 0
You need to be logged in to leave comments. Login now