##// END OF EJS Templates
remove now-obsolete use of skip_doctest outside core
Min RK -
Show More
@@ -1,707 +1,702 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """
2 """
3 ======
3 ======
4 Rmagic
4 Rmagic
5 ======
5 ======
6
6
7 Magic command interface for interactive work with R via rpy2
7 Magic command interface for interactive work with R via rpy2
8
8
9 .. note::
9 .. note::
10
10
11 The ``rpy2`` package needs to be installed separately. It
11 The ``rpy2`` package needs to be installed separately. It
12 can be obtained using ``easy_install`` or ``pip``.
12 can be obtained using ``easy_install`` or ``pip``.
13
13
14 You will also need a working copy of R.
14 You will also need a working copy of R.
15
15
16 Usage
16 Usage
17 =====
17 =====
18
18
19 To enable the magics below, execute ``%load_ext rmagic``.
19 To enable the magics below, execute ``%load_ext rmagic``.
20
20
21 ``%R``
21 ``%R``
22
22
23 {R_DOC}
23 {R_DOC}
24
24
25 ``%Rpush``
25 ``%Rpush``
26
26
27 {RPUSH_DOC}
27 {RPUSH_DOC}
28
28
29 ``%Rpull``
29 ``%Rpull``
30
30
31 {RPULL_DOC}
31 {RPULL_DOC}
32
32
33 ``%Rget``
33 ``%Rget``
34
34
35 {RGET_DOC}
35 {RGET_DOC}
36
36
37 """
37 """
38 from __future__ import print_function
38 from __future__ import print_function
39
39
40 #-----------------------------------------------------------------------------
40 #-----------------------------------------------------------------------------
41 # Copyright (C) 2012 The IPython Development Team
41 # Copyright (C) 2012 The IPython Development Team
42 #
42 #
43 # Distributed under the terms of the BSD License. The full license is in
43 # Distributed under the terms of the BSD License. The full license is in
44 # the file COPYING, distributed as part of this software.
44 # the file COPYING, distributed as part of this software.
45 #-----------------------------------------------------------------------------
45 #-----------------------------------------------------------------------------
46
46
47 import sys
47 import sys
48 import tempfile
48 import tempfile
49 from glob import glob
49 from glob import glob
50 from shutil import rmtree
50 from shutil import rmtree
51 import warnings
51 import warnings
52
52
53 # numpy and rpy2 imports
53 # numpy and rpy2 imports
54
54
55 import numpy as np
55 import numpy as np
56
56
57 import rpy2.rinterface as ri
57 import rpy2.rinterface as ri
58 import rpy2.robjects as ro
58 import rpy2.robjects as ro
59 try:
59 try:
60 from rpy2.robjects import pandas2ri
60 from rpy2.robjects import pandas2ri
61 pandas2ri.activate()
61 pandas2ri.activate()
62 except ImportError:
62 except ImportError:
63 pandas2ri = None
63 pandas2ri = None
64 from rpy2.robjects import numpy2ri
64 from rpy2.robjects import numpy2ri
65 numpy2ri.activate()
65 numpy2ri.activate()
66
66
67 # IPython imports
67 # IPython imports
68
68
69 from IPython.core.displaypub import publish_display_data
69 from IPython.core.displaypub import publish_display_data
70 from IPython.core.magic import (Magics, magics_class, line_magic,
70 from IPython.core.magic import (Magics, magics_class, line_magic,
71 line_cell_magic, needs_local_scope)
71 line_cell_magic, needs_local_scope)
72 from IPython.testing.skipdoctest import skip_doctest
73 from IPython.core.magic_arguments import (
72 from IPython.core.magic_arguments import (
74 argument, magic_arguments, parse_argstring
73 argument, magic_arguments, parse_argstring
75 )
74 )
76 from simplegeneric import generic
75 from simplegeneric import generic
77 from IPython.utils.py3compat import (str_to_unicode, unicode_to_str, PY3,
76 from IPython.utils.py3compat import (str_to_unicode, unicode_to_str, PY3,
78 unicode_type)
77 unicode_type)
79 from IPython.utils.text import dedent
78 from IPython.utils.text import dedent
80
79
81 class RInterpreterError(ri.RRuntimeError):
80 class RInterpreterError(ri.RRuntimeError):
82 """An error when running R code in a %%R magic cell."""
81 """An error when running R code in a %%R magic cell."""
83 def __init__(self, line, err, stdout):
82 def __init__(self, line, err, stdout):
84 self.line = line
83 self.line = line
85 self.err = err.rstrip()
84 self.err = err.rstrip()
86 self.stdout = stdout.rstrip()
85 self.stdout = stdout.rstrip()
87
86
88 def __unicode__(self):
87 def __unicode__(self):
89 s = 'Failed to parse and evaluate line %r.\nR error message: %r' % \
88 s = 'Failed to parse and evaluate line %r.\nR error message: %r' % \
90 (self.line, self.err)
89 (self.line, self.err)
91 if self.stdout and (self.stdout != self.err):
90 if self.stdout and (self.stdout != self.err):
92 s += '\nR stdout:\n' + self.stdout
91 s += '\nR stdout:\n' + self.stdout
93 return s
92 return s
94
93
95 if PY3:
94 if PY3:
96 __str__ = __unicode__
95 __str__ = __unicode__
97 else:
96 else:
98 def __str__(self):
97 def __str__(self):
99 return unicode_to_str(unicode(self), 'utf-8')
98 return unicode_to_str(unicode(self), 'utf-8')
100
99
101 def Rconverter(Robj, dataframe=False):
100 def Rconverter(Robj, dataframe=False):
102 """
101 """
103 Convert an object in R's namespace to one suitable
102 Convert an object in R's namespace to one suitable
104 for ipython's namespace.
103 for ipython's namespace.
105
104
106 For a data.frame, it tries to return a structured array.
105 For a data.frame, it tries to return a structured array.
107 It first checks for colnames, then names.
106 It first checks for colnames, then names.
108 If all are NULL, it returns np.asarray(Robj), else
107 If all are NULL, it returns np.asarray(Robj), else
109 it tries to construct a recarray
108 it tries to construct a recarray
110
109
111 Parameters
110 Parameters
112 ----------
111 ----------
113
112
114 Robj: an R object returned from rpy2
113 Robj: an R object returned from rpy2
115 """
114 """
116 is_data_frame = ro.r('is.data.frame')
115 is_data_frame = ro.r('is.data.frame')
117 colnames = ro.r('colnames')
116 colnames = ro.r('colnames')
118 rownames = ro.r('rownames') # with pandas, these could be used for the index
117 rownames = ro.r('rownames') # with pandas, these could be used for the index
119 names = ro.r('names')
118 names = ro.r('names')
120
119
121 if dataframe:
120 if dataframe:
122 as_data_frame = ro.r('as.data.frame')
121 as_data_frame = ro.r('as.data.frame')
123 cols = colnames(Robj)
122 cols = colnames(Robj)
124 _names = names(Robj)
123 _names = names(Robj)
125 if cols != ri.NULL:
124 if cols != ri.NULL:
126 Robj = as_data_frame(Robj)
125 Robj = as_data_frame(Robj)
127 names = tuple(np.array(cols))
126 names = tuple(np.array(cols))
128 elif _names != ri.NULL:
127 elif _names != ri.NULL:
129 names = tuple(np.array(_names))
128 names = tuple(np.array(_names))
130 else: # failed to find names
129 else: # failed to find names
131 return np.asarray(Robj)
130 return np.asarray(Robj)
132 Robj = np.rec.fromarrays(Robj, names = names)
131 Robj = np.rec.fromarrays(Robj, names = names)
133 return np.asarray(Robj)
132 return np.asarray(Robj)
134
133
135 @generic
134 @generic
136 def pyconverter(pyobj):
135 def pyconverter(pyobj):
137 """Convert Python objects to R objects. Add types using the decorator:
136 """Convert Python objects to R objects. Add types using the decorator:
138
137
139 @pyconverter.when_type
138 @pyconverter.when_type
140 """
139 """
141 return pyobj
140 return pyobj
142
141
143 # The default conversion for lists seems to make them a nested list. That has
142 # The default conversion for lists seems to make them a nested list. That has
144 # some advantages, but is rarely convenient, so for interactive use, we convert
143 # some advantages, but is rarely convenient, so for interactive use, we convert
145 # lists to a numpy array, which becomes an R vector.
144 # lists to a numpy array, which becomes an R vector.
146 @pyconverter.when_type(list)
145 @pyconverter.when_type(list)
147 def pyconverter_list(pyobj):
146 def pyconverter_list(pyobj):
148 return np.asarray(pyobj)
147 return np.asarray(pyobj)
149
148
150 if pandas2ri is None:
149 if pandas2ri is None:
151 # pandas2ri was new in rpy2 2.3.3, so for now we'll fallback to pandas'
150 # pandas2ri was new in rpy2 2.3.3, so for now we'll fallback to pandas'
152 # conversion function.
151 # conversion function.
153 try:
152 try:
154 from pandas import DataFrame
153 from pandas import DataFrame
155 from pandas.rpy.common import convert_to_r_dataframe
154 from pandas.rpy.common import convert_to_r_dataframe
156 @pyconverter.when_type(DataFrame)
155 @pyconverter.when_type(DataFrame)
157 def pyconverter_dataframe(pyobj):
156 def pyconverter_dataframe(pyobj):
158 return convert_to_r_dataframe(pyobj, strings_as_factors=True)
157 return convert_to_r_dataframe(pyobj, strings_as_factors=True)
159 except ImportError:
158 except ImportError:
160 pass
159 pass
161
160
162 @magics_class
161 @magics_class
163 class RMagics(Magics):
162 class RMagics(Magics):
164 """A set of magics useful for interactive work with R via rpy2.
163 """A set of magics useful for interactive work with R via rpy2.
165 """
164 """
166
165
167 def __init__(self, shell, Rconverter=Rconverter,
166 def __init__(self, shell, Rconverter=Rconverter,
168 pyconverter=pyconverter,
167 pyconverter=pyconverter,
169 cache_display_data=False):
168 cache_display_data=False):
170 """
169 """
171 Parameters
170 Parameters
172 ----------
171 ----------
173
172
174 shell : IPython shell
173 shell : IPython shell
175
174
176 Rconverter : callable
175 Rconverter : callable
177 To be called on values taken from R before putting them in the
176 To be called on values taken from R before putting them in the
178 IPython namespace.
177 IPython namespace.
179
178
180 pyconverter : callable
179 pyconverter : callable
181 To be called on values in ipython namespace before
180 To be called on values in ipython namespace before
182 assigning to variables in rpy2.
181 assigning to variables in rpy2.
183
182
184 cache_display_data : bool
183 cache_display_data : bool
185 If True, the published results of the final call to R are
184 If True, the published results of the final call to R are
186 cached in the variable 'display_cache'.
185 cached in the variable 'display_cache'.
187
186
188 """
187 """
189 super(RMagics, self).__init__(shell)
188 super(RMagics, self).__init__(shell)
190 self.cache_display_data = cache_display_data
189 self.cache_display_data = cache_display_data
191
190
192 self.r = ro.R()
191 self.r = ro.R()
193
192
194 self.Rstdout_cache = []
193 self.Rstdout_cache = []
195 self.pyconverter = pyconverter
194 self.pyconverter = pyconverter
196 self.Rconverter = Rconverter
195 self.Rconverter = Rconverter
197
196
198 def eval(self, line):
197 def eval(self, line):
199 '''
198 '''
200 Parse and evaluate a line of R code with rpy2.
199 Parse and evaluate a line of R code with rpy2.
201 Returns the output to R's stdout() connection,
200 Returns the output to R's stdout() connection,
202 the value generated by evaluating the code, and a
201 the value generated by evaluating the code, and a
203 boolean indicating whether the return value would be
202 boolean indicating whether the return value would be
204 visible if the line of code were evaluated in an R REPL.
203 visible if the line of code were evaluated in an R REPL.
205
204
206 R Code evaluation and visibility determination are
205 R Code evaluation and visibility determination are
207 done via an R call of the form withVisible({<code>})
206 done via an R call of the form withVisible({<code>})
208
207
209 '''
208 '''
210 old_writeconsole = ri.get_writeconsole()
209 old_writeconsole = ri.get_writeconsole()
211 ri.set_writeconsole(self.write_console)
210 ri.set_writeconsole(self.write_console)
212 try:
211 try:
213 res = ro.r("withVisible({%s\n})" % line)
212 res = ro.r("withVisible({%s\n})" % line)
214 value = res[0] #value (R object)
213 value = res[0] #value (R object)
215 visible = ro.conversion.ri2py(res[1])[0] #visible (boolean)
214 visible = ro.conversion.ri2py(res[1])[0] #visible (boolean)
216 except (ri.RRuntimeError, ValueError) as exception:
215 except (ri.RRuntimeError, ValueError) as exception:
217 warning_or_other_msg = self.flush() # otherwise next return seems to have copy of error
216 warning_or_other_msg = self.flush() # otherwise next return seems to have copy of error
218 raise RInterpreterError(line, str_to_unicode(str(exception)), warning_or_other_msg)
217 raise RInterpreterError(line, str_to_unicode(str(exception)), warning_or_other_msg)
219 text_output = self.flush()
218 text_output = self.flush()
220 ri.set_writeconsole(old_writeconsole)
219 ri.set_writeconsole(old_writeconsole)
221 return text_output, value, visible
220 return text_output, value, visible
222
221
223 def write_console(self, output):
222 def write_console(self, output):
224 '''
223 '''
225 A hook to capture R's stdout in a cache.
224 A hook to capture R's stdout in a cache.
226 '''
225 '''
227 self.Rstdout_cache.append(output)
226 self.Rstdout_cache.append(output)
228
227
229 def flush(self):
228 def flush(self):
230 '''
229 '''
231 Flush R's stdout cache to a string, returning the string.
230 Flush R's stdout cache to a string, returning the string.
232 '''
231 '''
233 value = ''.join([str_to_unicode(s, 'utf-8') for s in self.Rstdout_cache])
232 value = ''.join([str_to_unicode(s, 'utf-8') for s in self.Rstdout_cache])
234 self.Rstdout_cache = []
233 self.Rstdout_cache = []
235 return value
234 return value
236
235
237 @skip_doctest
238 @needs_local_scope
236 @needs_local_scope
239 @line_magic
237 @line_magic
240 def Rpush(self, line, local_ns=None):
238 def Rpush(self, line, local_ns=None):
241 '''
239 '''
242 A line-level magic for R that pushes
240 A line-level magic for R that pushes
243 variables from python to rpy2. The line should be made up
241 variables from python to rpy2. The line should be made up
244 of whitespace separated variable names in the IPython
242 of whitespace separated variable names in the IPython
245 namespace::
243 namespace::
246
244
247 In [7]: import numpy as np
245 In [7]: import numpy as np
248
246
249 In [8]: X = np.array([4.5,6.3,7.9])
247 In [8]: X = np.array([4.5,6.3,7.9])
250
248
251 In [9]: X.mean()
249 In [9]: X.mean()
252 Out[9]: 6.2333333333333343
250 Out[9]: 6.2333333333333343
253
251
254 In [10]: %Rpush X
252 In [10]: %Rpush X
255
253
256 In [11]: %R mean(X)
254 In [11]: %R mean(X)
257 Out[11]: array([ 6.23333333])
255 Out[11]: array([ 6.23333333])
258
256
259 '''
257 '''
260 if local_ns is None:
258 if local_ns is None:
261 local_ns = {}
259 local_ns = {}
262
260
263 inputs = line.split(' ')
261 inputs = line.split(' ')
264 for input in inputs:
262 for input in inputs:
265 try:
263 try:
266 val = local_ns[input]
264 val = local_ns[input]
267 except KeyError:
265 except KeyError:
268 try:
266 try:
269 val = self.shell.user_ns[input]
267 val = self.shell.user_ns[input]
270 except KeyError:
268 except KeyError:
271 # reraise the KeyError as a NameError so that it looks like
269 # reraise the KeyError as a NameError so that it looks like
272 # the standard python behavior when you use an unnamed
270 # the standard python behavior when you use an unnamed
273 # variable
271 # variable
274 raise NameError("name '%s' is not defined" % input)
272 raise NameError("name '%s' is not defined" % input)
275
273
276 self.r.assign(input, self.pyconverter(val))
274 self.r.assign(input, self.pyconverter(val))
277
275
278 @skip_doctest
279 @magic_arguments()
276 @magic_arguments()
280 @argument(
277 @argument(
281 '-d', '--as_dataframe', action='store_true',
278 '-d', '--as_dataframe', action='store_true',
282 default=False,
279 default=False,
283 help='Convert objects to data.frames before returning to ipython.'
280 help='Convert objects to data.frames before returning to ipython.'
284 )
281 )
285 @argument(
282 @argument(
286 'outputs',
283 'outputs',
287 nargs='*',
284 nargs='*',
288 )
285 )
289 @line_magic
286 @line_magic
290 def Rpull(self, line):
287 def Rpull(self, line):
291 '''
288 '''
292 A line-level magic for R that pulls
289 A line-level magic for R that pulls
293 variables from python to rpy2::
290 variables from python to rpy2::
294
291
295 In [18]: _ = %R x = c(3,4,6.7); y = c(4,6,7); z = c('a',3,4)
292 In [18]: _ = %R x = c(3,4,6.7); y = c(4,6,7); z = c('a',3,4)
296
293
297 In [19]: %Rpull x y z
294 In [19]: %Rpull x y z
298
295
299 In [20]: x
296 In [20]: x
300 Out[20]: array([ 3. , 4. , 6.7])
297 Out[20]: array([ 3. , 4. , 6.7])
301
298
302 In [21]: y
299 In [21]: y
303 Out[21]: array([ 4., 6., 7.])
300 Out[21]: array([ 4., 6., 7.])
304
301
305 In [22]: z
302 In [22]: z
306 Out[22]:
303 Out[22]:
307 array(['a', '3', '4'],
304 array(['a', '3', '4'],
308 dtype='|S1')
305 dtype='|S1')
309
306
310
307
311 If --as_dataframe, then each object is returned as a structured array
308 If --as_dataframe, then each object is returned as a structured array
312 after first passed through "as.data.frame" in R before
309 after first passed through "as.data.frame" in R before
313 being calling self.Rconverter.
310 being calling self.Rconverter.
314 This is useful when a structured array is desired as output, or
311 This is useful when a structured array is desired as output, or
315 when the object in R has mixed data types.
312 when the object in R has mixed data types.
316 See the %%R docstring for more examples.
313 See the %%R docstring for more examples.
317
314
318 Notes
315 Notes
319 -----
316 -----
320
317
321 Beware that R names can have '.' so this is not fool proof.
318 Beware that R names can have '.' so this is not fool proof.
322 To avoid this, don't name your R objects with '.'s...
319 To avoid this, don't name your R objects with '.'s...
323
320
324 '''
321 '''
325 args = parse_argstring(self.Rpull, line)
322 args = parse_argstring(self.Rpull, line)
326 outputs = args.outputs
323 outputs = args.outputs
327 for output in outputs:
324 for output in outputs:
328 self.shell.push({output:self.Rconverter(self.r(output),dataframe=args.as_dataframe)})
325 self.shell.push({output:self.Rconverter(self.r(output),dataframe=args.as_dataframe)})
329
326
330 @skip_doctest
331 @magic_arguments()
327 @magic_arguments()
332 @argument(
328 @argument(
333 '-d', '--as_dataframe', action='store_true',
329 '-d', '--as_dataframe', action='store_true',
334 default=False,
330 default=False,
335 help='Convert objects to data.frames before returning to ipython.'
331 help='Convert objects to data.frames before returning to ipython.'
336 )
332 )
337 @argument(
333 @argument(
338 'output',
334 'output',
339 nargs=1,
335 nargs=1,
340 type=str,
336 type=str,
341 )
337 )
342 @line_magic
338 @line_magic
343 def Rget(self, line):
339 def Rget(self, line):
344 '''
340 '''
345 Return an object from rpy2, possibly as a structured array (if possible).
341 Return an object from rpy2, possibly as a structured array (if possible).
346 Similar to Rpull except only one argument is accepted and the value is
342 Similar to Rpull except only one argument is accepted and the value is
347 returned rather than pushed to self.shell.user_ns::
343 returned rather than pushed to self.shell.user_ns::
348
344
349 In [3]: dtype=[('x', '<i4'), ('y', '<f8'), ('z', '|S1')]
345 In [3]: dtype=[('x', '<i4'), ('y', '<f8'), ('z', '|S1')]
350
346
351 In [4]: datapy = np.array([(1, 2.9, 'a'), (2, 3.5, 'b'), (3, 2.1, 'c'), (4, 5, 'e')], dtype=dtype)
347 In [4]: datapy = np.array([(1, 2.9, 'a'), (2, 3.5, 'b'), (3, 2.1, 'c'), (4, 5, 'e')], dtype=dtype)
352
348
353 In [5]: %R -i datapy
349 In [5]: %R -i datapy
354
350
355 In [6]: %Rget datapy
351 In [6]: %Rget datapy
356 Out[6]:
352 Out[6]:
357 array([['1', '2', '3', '4'],
353 array([['1', '2', '3', '4'],
358 ['2', '3', '2', '5'],
354 ['2', '3', '2', '5'],
359 ['a', 'b', 'c', 'e']],
355 ['a', 'b', 'c', 'e']],
360 dtype='|S1')
356 dtype='|S1')
361
357
362 In [7]: %Rget -d datapy
358 In [7]: %Rget -d datapy
363 Out[7]:
359 Out[7]:
364 array([(1, 2.9, 'a'), (2, 3.5, 'b'), (3, 2.1, 'c'), (4, 5.0, 'e')],
360 array([(1, 2.9, 'a'), (2, 3.5, 'b'), (3, 2.1, 'c'), (4, 5.0, 'e')],
365 dtype=[('x', '<i4'), ('y', '<f8'), ('z', '|S1')])
361 dtype=[('x', '<i4'), ('y', '<f8'), ('z', '|S1')])
366
362
367 '''
363 '''
368 args = parse_argstring(self.Rget, line)
364 args = parse_argstring(self.Rget, line)
369 output = args.output
365 output = args.output
370 return self.Rconverter(self.r(output[0]),dataframe=args.as_dataframe)
366 return self.Rconverter(self.r(output[0]),dataframe=args.as_dataframe)
371
367
372
368
373 @skip_doctest
374 @magic_arguments()
369 @magic_arguments()
375 @argument(
370 @argument(
376 '-i', '--input', action='append',
371 '-i', '--input', action='append',
377 help='Names of input variable from shell.user_ns to be assigned to R variables of the same names after calling self.pyconverter. Multiple names can be passed separated only by commas with no whitespace.'
372 help='Names of input variable from shell.user_ns to be assigned to R variables of the same names after calling self.pyconverter. Multiple names can be passed separated only by commas with no whitespace.'
378 )
373 )
379 @argument(
374 @argument(
380 '-o', '--output', action='append',
375 '-o', '--output', action='append',
381 help='Names of variables to be pushed from rpy2 to shell.user_ns after executing cell body and applying self.Rconverter. Multiple names can be passed separated only by commas with no whitespace.'
376 help='Names of variables to be pushed from rpy2 to shell.user_ns after executing cell body and applying self.Rconverter. Multiple names can be passed separated only by commas with no whitespace.'
382 )
377 )
383 @argument(
378 @argument(
384 '-w', '--width', type=int,
379 '-w', '--width', type=int,
385 help='Width of png plotting device sent as an argument to *png* in R.'
380 help='Width of png plotting device sent as an argument to *png* in R.'
386 )
381 )
387 @argument(
382 @argument(
388 '-h', '--height', type=int,
383 '-h', '--height', type=int,
389 help='Height of png plotting device sent as an argument to *png* in R.'
384 help='Height of png plotting device sent as an argument to *png* in R.'
390 )
385 )
391
386
392 @argument(
387 @argument(
393 '-d', '--dataframe', action='append',
388 '-d', '--dataframe', action='append',
394 help='Convert these objects to data.frames and return as structured arrays.'
389 help='Convert these objects to data.frames and return as structured arrays.'
395 )
390 )
396 @argument(
391 @argument(
397 '-u', '--units', type=unicode_type, choices=["px", "in", "cm", "mm"],
392 '-u', '--units', type=unicode_type, choices=["px", "in", "cm", "mm"],
398 help='Units of png plotting device sent as an argument to *png* in R. One of ["px", "in", "cm", "mm"].'
393 help='Units of png plotting device sent as an argument to *png* in R. One of ["px", "in", "cm", "mm"].'
399 )
394 )
400 @argument(
395 @argument(
401 '-r', '--res', type=int,
396 '-r', '--res', type=int,
402 help='Resolution of png plotting device sent as an argument to *png* in R. Defaults to 72 if *units* is one of ["in", "cm", "mm"].'
397 help='Resolution of png plotting device sent as an argument to *png* in R. Defaults to 72 if *units* is one of ["in", "cm", "mm"].'
403 )
398 )
404 @argument(
399 @argument(
405 '-p', '--pointsize', type=int,
400 '-p', '--pointsize', type=int,
406 help='Pointsize of png plotting device sent as an argument to *png* in R.'
401 help='Pointsize of png plotting device sent as an argument to *png* in R.'
407 )
402 )
408 @argument(
403 @argument(
409 '-b', '--bg',
404 '-b', '--bg',
410 help='Background of png plotting device sent as an argument to *png* in R.'
405 help='Background of png plotting device sent as an argument to *png* in R.'
411 )
406 )
412 @argument(
407 @argument(
413 '-n', '--noreturn',
408 '-n', '--noreturn',
414 help='Force the magic to not return anything.',
409 help='Force the magic to not return anything.',
415 action='store_true',
410 action='store_true',
416 default=False
411 default=False
417 )
412 )
418 @argument(
413 @argument(
419 'code',
414 'code',
420 nargs='*',
415 nargs='*',
421 )
416 )
422 @needs_local_scope
417 @needs_local_scope
423 @line_cell_magic
418 @line_cell_magic
424 def R(self, line, cell=None, local_ns=None):
419 def R(self, line, cell=None, local_ns=None):
425 '''
420 '''
426 Execute code in R, and pull some of the results back into the Python namespace.
421 Execute code in R, and pull some of the results back into the Python namespace.
427
422
428 In line mode, this will evaluate an expression and convert the returned value to a Python object.
423 In line mode, this will evaluate an expression and convert the returned value to a Python object.
429 The return value is determined by rpy2's behaviour of returning the result of evaluating the
424 The return value is determined by rpy2's behaviour of returning the result of evaluating the
430 final line.
425 final line.
431
426
432 Multiple R lines can be executed by joining them with semicolons::
427 Multiple R lines can be executed by joining them with semicolons::
433
428
434 In [9]: %R X=c(1,4,5,7); sd(X); mean(X)
429 In [9]: %R X=c(1,4,5,7); sd(X); mean(X)
435 Out[9]: array([ 4.25])
430 Out[9]: array([ 4.25])
436
431
437 In cell mode, this will run a block of R code. The resulting value
432 In cell mode, this will run a block of R code. The resulting value
438 is printed if it would printed be when evaluating the same code
433 is printed if it would printed be when evaluating the same code
439 within a standard R REPL.
434 within a standard R REPL.
440
435
441 Nothing is returned to python by default in cell mode::
436 Nothing is returned to python by default in cell mode::
442
437
443 In [10]: %%R
438 In [10]: %%R
444 ....: Y = c(2,4,3,9)
439 ....: Y = c(2,4,3,9)
445 ....: summary(lm(Y~X))
440 ....: summary(lm(Y~X))
446
441
447 Call:
442 Call:
448 lm(formula = Y ~ X)
443 lm(formula = Y ~ X)
449
444
450 Residuals:
445 Residuals:
451 1 2 3 4
446 1 2 3 4
452 0.88 -0.24 -2.28 1.64
447 0.88 -0.24 -2.28 1.64
453
448
454 Coefficients:
449 Coefficients:
455 Estimate Std. Error t value Pr(>|t|)
450 Estimate Std. Error t value Pr(>|t|)
456 (Intercept) 0.0800 2.3000 0.035 0.975
451 (Intercept) 0.0800 2.3000 0.035 0.975
457 X 1.0400 0.4822 2.157 0.164
452 X 1.0400 0.4822 2.157 0.164
458
453
459 Residual standard error: 2.088 on 2 degrees of freedom
454 Residual standard error: 2.088 on 2 degrees of freedom
460 Multiple R-squared: 0.6993,Adjusted R-squared: 0.549
455 Multiple R-squared: 0.6993,Adjusted R-squared: 0.549
461 F-statistic: 4.651 on 1 and 2 DF, p-value: 0.1638
456 F-statistic: 4.651 on 1 and 2 DF, p-value: 0.1638
462
457
463 In the notebook, plots are published as the output of the cell::
458 In the notebook, plots are published as the output of the cell::
464
459
465 %R plot(X, Y)
460 %R plot(X, Y)
466
461
467 will create a scatter plot of X bs Y.
462 will create a scatter plot of X bs Y.
468
463
469 If cell is not None and line has some R code, it is prepended to
464 If cell is not None and line has some R code, it is prepended to
470 the R code in cell.
465 the R code in cell.
471
466
472 Objects can be passed back and forth between rpy2 and python via the -i -o flags in line::
467 Objects can be passed back and forth between rpy2 and python via the -i -o flags in line::
473
468
474 In [14]: Z = np.array([1,4,5,10])
469 In [14]: Z = np.array([1,4,5,10])
475
470
476 In [15]: %R -i Z mean(Z)
471 In [15]: %R -i Z mean(Z)
477 Out[15]: array([ 5.])
472 Out[15]: array([ 5.])
478
473
479
474
480 In [16]: %R -o W W=Z*mean(Z)
475 In [16]: %R -o W W=Z*mean(Z)
481 Out[16]: array([ 5., 20., 25., 50.])
476 Out[16]: array([ 5., 20., 25., 50.])
482
477
483 In [17]: W
478 In [17]: W
484 Out[17]: array([ 5., 20., 25., 50.])
479 Out[17]: array([ 5., 20., 25., 50.])
485
480
486 The return value is determined by these rules:
481 The return value is determined by these rules:
487
482
488 * If the cell is not None, the magic returns None.
483 * If the cell is not None, the magic returns None.
489
484
490 * If the cell evaluates as False, the resulting value is returned
485 * If the cell evaluates as False, the resulting value is returned
491 unless the final line prints something to the console, in
486 unless the final line prints something to the console, in
492 which case None is returned.
487 which case None is returned.
493
488
494 * If the final line results in a NULL value when evaluated
489 * If the final line results in a NULL value when evaluated
495 by rpy2, then None is returned.
490 by rpy2, then None is returned.
496
491
497 * No attempt is made to convert the final value to a structured array.
492 * No attempt is made to convert the final value to a structured array.
498 Use the --dataframe flag or %Rget to push / return a structured array.
493 Use the --dataframe flag or %Rget to push / return a structured array.
499
494
500 * If the -n flag is present, there is no return value.
495 * If the -n flag is present, there is no return value.
501
496
502 * A trailing ';' will also result in no return value as the last
497 * A trailing ';' will also result in no return value as the last
503 value in the line is an empty string.
498 value in the line is an empty string.
504
499
505 The --dataframe argument will attempt to return structured arrays.
500 The --dataframe argument will attempt to return structured arrays.
506 This is useful for dataframes with
501 This is useful for dataframes with
507 mixed data types. Note also that for a data.frame,
502 mixed data types. Note also that for a data.frame,
508 if it is returned as an ndarray, it is transposed::
503 if it is returned as an ndarray, it is transposed::
509
504
510 In [18]: dtype=[('x', '<i4'), ('y', '<f8'), ('z', '|S1')]
505 In [18]: dtype=[('x', '<i4'), ('y', '<f8'), ('z', '|S1')]
511
506
512 In [19]: datapy = np.array([(1, 2.9, 'a'), (2, 3.5, 'b'), (3, 2.1, 'c'), (4, 5, 'e')], dtype=dtype)
507 In [19]: datapy = np.array([(1, 2.9, 'a'), (2, 3.5, 'b'), (3, 2.1, 'c'), (4, 5, 'e')], dtype=dtype)
513
508
514 In [20]: %%R -o datar
509 In [20]: %%R -o datar
515 datar = datapy
510 datar = datapy
516 ....:
511 ....:
517
512
518 In [21]: datar
513 In [21]: datar
519 Out[21]:
514 Out[21]:
520 array([['1', '2', '3', '4'],
515 array([['1', '2', '3', '4'],
521 ['2', '3', '2', '5'],
516 ['2', '3', '2', '5'],
522 ['a', 'b', 'c', 'e']],
517 ['a', 'b', 'c', 'e']],
523 dtype='|S1')
518 dtype='|S1')
524
519
525 In [22]: %%R -d datar
520 In [22]: %%R -d datar
526 datar = datapy
521 datar = datapy
527 ....:
522 ....:
528
523
529 In [23]: datar
524 In [23]: datar
530 Out[23]:
525 Out[23]:
531 array([(1, 2.9, 'a'), (2, 3.5, 'b'), (3, 2.1, 'c'), (4, 5.0, 'e')],
526 array([(1, 2.9, 'a'), (2, 3.5, 'b'), (3, 2.1, 'c'), (4, 5.0, 'e')],
532 dtype=[('x', '<i4'), ('y', '<f8'), ('z', '|S1')])
527 dtype=[('x', '<i4'), ('y', '<f8'), ('z', '|S1')])
533
528
534 The --dataframe argument first tries colnames, then names.
529 The --dataframe argument first tries colnames, then names.
535 If both are NULL, it returns an ndarray (i.e. unstructured)::
530 If both are NULL, it returns an ndarray (i.e. unstructured)::
536
531
537 In [1]: %R mydata=c(4,6,8.3); NULL
532 In [1]: %R mydata=c(4,6,8.3); NULL
538
533
539 In [2]: %R -d mydata
534 In [2]: %R -d mydata
540
535
541 In [3]: mydata
536 In [3]: mydata
542 Out[3]: array([ 4. , 6. , 8.3])
537 Out[3]: array([ 4. , 6. , 8.3])
543
538
544 In [4]: %R names(mydata) = c('a','b','c'); NULL
539 In [4]: %R names(mydata) = c('a','b','c'); NULL
545
540
546 In [5]: %R -d mydata
541 In [5]: %R -d mydata
547
542
548 In [6]: mydata
543 In [6]: mydata
549 Out[6]:
544 Out[6]:
550 array((4.0, 6.0, 8.3),
545 array((4.0, 6.0, 8.3),
551 dtype=[('a', '<f8'), ('b', '<f8'), ('c', '<f8')])
546 dtype=[('a', '<f8'), ('b', '<f8'), ('c', '<f8')])
552
547
553 In [7]: %R -o mydata
548 In [7]: %R -o mydata
554
549
555 In [8]: mydata
550 In [8]: mydata
556 Out[8]: array([ 4. , 6. , 8.3])
551 Out[8]: array([ 4. , 6. , 8.3])
557
552
558 '''
553 '''
559
554
560 args = parse_argstring(self.R, line)
555 args = parse_argstring(self.R, line)
561
556
562 # arguments 'code' in line are prepended to
557 # arguments 'code' in line are prepended to
563 # the cell lines
558 # the cell lines
564
559
565 if cell is None:
560 if cell is None:
566 code = ''
561 code = ''
567 return_output = True
562 return_output = True
568 line_mode = True
563 line_mode = True
569 else:
564 else:
570 code = cell
565 code = cell
571 return_output = False
566 return_output = False
572 line_mode = False
567 line_mode = False
573
568
574 code = ' '.join(args.code) + code
569 code = ' '.join(args.code) + code
575
570
576 # if there is no local namespace then default to an empty dict
571 # if there is no local namespace then default to an empty dict
577 if local_ns is None:
572 if local_ns is None:
578 local_ns = {}
573 local_ns = {}
579
574
580 if args.input:
575 if args.input:
581 for input in ','.join(args.input).split(','):
576 for input in ','.join(args.input).split(','):
582 try:
577 try:
583 val = local_ns[input]
578 val = local_ns[input]
584 except KeyError:
579 except KeyError:
585 try:
580 try:
586 val = self.shell.user_ns[input]
581 val = self.shell.user_ns[input]
587 except KeyError:
582 except KeyError:
588 raise NameError("name '%s' is not defined" % input)
583 raise NameError("name '%s' is not defined" % input)
589 self.r.assign(input, self.pyconverter(val))
584 self.r.assign(input, self.pyconverter(val))
590
585
591 if getattr(args, 'units') is not None:
586 if getattr(args, 'units') is not None:
592 if args.units != "px" and getattr(args, 'res') is None:
587 if args.units != "px" and getattr(args, 'res') is None:
593 args.res = 72
588 args.res = 72
594 args.units = '"%s"' % args.units
589 args.units = '"%s"' % args.units
595
590
596 png_argdict = dict([(n, getattr(args, n)) for n in ['units', 'res', 'height', 'width', 'bg', 'pointsize']])
591 png_argdict = dict([(n, getattr(args, n)) for n in ['units', 'res', 'height', 'width', 'bg', 'pointsize']])
597 png_args = ','.join(['%s=%s' % (o,v) for o, v in png_argdict.items() if v is not None])
592 png_args = ','.join(['%s=%s' % (o,v) for o, v in png_argdict.items() if v is not None])
598 # execute the R code in a temporary directory
593 # execute the R code in a temporary directory
599
594
600 tmpd = tempfile.mkdtemp()
595 tmpd = tempfile.mkdtemp()
601 self.r('png("%s/Rplots%%03d.png",%s)' % (tmpd.replace('\\', '/'), png_args))
596 self.r('png("%s/Rplots%%03d.png",%s)' % (tmpd.replace('\\', '/'), png_args))
602
597
603 text_output = ''
598 text_output = ''
604 try:
599 try:
605 if line_mode:
600 if line_mode:
606 for line in code.split(';'):
601 for line in code.split(';'):
607 text_result, result, visible = self.eval(line)
602 text_result, result, visible = self.eval(line)
608 text_output += text_result
603 text_output += text_result
609 if text_result:
604 if text_result:
610 # the last line printed something to the console so we won't return it
605 # the last line printed something to the console so we won't return it
611 return_output = False
606 return_output = False
612 else:
607 else:
613 text_result, result, visible = self.eval(code)
608 text_result, result, visible = self.eval(code)
614 text_output += text_result
609 text_output += text_result
615 if visible:
610 if visible:
616 old_writeconsole = ri.get_writeconsole()
611 old_writeconsole = ri.get_writeconsole()
617 ri.set_writeconsole(self.write_console)
612 ri.set_writeconsole(self.write_console)
618 ro.r.show(result)
613 ro.r.show(result)
619 text_output += self.flush()
614 text_output += self.flush()
620 ri.set_writeconsole(old_writeconsole)
615 ri.set_writeconsole(old_writeconsole)
621
616
622 except RInterpreterError as e:
617 except RInterpreterError as e:
623 print(e.stdout)
618 print(e.stdout)
624 if not e.stdout.endswith(e.err):
619 if not e.stdout.endswith(e.err):
625 print(e.err)
620 print(e.err)
626 rmtree(tmpd)
621 rmtree(tmpd)
627 return
622 return
628 finally:
623 finally:
629 self.r('dev.off()')
624 self.r('dev.off()')
630
625
631 # read out all the saved .png files
626 # read out all the saved .png files
632
627
633 images = [open(imgfile, 'rb').read() for imgfile in glob("%s/Rplots*png" % tmpd)]
628 images = [open(imgfile, 'rb').read() for imgfile in glob("%s/Rplots*png" % tmpd)]
634
629
635 # now publish the images
630 # now publish the images
636 # mimicking IPython/zmq/pylab/backend_inline.py
631 # mimicking IPython/zmq/pylab/backend_inline.py
637 fmt = 'png'
632 fmt = 'png'
638 mimetypes = { 'png' : 'image/png', 'svg' : 'image/svg+xml' }
633 mimetypes = { 'png' : 'image/png', 'svg' : 'image/svg+xml' }
639 mime = mimetypes[fmt]
634 mime = mimetypes[fmt]
640
635
641 # publish the printed R objects, if any
636 # publish the printed R objects, if any
642
637
643 display_data = []
638 display_data = []
644 if text_output:
639 if text_output:
645 display_data.append(('RMagic.R', {'text/plain':text_output}))
640 display_data.append(('RMagic.R', {'text/plain':text_output}))
646
641
647 # flush text streams before sending figures, helps a little with output
642 # flush text streams before sending figures, helps a little with output
648 for image in images:
643 for image in images:
649 # synchronization in the console (though it's a bandaid, not a real sln)
644 # synchronization in the console (though it's a bandaid, not a real sln)
650 sys.stdout.flush(); sys.stderr.flush()
645 sys.stdout.flush(); sys.stderr.flush()
651 display_data.append(('RMagic.R', {mime: image}))
646 display_data.append(('RMagic.R', {mime: image}))
652
647
653 # kill the temporary directory
648 # kill the temporary directory
654 rmtree(tmpd)
649 rmtree(tmpd)
655
650
656 # try to turn every output into a numpy array
651 # try to turn every output into a numpy array
657 # this means that output are assumed to be castable
652 # this means that output are assumed to be castable
658 # as numpy arrays
653 # as numpy arrays
659
654
660 if args.output:
655 if args.output:
661 for output in ','.join(args.output).split(','):
656 for output in ','.join(args.output).split(','):
662 self.shell.push({output:self.Rconverter(self.r(output), dataframe=False)})
657 self.shell.push({output:self.Rconverter(self.r(output), dataframe=False)})
663
658
664 if args.dataframe:
659 if args.dataframe:
665 for output in ','.join(args.dataframe).split(','):
660 for output in ','.join(args.dataframe).split(','):
666 self.shell.push({output:self.Rconverter(self.r(output), dataframe=True)})
661 self.shell.push({output:self.Rconverter(self.r(output), dataframe=True)})
667
662
668 for tag, disp_d in display_data:
663 for tag, disp_d in display_data:
669 publish_display_data(data=disp_d, source=tag)
664 publish_display_data(data=disp_d, source=tag)
670
665
671 # this will keep a reference to the display_data
666 # this will keep a reference to the display_data
672 # which might be useful to other objects who happen to use
667 # which might be useful to other objects who happen to use
673 # this method
668 # this method
674
669
675 if self.cache_display_data:
670 if self.cache_display_data:
676 self.display_cache = display_data
671 self.display_cache = display_data
677
672
678 # if in line mode and return_output, return the result as an ndarray
673 # if in line mode and return_output, return the result as an ndarray
679 if return_output and not args.noreturn:
674 if return_output and not args.noreturn:
680 if result != ri.NULL:
675 if result != ri.NULL:
681 return self.Rconverter(result, dataframe=False)
676 return self.Rconverter(result, dataframe=False)
682
677
683 __doc__ = __doc__.format(
678 __doc__ = __doc__.format(
684 R_DOC = dedent(RMagics.R.__doc__),
679 R_DOC = dedent(RMagics.R.__doc__),
685 RPUSH_DOC = dedent(RMagics.Rpush.__doc__),
680 RPUSH_DOC = dedent(RMagics.Rpush.__doc__),
686 RPULL_DOC = dedent(RMagics.Rpull.__doc__),
681 RPULL_DOC = dedent(RMagics.Rpull.__doc__),
687 RGET_DOC = dedent(RMagics.Rget.__doc__)
682 RGET_DOC = dedent(RMagics.Rget.__doc__)
688 )
683 )
689
684
690
685
691 def load_ipython_extension(ip):
686 def load_ipython_extension(ip):
692 """Load the extension in IPython."""
687 """Load the extension in IPython."""
693 warnings.warn("The rmagic extension in IPython is deprecated in favour of "
688 warnings.warn("The rmagic extension in IPython is deprecated in favour of "
694 "rpy2.ipython. If available, that will be loaded instead.\n"
689 "rpy2.ipython. If available, that will be loaded instead.\n"
695 "http://rpy.sourceforge.net/")
690 "http://rpy.sourceforge.net/")
696 try:
691 try:
697 import rpy2.ipython
692 import rpy2.ipython
698 except ImportError:
693 except ImportError:
699 pass # Fall back to our own implementation for now
694 pass # Fall back to our own implementation for now
700 else:
695 else:
701 return rpy2.ipython.load_ipython_extension(ip)
696 return rpy2.ipython.load_ipython_extension(ip)
702
697
703 ip.register_magics(RMagics)
698 ip.register_magics(RMagics)
704 # Initialising rpy2 interferes with readline. Since, at this point, we've
699 # Initialising rpy2 interferes with readline. Since, at this point, we've
705 # probably just loaded rpy2, we reset the delimiters. See issue gh-2759.
700 # probably just loaded rpy2, we reset the delimiters. See issue gh-2759.
706 if ip.has_readline:
701 if ip.has_readline:
707 ip.readline.set_completer_delims(ip.readline_delims)
702 ip.readline.set_completer_delims(ip.readline_delims)
@@ -1,243 +1,241 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """
2 """
3 %store magic for lightweight persistence.
3 %store magic for lightweight persistence.
4
4
5 Stores variables, aliases and macros in IPython's database.
5 Stores variables, aliases and macros in IPython's database.
6
6
7 To automatically restore stored variables at startup, add this to your
7 To automatically restore stored variables at startup, add this to your
8 :file:`ipython_config.py` file::
8 :file:`ipython_config.py` file::
9
9
10 c.StoreMagics.autorestore = True
10 c.StoreMagics.autorestore = True
11 """
11 """
12 from __future__ import print_function
12 from __future__ import print_function
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Copyright (c) 2012, The IPython Development Team.
14 # Copyright (c) 2012, The IPython Development Team.
15 #
15 #
16 # Distributed under the terms of the Modified BSD License.
16 # Distributed under the terms of the Modified BSD License.
17 #
17 #
18 # The full license is in the file COPYING.txt, distributed with this software.
18 # The full license is in the file COPYING.txt, distributed with this software.
19 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
20
20
21 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
22 # Imports
22 # Imports
23 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
24
24
25 # Stdlib
25 # Stdlib
26 import inspect, os, sys, textwrap
26 import inspect, os, sys, textwrap
27
27
28 # Our own
28 # Our own
29 from IPython.core.error import UsageError
29 from IPython.core.error import UsageError
30 from IPython.core.magic import Magics, magics_class, line_magic
30 from IPython.core.magic import Magics, magics_class, line_magic
31 from IPython.testing.skipdoctest import skip_doctest
32 from IPython.utils.traitlets import Bool
31 from IPython.utils.traitlets import Bool
33 from IPython.utils.py3compat import string_types
32 from IPython.utils.py3compat import string_types
34
33
35 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
36 # Functions and classes
35 # Functions and classes
37 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
38
37
39 def restore_aliases(ip):
38 def restore_aliases(ip):
40 staliases = ip.db.get('stored_aliases', {})
39 staliases = ip.db.get('stored_aliases', {})
41 for k,v in staliases.items():
40 for k,v in staliases.items():
42 #print "restore alias",k,v # dbg
41 #print "restore alias",k,v # dbg
43 #self.alias_table[k] = v
42 #self.alias_table[k] = v
44 ip.alias_manager.define_alias(k,v)
43 ip.alias_manager.define_alias(k,v)
45
44
46
45
47 def refresh_variables(ip):
46 def refresh_variables(ip):
48 db = ip.db
47 db = ip.db
49 for key in db.keys('autorestore/*'):
48 for key in db.keys('autorestore/*'):
50 # strip autorestore
49 # strip autorestore
51 justkey = os.path.basename(key)
50 justkey = os.path.basename(key)
52 try:
51 try:
53 obj = db[key]
52 obj = db[key]
54 except KeyError:
53 except KeyError:
55 print("Unable to restore variable '%s', ignoring (use %%store -d to forget!)" % justkey)
54 print("Unable to restore variable '%s', ignoring (use %%store -d to forget!)" % justkey)
56 print("The error was:", sys.exc_info()[0])
55 print("The error was:", sys.exc_info()[0])
57 else:
56 else:
58 #print "restored",justkey,"=",obj #dbg
57 #print "restored",justkey,"=",obj #dbg
59 ip.user_ns[justkey] = obj
58 ip.user_ns[justkey] = obj
60
59
61
60
62 def restore_dhist(ip):
61 def restore_dhist(ip):
63 ip.user_ns['_dh'] = ip.db.get('dhist',[])
62 ip.user_ns['_dh'] = ip.db.get('dhist',[])
64
63
65
64
66 def restore_data(ip):
65 def restore_data(ip):
67 refresh_variables(ip)
66 refresh_variables(ip)
68 restore_aliases(ip)
67 restore_aliases(ip)
69 restore_dhist(ip)
68 restore_dhist(ip)
70
69
71
70
72 @magics_class
71 @magics_class
73 class StoreMagics(Magics):
72 class StoreMagics(Magics):
74 """Lightweight persistence for python variables.
73 """Lightweight persistence for python variables.
75
74
76 Provides the %store magic."""
75 Provides the %store magic."""
77
76
78 autorestore = Bool(False, config=True, help=
77 autorestore = Bool(False, config=True, help=
79 """If True, any %store-d variables will be automatically restored
78 """If True, any %store-d variables will be automatically restored
80 when IPython starts.
79 when IPython starts.
81 """
80 """
82 )
81 )
83
82
84 def __init__(self, shell):
83 def __init__(self, shell):
85 super(StoreMagics, self).__init__(shell=shell)
84 super(StoreMagics, self).__init__(shell=shell)
86 self.shell.configurables.append(self)
85 self.shell.configurables.append(self)
87 if self.autorestore:
86 if self.autorestore:
88 restore_data(self.shell)
87 restore_data(self.shell)
89
88
90 @skip_doctest
91 @line_magic
89 @line_magic
92 def store(self, parameter_s=''):
90 def store(self, parameter_s=''):
93 """Lightweight persistence for python variables.
91 """Lightweight persistence for python variables.
94
92
95 Example::
93 Example::
96
94
97 In [1]: l = ['hello',10,'world']
95 In [1]: l = ['hello',10,'world']
98 In [2]: %store l
96 In [2]: %store l
99 In [3]: exit
97 In [3]: exit
100
98
101 (IPython session is closed and started again...)
99 (IPython session is closed and started again...)
102
100
103 ville@badger:~$ ipython
101 ville@badger:~$ ipython
104 In [1]: l
102 In [1]: l
105 NameError: name 'l' is not defined
103 NameError: name 'l' is not defined
106 In [2]: %store -r
104 In [2]: %store -r
107 In [3]: l
105 In [3]: l
108 Out[3]: ['hello', 10, 'world']
106 Out[3]: ['hello', 10, 'world']
109
107
110 Usage:
108 Usage:
111
109
112 * ``%store`` - Show list of all variables and their current
110 * ``%store`` - Show list of all variables and their current
113 values
111 values
114 * ``%store spam`` - Store the *current* value of the variable spam
112 * ``%store spam`` - Store the *current* value of the variable spam
115 to disk
113 to disk
116 * ``%store -d spam`` - Remove the variable and its value from storage
114 * ``%store -d spam`` - Remove the variable and its value from storage
117 * ``%store -z`` - Remove all variables from storage
115 * ``%store -z`` - Remove all variables from storage
118 * ``%store -r`` - Refresh all variables from store (overwrite
116 * ``%store -r`` - Refresh all variables from store (overwrite
119 current vals)
117 current vals)
120 * ``%store -r spam bar`` - Refresh specified variables from store
118 * ``%store -r spam bar`` - Refresh specified variables from store
121 (delete current val)
119 (delete current val)
122 * ``%store foo >a.txt`` - Store value of foo to new file a.txt
120 * ``%store foo >a.txt`` - Store value of foo to new file a.txt
123 * ``%store foo >>a.txt`` - Append value of foo to file a.txt
121 * ``%store foo >>a.txt`` - Append value of foo to file a.txt
124
122
125 It should be noted that if you change the value of a variable, you
123 It should be noted that if you change the value of a variable, you
126 need to %store it again if you want to persist the new value.
124 need to %store it again if you want to persist the new value.
127
125
128 Note also that the variables will need to be pickleable; most basic
126 Note also that the variables will need to be pickleable; most basic
129 python types can be safely %store'd.
127 python types can be safely %store'd.
130
128
131 Also aliases can be %store'd across sessions.
129 Also aliases can be %store'd across sessions.
132 """
130 """
133
131
134 opts,argsl = self.parse_options(parameter_s,'drz',mode='string')
132 opts,argsl = self.parse_options(parameter_s,'drz',mode='string')
135 args = argsl.split(None,1)
133 args = argsl.split(None,1)
136 ip = self.shell
134 ip = self.shell
137 db = ip.db
135 db = ip.db
138 # delete
136 # delete
139 if 'd' in opts:
137 if 'd' in opts:
140 try:
138 try:
141 todel = args[0]
139 todel = args[0]
142 except IndexError:
140 except IndexError:
143 raise UsageError('You must provide the variable to forget')
141 raise UsageError('You must provide the variable to forget')
144 else:
142 else:
145 try:
143 try:
146 del db['autorestore/' + todel]
144 del db['autorestore/' + todel]
147 except:
145 except:
148 raise UsageError("Can't delete variable '%s'" % todel)
146 raise UsageError("Can't delete variable '%s'" % todel)
149 # reset
147 # reset
150 elif 'z' in opts:
148 elif 'z' in opts:
151 for k in db.keys('autorestore/*'):
149 for k in db.keys('autorestore/*'):
152 del db[k]
150 del db[k]
153
151
154 elif 'r' in opts:
152 elif 'r' in opts:
155 if args:
153 if args:
156 for arg in args:
154 for arg in args:
157 try:
155 try:
158 obj = db['autorestore/' + arg]
156 obj = db['autorestore/' + arg]
159 except KeyError:
157 except KeyError:
160 print("no stored variable %s" % arg)
158 print("no stored variable %s" % arg)
161 else:
159 else:
162 ip.user_ns[arg] = obj
160 ip.user_ns[arg] = obj
163 else:
161 else:
164 restore_data(ip)
162 restore_data(ip)
165
163
166 # run without arguments -> list variables & values
164 # run without arguments -> list variables & values
167 elif not args:
165 elif not args:
168 vars = db.keys('autorestore/*')
166 vars = db.keys('autorestore/*')
169 vars.sort()
167 vars.sort()
170 if vars:
168 if vars:
171 size = max(map(len, vars))
169 size = max(map(len, vars))
172 else:
170 else:
173 size = 0
171 size = 0
174
172
175 print('Stored variables and their in-db values:')
173 print('Stored variables and their in-db values:')
176 fmt = '%-'+str(size)+'s -> %s'
174 fmt = '%-'+str(size)+'s -> %s'
177 get = db.get
175 get = db.get
178 for var in vars:
176 for var in vars:
179 justkey = os.path.basename(var)
177 justkey = os.path.basename(var)
180 # print 30 first characters from every var
178 # print 30 first characters from every var
181 print(fmt % (justkey, repr(get(var, '<unavailable>'))[:50]))
179 print(fmt % (justkey, repr(get(var, '<unavailable>'))[:50]))
182
180
183 # default action - store the variable
181 # default action - store the variable
184 else:
182 else:
185 # %store foo >file.txt or >>file.txt
183 # %store foo >file.txt or >>file.txt
186 if len(args) > 1 and args[1].startswith('>'):
184 if len(args) > 1 and args[1].startswith('>'):
187 fnam = os.path.expanduser(args[1].lstrip('>').lstrip())
185 fnam = os.path.expanduser(args[1].lstrip('>').lstrip())
188 if args[1].startswith('>>'):
186 if args[1].startswith('>>'):
189 fil = open(fnam, 'a')
187 fil = open(fnam, 'a')
190 else:
188 else:
191 fil = open(fnam, 'w')
189 fil = open(fnam, 'w')
192 obj = ip.ev(args[0])
190 obj = ip.ev(args[0])
193 print("Writing '%s' (%s) to file '%s'." % (args[0],
191 print("Writing '%s' (%s) to file '%s'." % (args[0],
194 obj.__class__.__name__, fnam))
192 obj.__class__.__name__, fnam))
195
193
196
194
197 if not isinstance (obj, string_types):
195 if not isinstance (obj, string_types):
198 from pprint import pprint
196 from pprint import pprint
199 pprint(obj, fil)
197 pprint(obj, fil)
200 else:
198 else:
201 fil.write(obj)
199 fil.write(obj)
202 if not obj.endswith('\n'):
200 if not obj.endswith('\n'):
203 fil.write('\n')
201 fil.write('\n')
204
202
205 fil.close()
203 fil.close()
206 return
204 return
207
205
208 # %store foo
206 # %store foo
209 try:
207 try:
210 obj = ip.user_ns[args[0]]
208 obj = ip.user_ns[args[0]]
211 except KeyError:
209 except KeyError:
212 # it might be an alias
210 # it might be an alias
213 name = args[0]
211 name = args[0]
214 try:
212 try:
215 cmd = ip.alias_manager.retrieve_alias(name)
213 cmd = ip.alias_manager.retrieve_alias(name)
216 except ValueError:
214 except ValueError:
217 raise UsageError("Unknown variable '%s'" % name)
215 raise UsageError("Unknown variable '%s'" % name)
218
216
219 staliases = db.get('stored_aliases',{})
217 staliases = db.get('stored_aliases',{})
220 staliases[name] = cmd
218 staliases[name] = cmd
221 db['stored_aliases'] = staliases
219 db['stored_aliases'] = staliases
222 print("Alias stored: %s (%s)" % (name, cmd))
220 print("Alias stored: %s (%s)" % (name, cmd))
223 return
221 return
224
222
225 else:
223 else:
226 modname = getattr(inspect.getmodule(obj), '__name__', '')
224 modname = getattr(inspect.getmodule(obj), '__name__', '')
227 if modname == '__main__':
225 if modname == '__main__':
228 print(textwrap.dedent("""\
226 print(textwrap.dedent("""\
229 Warning:%s is %s
227 Warning:%s is %s
230 Proper storage of interactively declared classes (or instances
228 Proper storage of interactively declared classes (or instances
231 of those classes) is not possible! Only instances
229 of those classes) is not possible! Only instances
232 of classes in real modules on file system can be %%store'd.
230 of classes in real modules on file system can be %%store'd.
233 """ % (args[0], obj) ))
231 """ % (args[0], obj) ))
234 return
232 return
235 #pickled = pickle.dumps(obj)
233 #pickled = pickle.dumps(obj)
236 db[ 'autorestore/' + args[0] ] = obj
234 db[ 'autorestore/' + args[0] ] = obj
237 print("Stored '%s' (%s)" % (args[0], obj.__class__.__name__))
235 print("Stored '%s' (%s)" % (args[0], obj.__class__.__name__))
238
236
239
237
240 def load_ipython_extension(ip):
238 def load_ipython_extension(ip):
241 """Load the extension in IPython."""
239 """Load the extension in IPython."""
242 ip.register_magics(StoreMagics)
240 ip.register_magics(StoreMagics)
243
241
@@ -1,111 +1,108 b''
1 """Link and DirectionalLink classes.
1 """Link and DirectionalLink classes.
2
2
3 Propagate changes between widgets on the javascript side
3 Propagate changes between widgets on the javascript side
4 """
4 """
5
5
6 # Copyright (c) IPython Development Team.
6 # Copyright (c) IPython Development Team.
7 # Distributed under the terms of the Modified BSD License.
7 # Distributed under the terms of the Modified BSD License.
8
8
9 from .widget import Widget
9 from .widget import Widget
10 from IPython.testing.skipdoctest import skip_doctest
11 from IPython.utils.traitlets import Unicode, Tuple, List,Instance, TraitError
10 from IPython.utils.traitlets import Unicode, Tuple, List,Instance, TraitError
12
11
13 class WidgetTraitTuple(Tuple):
12 class WidgetTraitTuple(Tuple):
14 """Traitlet for validating a single (Widget, 'trait_name') pair"""
13 """Traitlet for validating a single (Widget, 'trait_name') pair"""
15
14
16 def __init__(self, **kwargs):
15 def __init__(self, **kwargs):
17 super(WidgetTraitTuple, self).__init__(Instance(Widget), Unicode, **kwargs)
16 super(WidgetTraitTuple, self).__init__(Instance(Widget), Unicode, **kwargs)
18
17
19 def validate_elements(self, obj, value):
18 def validate_elements(self, obj, value):
20 value = super(WidgetTraitTuple, self).validate_elements(obj, value)
19 value = super(WidgetTraitTuple, self).validate_elements(obj, value)
21 widget, trait_name = value
20 widget, trait_name = value
22 trait = widget.traits().get(trait_name)
21 trait = widget.traits().get(trait_name)
23 trait_repr = "%s.%s" % (widget.__class__.__name__, trait_name)
22 trait_repr = "%s.%s" % (widget.__class__.__name__, trait_name)
24 # Can't raise TraitError because the parent will swallow the message
23 # Can't raise TraitError because the parent will swallow the message
25 # and throw it away in a new, less informative TraitError
24 # and throw it away in a new, less informative TraitError
26 if trait is None:
25 if trait is None:
27 raise TypeError("No such trait: %s" % trait_repr)
26 raise TypeError("No such trait: %s" % trait_repr)
28 elif not trait.get_metadata('sync'):
27 elif not trait.get_metadata('sync'):
29 raise TypeError("%s cannot be synced" % trait_repr)
28 raise TypeError("%s cannot be synced" % trait_repr)
30
29
31 return value
30 return value
32
31
33
32
34 class Link(Widget):
33 class Link(Widget):
35 """Link Widget
34 """Link Widget
36
35
37 one trait:
36 one trait:
38 widgets, a list of (widget, 'trait_name') tuples which should be linked in the frontend.
37 widgets, a list of (widget, 'trait_name') tuples which should be linked in the frontend.
39 """
38 """
40 _model_name = Unicode('LinkModel', sync=True)
39 _model_name = Unicode('LinkModel', sync=True)
41 widgets = List(WidgetTraitTuple, sync=True)
40 widgets = List(WidgetTraitTuple, sync=True)
42
41
43 def __init__(self, widgets, **kwargs):
42 def __init__(self, widgets, **kwargs):
44 if len(widgets) < 2:
43 if len(widgets) < 2:
45 raise TypeError("Require at least two widgets to link")
44 raise TypeError("Require at least two widgets to link")
46 kwargs['widgets'] = widgets
45 kwargs['widgets'] = widgets
47 super(Link, self).__init__(**kwargs)
46 super(Link, self).__init__(**kwargs)
48
47
49 # for compatibility with traitlet links
48 # for compatibility with traitlet links
50 def unlink(self):
49 def unlink(self):
51 self.close()
50 self.close()
52
51
53
52
54 @skip_doctest
55 def jslink(*args):
53 def jslink(*args):
56 """Link traits from different widgets together on the frontend so they remain in sync.
54 """Link traits from different widgets together on the frontend so they remain in sync.
57
55
58 Parameters
56 Parameters
59 ----------
57 ----------
60 *args : two or more (Widget, 'trait_name') tuples that should be kept in sync.
58 *args : two or more (Widget, 'trait_name') tuples that should be kept in sync.
61
59
62 Examples
60 Examples
63 --------
61 --------
64
62
65 >>> c = link((widget1, 'value'), (widget2, 'value'), (widget3, 'value'))
63 >>> c = link((widget1, 'value'), (widget2, 'value'), (widget3, 'value'))
66 """
64 """
67 return Link(widgets=args)
65 return Link(widgets=args)
68
66
69
67
70 class DirectionalLink(Widget):
68 class DirectionalLink(Widget):
71 """A directional link
69 """A directional link
72
70
73 source: a (Widget, 'trait_name') tuple for the source trait
71 source: a (Widget, 'trait_name') tuple for the source trait
74 targets: one or more (Widget, 'trait_name') tuples that should be updated
72 targets: one or more (Widget, 'trait_name') tuples that should be updated
75 when the source trait changes.
73 when the source trait changes.
76 """
74 """
77 _model_name = Unicode('DirectionalLinkModel', sync=True)
75 _model_name = Unicode('DirectionalLinkModel', sync=True)
78 targets = List(WidgetTraitTuple, sync=True)
76 targets = List(WidgetTraitTuple, sync=True)
79 source = WidgetTraitTuple(sync=True)
77 source = WidgetTraitTuple(sync=True)
80
78
81 # Does not quite behave like other widgets but reproduces
79 # Does not quite behave like other widgets but reproduces
82 # the behavior of IPython.utils.traitlets.directional_link
80 # the behavior of IPython.utils.traitlets.directional_link
83 def __init__(self, source, targets, **kwargs):
81 def __init__(self, source, targets, **kwargs):
84 if len(targets) < 1:
82 if len(targets) < 1:
85 raise TypeError("Require at least two widgets to link")
83 raise TypeError("Require at least two widgets to link")
86
84
87 kwargs['source'] = source
85 kwargs['source'] = source
88 kwargs['targets'] = targets
86 kwargs['targets'] = targets
89 super(DirectionalLink, self).__init__(**kwargs)
87 super(DirectionalLink, self).__init__(**kwargs)
90
88
91 # for compatibility with traitlet links
89 # for compatibility with traitlet links
92 def unlink(self):
90 def unlink(self):
93 self.close()
91 self.close()
94
92
95 @skip_doctest
96 def jsdlink(source, *targets):
93 def jsdlink(source, *targets):
97 """Link the trait of a source widget with traits of target widgets in the frontend.
94 """Link the trait of a source widget with traits of target widgets in the frontend.
98
95
99 Parameters
96 Parameters
100 ----------
97 ----------
101 source : a (Widget, 'trait_name') tuple for the source trait
98 source : a (Widget, 'trait_name') tuple for the source trait
102 *targets : one or more (Widget, 'trait_name') tuples that should be updated
99 *targets : one or more (Widget, 'trait_name') tuples that should be updated
103 when the source trait changes.
100 when the source trait changes.
104
101
105 Examples
102 Examples
106 --------
103 --------
107
104
108 >>> c = dlink((src_widget, 'value'), (tgt_widget1, 'value'), (tgt_widget2, 'value'))
105 >>> c = dlink((src_widget, 'value'), (tgt_widget1, 'value'), (tgt_widget2, 'value'))
109 """
106 """
110 return DirectionalLink(source=source, targets=targets)
107 return DirectionalLink(source=source, targets=targets)
111
108
@@ -1,78 +1,76 b''
1 """Output class.
1 """Output class.
2
2
3 Represents a widget that can be used to display output within the widget area.
3 Represents a widget that can be used to display output within the widget area.
4 """
4 """
5
5
6 # Copyright (c) IPython Development Team.
6 # Copyright (c) IPython Development Team.
7 # Distributed under the terms of the Modified BSD License.
7 # Distributed under the terms of the Modified BSD License.
8
8
9 from .widget import DOMWidget
9 from .widget import DOMWidget
10 import sys
10 import sys
11 from IPython.utils.traitlets import Unicode, List
11 from IPython.utils.traitlets import Unicode, List
12 from IPython.display import clear_output
12 from IPython.display import clear_output
13 from IPython.testing.skipdoctest import skip_doctest
14 from IPython.kernel.zmq.session import Message
13 from IPython.kernel.zmq.session import Message
15
14
16 @skip_doctest
17 class Output(DOMWidget):
15 class Output(DOMWidget):
18 """Widget used as a context manager to display output.
16 """Widget used as a context manager to display output.
19
17
20 This widget can capture and display stdout, stderr, and rich output. To use
18 This widget can capture and display stdout, stderr, and rich output. To use
21 it, create an instance of it and display it. Then use it as a context
19 it, create an instance of it and display it. Then use it as a context
22 manager. Any output produced while in it's context will be captured and
20 manager. Any output produced while in it's context will be captured and
23 displayed in it instead of the standard output area.
21 displayed in it instead of the standard output area.
24
22
25 Example
23 Example
26 from IPython.html import widgets
24 from IPython.html import widgets
27 from IPython.display import display
25 from IPython.display import display
28 out = widgets.Output()
26 out = widgets.Output()
29 display(out)
27 display(out)
30
28
31 print('prints to output area')
29 print('prints to output area')
32
30
33 with out:
31 with out:
34 print('prints to output widget')"""
32 print('prints to output widget')"""
35 _view_name = Unicode('OutputView', sync=True)
33 _view_name = Unicode('OutputView', sync=True)
36
34
37 def clear_output(self, *pargs, **kwargs):
35 def clear_output(self, *pargs, **kwargs):
38 with self:
36 with self:
39 clear_output(*pargs, **kwargs)
37 clear_output(*pargs, **kwargs)
40
38
41 def __enter__(self):
39 def __enter__(self):
42 """Called upon entering output widget context manager."""
40 """Called upon entering output widget context manager."""
43 self._flush()
41 self._flush()
44 kernel = get_ipython().kernel
42 kernel = get_ipython().kernel
45 session = kernel.session
43 session = kernel.session
46 send = session.send
44 send = session.send
47 self._original_send = send
45 self._original_send = send
48 self._session = session
46 self._session = session
49
47
50 def send_hook(stream, msg_or_type, content=None, parent=None, ident=None,
48 def send_hook(stream, msg_or_type, content=None, parent=None, ident=None,
51 buffers=None, track=False, header=None, metadata=None):
49 buffers=None, track=False, header=None, metadata=None):
52
50
53 # Handle both prebuild messages and unbuilt messages.
51 # Handle both prebuild messages and unbuilt messages.
54 if isinstance(msg_or_type, (Message, dict)):
52 if isinstance(msg_or_type, (Message, dict)):
55 msg_type = msg_or_type['msg_type']
53 msg_type = msg_or_type['msg_type']
56 msg = dict(msg_or_type)
54 msg = dict(msg_or_type)
57 else:
55 else:
58 msg_type = msg_or_type
56 msg_type = msg_or_type
59 msg = session.msg(msg_type, content=content, parent=parent,
57 msg = session.msg(msg_type, content=content, parent=parent,
60 header=header, metadata=metadata)
58 header=header, metadata=metadata)
61
59
62 # If this is a message type that we want to forward, forward it.
60 # If this is a message type that we want to forward, forward it.
63 if stream is kernel.iopub_socket and msg_type in ['clear_output', 'stream', 'display_data']:
61 if stream is kernel.iopub_socket and msg_type in ['clear_output', 'stream', 'display_data']:
64 self.send(msg)
62 self.send(msg)
65 else:
63 else:
66 send(stream, msg, ident=ident, buffers=buffers, track=track)
64 send(stream, msg, ident=ident, buffers=buffers, track=track)
67
65
68 session.send = send_hook
66 session.send = send_hook
69
67
70 def __exit__(self, exception_type, exception_value, traceback):
68 def __exit__(self, exception_type, exception_value, traceback):
71 """Called upon exiting output widget context manager."""
69 """Called upon exiting output widget context manager."""
72 self._flush()
70 self._flush()
73 self._session.send = self._original_send
71 self._session.send = self._original_send
74
72
75 def _flush(self):
73 def _flush(self):
76 """Flush stdout and stderr buffers."""
74 """Flush stdout and stderr buffers."""
77 sys.stdout.flush()
75 sys.stdout.flush()
78 sys.stderr.flush()
76 sys.stderr.flush()
@@ -1,512 +1,510 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """
2 """
3 Defines a variety of Pygments lexers for highlighting IPython code.
3 Defines a variety of Pygments lexers for highlighting IPython code.
4
4
5 This includes:
5 This includes:
6
6
7 IPythonLexer, IPython3Lexer
7 IPythonLexer, IPython3Lexer
8 Lexers for pure IPython (python + magic/shell commands)
8 Lexers for pure IPython (python + magic/shell commands)
9
9
10 IPythonPartialTracebackLexer, IPythonTracebackLexer
10 IPythonPartialTracebackLexer, IPythonTracebackLexer
11 Supports 2.x and 3.x via keyword `python3`. The partial traceback
11 Supports 2.x and 3.x via keyword `python3`. The partial traceback
12 lexer reads everything but the Python code appearing in a traceback.
12 lexer reads everything but the Python code appearing in a traceback.
13 The full lexer combines the partial lexer with an IPython lexer.
13 The full lexer combines the partial lexer with an IPython lexer.
14
14
15 IPythonConsoleLexer
15 IPythonConsoleLexer
16 A lexer for IPython console sessions, with support for tracebacks.
16 A lexer for IPython console sessions, with support for tracebacks.
17
17
18 IPyLexer
18 IPyLexer
19 A friendly lexer which examines the first line of text and from it,
19 A friendly lexer which examines the first line of text and from it,
20 decides whether to use an IPython lexer or an IPython console lexer.
20 decides whether to use an IPython lexer or an IPython console lexer.
21 This is probably the only lexer that needs to be explicitly added
21 This is probably the only lexer that needs to be explicitly added
22 to Pygments.
22 to Pygments.
23
23
24 """
24 """
25 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
26 # Copyright (c) 2013, the IPython Development Team.
26 # Copyright (c) 2013, the IPython Development Team.
27 #
27 #
28 # Distributed under the terms of the Modified BSD License.
28 # Distributed under the terms of the Modified BSD License.
29 #
29 #
30 # The full license is in the file COPYING.txt, distributed with this software.
30 # The full license is in the file COPYING.txt, distributed with this software.
31 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
32
32
33 # Standard library
33 # Standard library
34 import re
34 import re
35
35
36 # Third party
36 # Third party
37 from pygments.lexers import BashLexer, PythonLexer, Python3Lexer
37 from pygments.lexers import BashLexer, PythonLexer, Python3Lexer
38 from pygments.lexer import (
38 from pygments.lexer import (
39 Lexer, DelegatingLexer, RegexLexer, do_insertions, bygroups, using,
39 Lexer, DelegatingLexer, RegexLexer, do_insertions, bygroups, using,
40 )
40 )
41 from pygments.token import (
41 from pygments.token import (
42 Comment, Generic, Keyword, Literal, Name, Operator, Other, Text, Error,
42 Comment, Generic, Keyword, Literal, Name, Operator, Other, Text, Error,
43 )
43 )
44 from pygments.util import get_bool_opt
44 from pygments.util import get_bool_opt
45
45
46 # Local
46 # Local
47 from IPython.testing.skipdoctest import skip_doctest
48
47
49 line_re = re.compile('.*?\n')
48 line_re = re.compile('.*?\n')
50
49
51 __all__ = ['build_ipy_lexer', 'IPython3Lexer', 'IPythonLexer',
50 __all__ = ['build_ipy_lexer', 'IPython3Lexer', 'IPythonLexer',
52 'IPythonPartialTracebackLexer', 'IPythonTracebackLexer',
51 'IPythonPartialTracebackLexer', 'IPythonTracebackLexer',
53 'IPythonConsoleLexer', 'IPyLexer']
52 'IPythonConsoleLexer', 'IPyLexer']
54
53
55 ipython_tokens = [
54 ipython_tokens = [
56 (r"(?s)(\s*)(%%)(\w+)(.*)", bygroups(Text, Operator, Keyword, Text)),
55 (r"(?s)(\s*)(%%)(\w+)(.*)", bygroups(Text, Operator, Keyword, Text)),
57 (r'(?s)(^\s*)(%%!)([^\n]*\n)(.*)', bygroups(Text, Operator, Text, using(BashLexer))),
56 (r'(?s)(^\s*)(%%!)([^\n]*\n)(.*)', bygroups(Text, Operator, Text, using(BashLexer))),
58 (r"(%%?)(\w+)(\?\??)$", bygroups(Operator, Keyword, Operator)),
57 (r"(%%?)(\w+)(\?\??)$", bygroups(Operator, Keyword, Operator)),
59 (r"\b(\?\??)(\s*)$", bygroups(Operator, Text)),
58 (r"\b(\?\??)(\s*)$", bygroups(Operator, Text)),
60 (r'(%)(sx|sc|system)(.*)(\n)', bygroups(Operator, Keyword,
59 (r'(%)(sx|sc|system)(.*)(\n)', bygroups(Operator, Keyword,
61 using(BashLexer), Text)),
60 using(BashLexer), Text)),
62 (r'(%)(\w+)(.*\n)', bygroups(Operator, Keyword, Text)),
61 (r'(%)(\w+)(.*\n)', bygroups(Operator, Keyword, Text)),
63 (r'^(!!)(.+)(\n)', bygroups(Operator, using(BashLexer), Text)),
62 (r'^(!!)(.+)(\n)', bygroups(Operator, using(BashLexer), Text)),
64 (r'(!)(?!=)(.+)(\n)', bygroups(Operator, using(BashLexer), Text)),
63 (r'(!)(?!=)(.+)(\n)', bygroups(Operator, using(BashLexer), Text)),
65 (r'^(\s*)(\?\??)(\s*%{0,2}[\w\.\*]*)', bygroups(Text, Operator, Text)),
64 (r'^(\s*)(\?\??)(\s*%{0,2}[\w\.\*]*)', bygroups(Text, Operator, Text)),
66 ]
65 ]
67
66
68 def build_ipy_lexer(python3):
67 def build_ipy_lexer(python3):
69 """Builds IPython lexers depending on the value of `python3`.
68 """Builds IPython lexers depending on the value of `python3`.
70
69
71 The lexer inherits from an appropriate Python lexer and then adds
70 The lexer inherits from an appropriate Python lexer and then adds
72 information about IPython specific keywords (i.e. magic commands,
71 information about IPython specific keywords (i.e. magic commands,
73 shell commands, etc.)
72 shell commands, etc.)
74
73
75 Parameters
74 Parameters
76 ----------
75 ----------
77 python3 : bool
76 python3 : bool
78 If `True`, then build an IPython lexer from a Python 3 lexer.
77 If `True`, then build an IPython lexer from a Python 3 lexer.
79
78
80 """
79 """
81 # It would be nice to have a single IPython lexer class which takes
80 # It would be nice to have a single IPython lexer class which takes
82 # a boolean `python3`. But since there are two Python lexer classes,
81 # a boolean `python3`. But since there are two Python lexer classes,
83 # we will also have two IPython lexer classes.
82 # we will also have two IPython lexer classes.
84 if python3:
83 if python3:
85 PyLexer = Python3Lexer
84 PyLexer = Python3Lexer
86 clsname = 'IPython3Lexer'
85 clsname = 'IPython3Lexer'
87 name = 'IPython3'
86 name = 'IPython3'
88 aliases = ['ipython3']
87 aliases = ['ipython3']
89 doc = """IPython3 Lexer"""
88 doc = """IPython3 Lexer"""
90 else:
89 else:
91 PyLexer = PythonLexer
90 PyLexer = PythonLexer
92 clsname = 'IPythonLexer'
91 clsname = 'IPythonLexer'
93 name = 'IPython'
92 name = 'IPython'
94 aliases = ['ipython2', 'ipython']
93 aliases = ['ipython2', 'ipython']
95 doc = """IPython Lexer"""
94 doc = """IPython Lexer"""
96
95
97 tokens = PyLexer.tokens.copy()
96 tokens = PyLexer.tokens.copy()
98 tokens['root'] = ipython_tokens + tokens['root']
97 tokens['root'] = ipython_tokens + tokens['root']
99
98
100 attrs = {'name': name, 'aliases': aliases, 'filenames': [],
99 attrs = {'name': name, 'aliases': aliases, 'filenames': [],
101 '__doc__': doc, 'tokens': tokens}
100 '__doc__': doc, 'tokens': tokens}
102
101
103 return type(name, (PyLexer,), attrs)
102 return type(name, (PyLexer,), attrs)
104
103
105
104
106 IPython3Lexer = build_ipy_lexer(python3=True)
105 IPython3Lexer = build_ipy_lexer(python3=True)
107 IPythonLexer = build_ipy_lexer(python3=False)
106 IPythonLexer = build_ipy_lexer(python3=False)
108
107
109
108
110 class IPythonPartialTracebackLexer(RegexLexer):
109 class IPythonPartialTracebackLexer(RegexLexer):
111 """
110 """
112 Partial lexer for IPython tracebacks.
111 Partial lexer for IPython tracebacks.
113
112
114 Handles all the non-python output. This works for both Python 2.x and 3.x.
113 Handles all the non-python output. This works for both Python 2.x and 3.x.
115
114
116 """
115 """
117 name = 'IPython Partial Traceback'
116 name = 'IPython Partial Traceback'
118
117
119 tokens = {
118 tokens = {
120 'root': [
119 'root': [
121 # Tracebacks for syntax errors have a different style.
120 # Tracebacks for syntax errors have a different style.
122 # For both types of tracebacks, we mark the first line with
121 # For both types of tracebacks, we mark the first line with
123 # Generic.Traceback. For syntax errors, we mark the filename
122 # Generic.Traceback. For syntax errors, we mark the filename
124 # as we mark the filenames for non-syntax tracebacks.
123 # as we mark the filenames for non-syntax tracebacks.
125 #
124 #
126 # These two regexps define how IPythonConsoleLexer finds a
125 # These two regexps define how IPythonConsoleLexer finds a
127 # traceback.
126 # traceback.
128 #
127 #
129 ## Non-syntax traceback
128 ## Non-syntax traceback
130 (r'^(\^C)?(-+\n)', bygroups(Error, Generic.Traceback)),
129 (r'^(\^C)?(-+\n)', bygroups(Error, Generic.Traceback)),
131 ## Syntax traceback
130 ## Syntax traceback
132 (r'^( File)(.*)(, line )(\d+\n)',
131 (r'^( File)(.*)(, line )(\d+\n)',
133 bygroups(Generic.Traceback, Name.Namespace,
132 bygroups(Generic.Traceback, Name.Namespace,
134 Generic.Traceback, Literal.Number.Integer)),
133 Generic.Traceback, Literal.Number.Integer)),
135
134
136 # (Exception Identifier)(Whitespace)(Traceback Message)
135 # (Exception Identifier)(Whitespace)(Traceback Message)
137 (r'(?u)(^[^\d\W]\w*)(\s*)(Traceback.*?\n)',
136 (r'(?u)(^[^\d\W]\w*)(\s*)(Traceback.*?\n)',
138 bygroups(Name.Exception, Generic.Whitespace, Text)),
137 bygroups(Name.Exception, Generic.Whitespace, Text)),
139 # (Module/Filename)(Text)(Callee)(Function Signature)
138 # (Module/Filename)(Text)(Callee)(Function Signature)
140 # Better options for callee and function signature?
139 # Better options for callee and function signature?
141 (r'(.*)( in )(.*)(\(.*\)\n)',
140 (r'(.*)( in )(.*)(\(.*\)\n)',
142 bygroups(Name.Namespace, Text, Name.Entity, Name.Tag)),
141 bygroups(Name.Namespace, Text, Name.Entity, Name.Tag)),
143 # Regular line: (Whitespace)(Line Number)(Python Code)
142 # Regular line: (Whitespace)(Line Number)(Python Code)
144 (r'(\s*?)(\d+)(.*?\n)',
143 (r'(\s*?)(\d+)(.*?\n)',
145 bygroups(Generic.Whitespace, Literal.Number.Integer, Other)),
144 bygroups(Generic.Whitespace, Literal.Number.Integer, Other)),
146 # Emphasized line: (Arrow)(Line Number)(Python Code)
145 # Emphasized line: (Arrow)(Line Number)(Python Code)
147 # Using Exception token so arrow color matches the Exception.
146 # Using Exception token so arrow color matches the Exception.
148 (r'(-*>?\s?)(\d+)(.*?\n)',
147 (r'(-*>?\s?)(\d+)(.*?\n)',
149 bygroups(Name.Exception, Literal.Number.Integer, Other)),
148 bygroups(Name.Exception, Literal.Number.Integer, Other)),
150 # (Exception Identifier)(Message)
149 # (Exception Identifier)(Message)
151 (r'(?u)(^[^\d\W]\w*)(:.*?\n)',
150 (r'(?u)(^[^\d\W]\w*)(:.*?\n)',
152 bygroups(Name.Exception, Text)),
151 bygroups(Name.Exception, Text)),
153 # Tag everything else as Other, will be handled later.
152 # Tag everything else as Other, will be handled later.
154 (r'.*\n', Other),
153 (r'.*\n', Other),
155 ],
154 ],
156 }
155 }
157
156
158
157
159 class IPythonTracebackLexer(DelegatingLexer):
158 class IPythonTracebackLexer(DelegatingLexer):
160 """
159 """
161 IPython traceback lexer.
160 IPython traceback lexer.
162
161
163 For doctests, the tracebacks can be snipped as much as desired with the
162 For doctests, the tracebacks can be snipped as much as desired with the
164 exception to the lines that designate a traceback. For non-syntax error
163 exception to the lines that designate a traceback. For non-syntax error
165 tracebacks, this is the line of hyphens. For syntax error tracebacks,
164 tracebacks, this is the line of hyphens. For syntax error tracebacks,
166 this is the line which lists the File and line number.
165 this is the line which lists the File and line number.
167
166
168 """
167 """
169 # The lexer inherits from DelegatingLexer. The "root" lexer is an
168 # The lexer inherits from DelegatingLexer. The "root" lexer is an
170 # appropriate IPython lexer, which depends on the value of the boolean
169 # appropriate IPython lexer, which depends on the value of the boolean
171 # `python3`. First, we parse with the partial IPython traceback lexer.
170 # `python3`. First, we parse with the partial IPython traceback lexer.
172 # Then, any code marked with the "Other" token is delegated to the root
171 # Then, any code marked with the "Other" token is delegated to the root
173 # lexer.
172 # lexer.
174 #
173 #
175 name = 'IPython Traceback'
174 name = 'IPython Traceback'
176 aliases = ['ipythontb']
175 aliases = ['ipythontb']
177
176
178 def __init__(self, **options):
177 def __init__(self, **options):
179 self.python3 = get_bool_opt(options, 'python3', False)
178 self.python3 = get_bool_opt(options, 'python3', False)
180 if self.python3:
179 if self.python3:
181 self.aliases = ['ipython3tb']
180 self.aliases = ['ipython3tb']
182 else:
181 else:
183 self.aliases = ['ipython2tb', 'ipythontb']
182 self.aliases = ['ipython2tb', 'ipythontb']
184
183
185 if self.python3:
184 if self.python3:
186 IPyLexer = IPython3Lexer
185 IPyLexer = IPython3Lexer
187 else:
186 else:
188 IPyLexer = IPythonLexer
187 IPyLexer = IPythonLexer
189
188
190 DelegatingLexer.__init__(self, IPyLexer,
189 DelegatingLexer.__init__(self, IPyLexer,
191 IPythonPartialTracebackLexer, **options)
190 IPythonPartialTracebackLexer, **options)
192
191
193 @skip_doctest
194 class IPythonConsoleLexer(Lexer):
192 class IPythonConsoleLexer(Lexer):
195 """
193 """
196 An IPython console lexer for IPython code-blocks and doctests, such as:
194 An IPython console lexer for IPython code-blocks and doctests, such as:
197
195
198 .. code-block:: rst
196 .. code-block:: rst
199
197
200 .. code-block:: ipythonconsole
198 .. code-block:: ipythonconsole
201
199
202 In [1]: a = 'foo'
200 In [1]: a = 'foo'
203
201
204 In [2]: a
202 In [2]: a
205 Out[2]: 'foo'
203 Out[2]: 'foo'
206
204
207 In [3]: print a
205 In [3]: print a
208 foo
206 foo
209
207
210 In [4]: 1 / 0
208 In [4]: 1 / 0
211
209
212
210
213 Support is also provided for IPython exceptions:
211 Support is also provided for IPython exceptions:
214
212
215 .. code-block:: rst
213 .. code-block:: rst
216
214
217 .. code-block:: ipythonconsole
215 .. code-block:: ipythonconsole
218
216
219 In [1]: raise Exception
217 In [1]: raise Exception
220
218
221 ---------------------------------------------------------------------------
219 ---------------------------------------------------------------------------
222 Exception Traceback (most recent call last)
220 Exception Traceback (most recent call last)
223 <ipython-input-1-fca2ab0ca76b> in <module>()
221 <ipython-input-1-fca2ab0ca76b> in <module>()
224 ----> 1 raise Exception
222 ----> 1 raise Exception
225
223
226 Exception:
224 Exception:
227
225
228 """
226 """
229 name = 'IPython console session'
227 name = 'IPython console session'
230 aliases = ['ipythonconsole']
228 aliases = ['ipythonconsole']
231 mimetypes = ['text/x-ipython-console']
229 mimetypes = ['text/x-ipython-console']
232
230
233 # The regexps used to determine what is input and what is output.
231 # The regexps used to determine what is input and what is output.
234 # The default prompts for IPython are:
232 # The default prompts for IPython are:
235 #
233 #
236 # c.PromptManager.in_template = 'In [\#]: '
234 # c.PromptManager.in_template = 'In [\#]: '
237 # c.PromptManager.in2_template = ' .\D.: '
235 # c.PromptManager.in2_template = ' .\D.: '
238 # c.PromptManager.out_template = 'Out[\#]: '
236 # c.PromptManager.out_template = 'Out[\#]: '
239 #
237 #
240 in1_regex = r'In \[[0-9]+\]: '
238 in1_regex = r'In \[[0-9]+\]: '
241 in2_regex = r' \.\.+\.: '
239 in2_regex = r' \.\.+\.: '
242 out_regex = r'Out\[[0-9]+\]: '
240 out_regex = r'Out\[[0-9]+\]: '
243
241
244 #: The regex to determine when a traceback starts.
242 #: The regex to determine when a traceback starts.
245 ipytb_start = re.compile(r'^(\^C)?(-+\n)|^( File)(.*)(, line )(\d+\n)')
243 ipytb_start = re.compile(r'^(\^C)?(-+\n)|^( File)(.*)(, line )(\d+\n)')
246
244
247 def __init__(self, **options):
245 def __init__(self, **options):
248 """Initialize the IPython console lexer.
246 """Initialize the IPython console lexer.
249
247
250 Parameters
248 Parameters
251 ----------
249 ----------
252 python3 : bool
250 python3 : bool
253 If `True`, then the console inputs are parsed using a Python 3
251 If `True`, then the console inputs are parsed using a Python 3
254 lexer. Otherwise, they are parsed using a Python 2 lexer.
252 lexer. Otherwise, they are parsed using a Python 2 lexer.
255 in1_regex : RegexObject
253 in1_regex : RegexObject
256 The compiled regular expression used to detect the start
254 The compiled regular expression used to detect the start
257 of inputs. Although the IPython configuration setting may have a
255 of inputs. Although the IPython configuration setting may have a
258 trailing whitespace, do not include it in the regex. If `None`,
256 trailing whitespace, do not include it in the regex. If `None`,
259 then the default input prompt is assumed.
257 then the default input prompt is assumed.
260 in2_regex : RegexObject
258 in2_regex : RegexObject
261 The compiled regular expression used to detect the continuation
259 The compiled regular expression used to detect the continuation
262 of inputs. Although the IPython configuration setting may have a
260 of inputs. Although the IPython configuration setting may have a
263 trailing whitespace, do not include it in the regex. If `None`,
261 trailing whitespace, do not include it in the regex. If `None`,
264 then the default input prompt is assumed.
262 then the default input prompt is assumed.
265 out_regex : RegexObject
263 out_regex : RegexObject
266 The compiled regular expression used to detect outputs. If `None`,
264 The compiled regular expression used to detect outputs. If `None`,
267 then the default output prompt is assumed.
265 then the default output prompt is assumed.
268
266
269 """
267 """
270 self.python3 = get_bool_opt(options, 'python3', False)
268 self.python3 = get_bool_opt(options, 'python3', False)
271 if self.python3:
269 if self.python3:
272 self.aliases = ['ipython3console']
270 self.aliases = ['ipython3console']
273 else:
271 else:
274 self.aliases = ['ipython2console', 'ipythonconsole']
272 self.aliases = ['ipython2console', 'ipythonconsole']
275
273
276 in1_regex = options.get('in1_regex', self.in1_regex)
274 in1_regex = options.get('in1_regex', self.in1_regex)
277 in2_regex = options.get('in2_regex', self.in2_regex)
275 in2_regex = options.get('in2_regex', self.in2_regex)
278 out_regex = options.get('out_regex', self.out_regex)
276 out_regex = options.get('out_regex', self.out_regex)
279
277
280 # So that we can work with input and output prompts which have been
278 # So that we can work with input and output prompts which have been
281 # rstrip'd (possibly by editors) we also need rstrip'd variants. If
279 # rstrip'd (possibly by editors) we also need rstrip'd variants. If
282 # we do not do this, then such prompts will be tagged as 'output'.
280 # we do not do this, then such prompts will be tagged as 'output'.
283 # The reason can't just use the rstrip'd variants instead is because
281 # The reason can't just use the rstrip'd variants instead is because
284 # we want any whitespace associated with the prompt to be inserted
282 # we want any whitespace associated with the prompt to be inserted
285 # with the token. This allows formatted code to be modified so as hide
283 # with the token. This allows formatted code to be modified so as hide
286 # the appearance of prompts, with the whitespace included. One example
284 # the appearance of prompts, with the whitespace included. One example
287 # use of this is in copybutton.js from the standard lib Python docs.
285 # use of this is in copybutton.js from the standard lib Python docs.
288 in1_regex_rstrip = in1_regex.rstrip() + '\n'
286 in1_regex_rstrip = in1_regex.rstrip() + '\n'
289 in2_regex_rstrip = in2_regex.rstrip() + '\n'
287 in2_regex_rstrip = in2_regex.rstrip() + '\n'
290 out_regex_rstrip = out_regex.rstrip() + '\n'
288 out_regex_rstrip = out_regex.rstrip() + '\n'
291
289
292 # Compile and save them all.
290 # Compile and save them all.
293 attrs = ['in1_regex', 'in2_regex', 'out_regex',
291 attrs = ['in1_regex', 'in2_regex', 'out_regex',
294 'in1_regex_rstrip', 'in2_regex_rstrip', 'out_regex_rstrip']
292 'in1_regex_rstrip', 'in2_regex_rstrip', 'out_regex_rstrip']
295 for attr in attrs:
293 for attr in attrs:
296 self.__setattr__(attr, re.compile(locals()[attr]))
294 self.__setattr__(attr, re.compile(locals()[attr]))
297
295
298 Lexer.__init__(self, **options)
296 Lexer.__init__(self, **options)
299
297
300 if self.python3:
298 if self.python3:
301 pylexer = IPython3Lexer
299 pylexer = IPython3Lexer
302 tblexer = IPythonTracebackLexer
300 tblexer = IPythonTracebackLexer
303 else:
301 else:
304 pylexer = IPythonLexer
302 pylexer = IPythonLexer
305 tblexer = IPythonTracebackLexer
303 tblexer = IPythonTracebackLexer
306
304
307 self.pylexer = pylexer(**options)
305 self.pylexer = pylexer(**options)
308 self.tblexer = tblexer(**options)
306 self.tblexer = tblexer(**options)
309
307
310 self.reset()
308 self.reset()
311
309
312 def reset(self):
310 def reset(self):
313 self.mode = 'output'
311 self.mode = 'output'
314 self.index = 0
312 self.index = 0
315 self.buffer = u''
313 self.buffer = u''
316 self.insertions = []
314 self.insertions = []
317
315
318 def buffered_tokens(self):
316 def buffered_tokens(self):
319 """
317 """
320 Generator of unprocessed tokens after doing insertions and before
318 Generator of unprocessed tokens after doing insertions and before
321 changing to a new state.
319 changing to a new state.
322
320
323 """
321 """
324 if self.mode == 'output':
322 if self.mode == 'output':
325 tokens = [(0, Generic.Output, self.buffer)]
323 tokens = [(0, Generic.Output, self.buffer)]
326 elif self.mode == 'input':
324 elif self.mode == 'input':
327 tokens = self.pylexer.get_tokens_unprocessed(self.buffer)
325 tokens = self.pylexer.get_tokens_unprocessed(self.buffer)
328 else: # traceback
326 else: # traceback
329 tokens = self.tblexer.get_tokens_unprocessed(self.buffer)
327 tokens = self.tblexer.get_tokens_unprocessed(self.buffer)
330
328
331 for i, t, v in do_insertions(self.insertions, tokens):
329 for i, t, v in do_insertions(self.insertions, tokens):
332 # All token indexes are relative to the buffer.
330 # All token indexes are relative to the buffer.
333 yield self.index + i, t, v
331 yield self.index + i, t, v
334
332
335 # Clear it all
333 # Clear it all
336 self.index += len(self.buffer)
334 self.index += len(self.buffer)
337 self.buffer = u''
335 self.buffer = u''
338 self.insertions = []
336 self.insertions = []
339
337
340 def get_mci(self, line):
338 def get_mci(self, line):
341 """
339 """
342 Parses the line and returns a 3-tuple: (mode, code, insertion).
340 Parses the line and returns a 3-tuple: (mode, code, insertion).
343
341
344 `mode` is the next mode (or state) of the lexer, and is always equal
342 `mode` is the next mode (or state) of the lexer, and is always equal
345 to 'input', 'output', or 'tb'.
343 to 'input', 'output', or 'tb'.
346
344
347 `code` is a portion of the line that should be added to the buffer
345 `code` is a portion of the line that should be added to the buffer
348 corresponding to the next mode and eventually lexed by another lexer.
346 corresponding to the next mode and eventually lexed by another lexer.
349 For example, `code` could be Python code if `mode` were 'input'.
347 For example, `code` could be Python code if `mode` were 'input'.
350
348
351 `insertion` is a 3-tuple (index, token, text) representing an
349 `insertion` is a 3-tuple (index, token, text) representing an
352 unprocessed "token" that will be inserted into the stream of tokens
350 unprocessed "token" that will be inserted into the stream of tokens
353 that are created from the buffer once we change modes. This is usually
351 that are created from the buffer once we change modes. This is usually
354 the input or output prompt.
352 the input or output prompt.
355
353
356 In general, the next mode depends on current mode and on the contents
354 In general, the next mode depends on current mode and on the contents
357 of `line`.
355 of `line`.
358
356
359 """
357 """
360 # To reduce the number of regex match checks, we have multiple
358 # To reduce the number of regex match checks, we have multiple
361 # 'if' blocks instead of 'if-elif' blocks.
359 # 'if' blocks instead of 'if-elif' blocks.
362
360
363 # Check for possible end of input
361 # Check for possible end of input
364 in2_match = self.in2_regex.match(line)
362 in2_match = self.in2_regex.match(line)
365 in2_match_rstrip = self.in2_regex_rstrip.match(line)
363 in2_match_rstrip = self.in2_regex_rstrip.match(line)
366 if (in2_match and in2_match.group().rstrip() == line.rstrip()) or \
364 if (in2_match and in2_match.group().rstrip() == line.rstrip()) or \
367 in2_match_rstrip:
365 in2_match_rstrip:
368 end_input = True
366 end_input = True
369 else:
367 else:
370 end_input = False
368 end_input = False
371 if end_input and self.mode != 'tb':
369 if end_input and self.mode != 'tb':
372 # Only look for an end of input when not in tb mode.
370 # Only look for an end of input when not in tb mode.
373 # An ellipsis could appear within the traceback.
371 # An ellipsis could appear within the traceback.
374 mode = 'output'
372 mode = 'output'
375 code = u''
373 code = u''
376 insertion = (0, Generic.Prompt, line)
374 insertion = (0, Generic.Prompt, line)
377 return mode, code, insertion
375 return mode, code, insertion
378
376
379 # Check for output prompt
377 # Check for output prompt
380 out_match = self.out_regex.match(line)
378 out_match = self.out_regex.match(line)
381 out_match_rstrip = self.out_regex_rstrip.match(line)
379 out_match_rstrip = self.out_regex_rstrip.match(line)
382 if out_match or out_match_rstrip:
380 if out_match or out_match_rstrip:
383 mode = 'output'
381 mode = 'output'
384 if out_match:
382 if out_match:
385 idx = out_match.end()
383 idx = out_match.end()
386 else:
384 else:
387 idx = out_match_rstrip.end()
385 idx = out_match_rstrip.end()
388 code = line[idx:]
386 code = line[idx:]
389 # Use the 'heading' token for output. We cannot use Generic.Error
387 # Use the 'heading' token for output. We cannot use Generic.Error
390 # since it would conflict with exceptions.
388 # since it would conflict with exceptions.
391 insertion = (0, Generic.Heading, line[:idx])
389 insertion = (0, Generic.Heading, line[:idx])
392 return mode, code, insertion
390 return mode, code, insertion
393
391
394
392
395 # Check for input or continuation prompt (non stripped version)
393 # Check for input or continuation prompt (non stripped version)
396 in1_match = self.in1_regex.match(line)
394 in1_match = self.in1_regex.match(line)
397 if in1_match or (in2_match and self.mode != 'tb'):
395 if in1_match or (in2_match and self.mode != 'tb'):
398 # New input or when not in tb, continued input.
396 # New input or when not in tb, continued input.
399 # We do not check for continued input when in tb since it is
397 # We do not check for continued input when in tb since it is
400 # allowable to replace a long stack with an ellipsis.
398 # allowable to replace a long stack with an ellipsis.
401 mode = 'input'
399 mode = 'input'
402 if in1_match:
400 if in1_match:
403 idx = in1_match.end()
401 idx = in1_match.end()
404 else: # in2_match
402 else: # in2_match
405 idx = in2_match.end()
403 idx = in2_match.end()
406 code = line[idx:]
404 code = line[idx:]
407 insertion = (0, Generic.Prompt, line[:idx])
405 insertion = (0, Generic.Prompt, line[:idx])
408 return mode, code, insertion
406 return mode, code, insertion
409
407
410 # Check for input or continuation prompt (stripped version)
408 # Check for input or continuation prompt (stripped version)
411 in1_match_rstrip = self.in1_regex_rstrip.match(line)
409 in1_match_rstrip = self.in1_regex_rstrip.match(line)
412 if in1_match_rstrip or (in2_match_rstrip and self.mode != 'tb'):
410 if in1_match_rstrip or (in2_match_rstrip and self.mode != 'tb'):
413 # New input or when not in tb, continued input.
411 # New input or when not in tb, continued input.
414 # We do not check for continued input when in tb since it is
412 # We do not check for continued input when in tb since it is
415 # allowable to replace a long stack with an ellipsis.
413 # allowable to replace a long stack with an ellipsis.
416 mode = 'input'
414 mode = 'input'
417 if in1_match_rstrip:
415 if in1_match_rstrip:
418 idx = in1_match_rstrip.end()
416 idx = in1_match_rstrip.end()
419 else: # in2_match
417 else: # in2_match
420 idx = in2_match_rstrip.end()
418 idx = in2_match_rstrip.end()
421 code = line[idx:]
419 code = line[idx:]
422 insertion = (0, Generic.Prompt, line[:idx])
420 insertion = (0, Generic.Prompt, line[:idx])
423 return mode, code, insertion
421 return mode, code, insertion
424
422
425 # Check for traceback
423 # Check for traceback
426 if self.ipytb_start.match(line):
424 if self.ipytb_start.match(line):
427 mode = 'tb'
425 mode = 'tb'
428 code = line
426 code = line
429 insertion = None
427 insertion = None
430 return mode, code, insertion
428 return mode, code, insertion
431
429
432 # All other stuff...
430 # All other stuff...
433 if self.mode in ('input', 'output'):
431 if self.mode in ('input', 'output'):
434 # We assume all other text is output. Multiline input that
432 # We assume all other text is output. Multiline input that
435 # does not use the continuation marker cannot be detected.
433 # does not use the continuation marker cannot be detected.
436 # For example, the 3 in the following is clearly output:
434 # For example, the 3 in the following is clearly output:
437 #
435 #
438 # In [1]: print 3
436 # In [1]: print 3
439 # 3
437 # 3
440 #
438 #
441 # But the following second line is part of the input:
439 # But the following second line is part of the input:
442 #
440 #
443 # In [2]: while True:
441 # In [2]: while True:
444 # print True
442 # print True
445 #
443 #
446 # In both cases, the 2nd line will be 'output'.
444 # In both cases, the 2nd line will be 'output'.
447 #
445 #
448 mode = 'output'
446 mode = 'output'
449 else:
447 else:
450 mode = 'tb'
448 mode = 'tb'
451
449
452 code = line
450 code = line
453 insertion = None
451 insertion = None
454
452
455 return mode, code, insertion
453 return mode, code, insertion
456
454
457 def get_tokens_unprocessed(self, text):
455 def get_tokens_unprocessed(self, text):
458 self.reset()
456 self.reset()
459 for match in line_re.finditer(text):
457 for match in line_re.finditer(text):
460 line = match.group()
458 line = match.group()
461 mode, code, insertion = self.get_mci(line)
459 mode, code, insertion = self.get_mci(line)
462
460
463 if mode != self.mode:
461 if mode != self.mode:
464 # Yield buffered tokens before transitioning to new mode.
462 # Yield buffered tokens before transitioning to new mode.
465 for token in self.buffered_tokens():
463 for token in self.buffered_tokens():
466 yield token
464 yield token
467 self.mode = mode
465 self.mode = mode
468
466
469 if insertion:
467 if insertion:
470 self.insertions.append((len(self.buffer), [insertion]))
468 self.insertions.append((len(self.buffer), [insertion]))
471 self.buffer += code
469 self.buffer += code
472 else:
470 else:
473 for token in self.buffered_tokens():
471 for token in self.buffered_tokens():
474 yield token
472 yield token
475
473
476 class IPyLexer(Lexer):
474 class IPyLexer(Lexer):
477 """
475 """
478 Primary lexer for all IPython-like code.
476 Primary lexer for all IPython-like code.
479
477
480 This is a simple helper lexer. If the first line of the text begins with
478 This is a simple helper lexer. If the first line of the text begins with
481 "In \[[0-9]+\]:", then the entire text is parsed with an IPython console
479 "In \[[0-9]+\]:", then the entire text is parsed with an IPython console
482 lexer. If not, then the entire text is parsed with an IPython lexer.
480 lexer. If not, then the entire text is parsed with an IPython lexer.
483
481
484 The goal is to reduce the number of lexers that are registered
482 The goal is to reduce the number of lexers that are registered
485 with Pygments.
483 with Pygments.
486
484
487 """
485 """
488 name = 'IPy session'
486 name = 'IPy session'
489 aliases = ['ipy']
487 aliases = ['ipy']
490
488
491 def __init__(self, **options):
489 def __init__(self, **options):
492 self.python3 = get_bool_opt(options, 'python3', False)
490 self.python3 = get_bool_opt(options, 'python3', False)
493 if self.python3:
491 if self.python3:
494 self.aliases = ['ipy3']
492 self.aliases = ['ipy3']
495 else:
493 else:
496 self.aliases = ['ipy2', 'ipy']
494 self.aliases = ['ipy2', 'ipy']
497
495
498 Lexer.__init__(self, **options)
496 Lexer.__init__(self, **options)
499
497
500 self.IPythonLexer = IPythonLexer(**options)
498 self.IPythonLexer = IPythonLexer(**options)
501 self.IPythonConsoleLexer = IPythonConsoleLexer(**options)
499 self.IPythonConsoleLexer = IPythonConsoleLexer(**options)
502
500
503 def get_tokens_unprocessed(self, text):
501 def get_tokens_unprocessed(self, text):
504 # Search for the input prompt anywhere...this allows code blocks to
502 # Search for the input prompt anywhere...this allows code blocks to
505 # begin with comments as well.
503 # begin with comments as well.
506 if re.match(r'.*(In \[[0-9]+\]:)', text.strip(), re.DOTALL):
504 if re.match(r'.*(In \[[0-9]+\]:)', text.strip(), re.DOTALL):
507 lex = self.IPythonConsoleLexer
505 lex = self.IPythonConsoleLexer
508 else:
506 else:
509 lex = self.IPythonLexer
507 lex = self.IPythonLexer
510 for token in lex.get_tokens_unprocessed(text):
508 for token in lex.get_tokens_unprocessed(text):
511 yield token
509 yield token
512
510
@@ -1,116 +1,114 b''
1 """
1 """
2 Password generation for the IPython notebook.
2 Password generation for the IPython notebook.
3 """
3 """
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Imports
5 # Imports
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Stdlib
7 # Stdlib
8 import getpass
8 import getpass
9 import hashlib
9 import hashlib
10 import random
10 import random
11
11
12 # Our own
12 # Our own
13 from IPython.core.error import UsageError
13 from IPython.core.error import UsageError
14 from IPython.testing.skipdoctest import skip_doctest
15 from IPython.utils.py3compat import cast_bytes, str_to_bytes
14 from IPython.utils.py3compat import cast_bytes, str_to_bytes
16
15
17 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
18 # Globals
17 # Globals
19 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
20
19
21 # Length of the salt in nr of hex chars, which implies salt_len * 4
20 # Length of the salt in nr of hex chars, which implies salt_len * 4
22 # bits of randomness.
21 # bits of randomness.
23 salt_len = 12
22 salt_len = 12
24
23
25 #-----------------------------------------------------------------------------
24 #-----------------------------------------------------------------------------
26 # Functions
25 # Functions
27 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
28
27
29 @skip_doctest
30 def passwd(passphrase=None, algorithm='sha1'):
28 def passwd(passphrase=None, algorithm='sha1'):
31 """Generate hashed password and salt for use in notebook configuration.
29 """Generate hashed password and salt for use in notebook configuration.
32
30
33 In the notebook configuration, set `c.NotebookApp.password` to
31 In the notebook configuration, set `c.NotebookApp.password` to
34 the generated string.
32 the generated string.
35
33
36 Parameters
34 Parameters
37 ----------
35 ----------
38 passphrase : str
36 passphrase : str
39 Password to hash. If unspecified, the user is asked to input
37 Password to hash. If unspecified, the user is asked to input
40 and verify a password.
38 and verify a password.
41 algorithm : str
39 algorithm : str
42 Hashing algorithm to use (e.g, 'sha1' or any argument supported
40 Hashing algorithm to use (e.g, 'sha1' or any argument supported
43 by :func:`hashlib.new`).
41 by :func:`hashlib.new`).
44
42
45 Returns
43 Returns
46 -------
44 -------
47 hashed_passphrase : str
45 hashed_passphrase : str
48 Hashed password, in the format 'hash_algorithm:salt:passphrase_hash'.
46 Hashed password, in the format 'hash_algorithm:salt:passphrase_hash'.
49
47
50 Examples
48 Examples
51 --------
49 --------
52 >>> passwd('mypassword')
50 >>> passwd('mypassword')
53 'sha1:7cf3:b7d6da294ea9592a9480c8f52e63cd42cfb9dd12'
51 'sha1:7cf3:b7d6da294ea9592a9480c8f52e63cd42cfb9dd12'
54
52
55 """
53 """
56 if passphrase is None:
54 if passphrase is None:
57 for i in range(3):
55 for i in range(3):
58 p0 = getpass.getpass('Enter password: ')
56 p0 = getpass.getpass('Enter password: ')
59 p1 = getpass.getpass('Verify password: ')
57 p1 = getpass.getpass('Verify password: ')
60 if p0 == p1:
58 if p0 == p1:
61 passphrase = p0
59 passphrase = p0
62 break
60 break
63 else:
61 else:
64 print('Passwords do not match.')
62 print('Passwords do not match.')
65 else:
63 else:
66 raise UsageError('No matching passwords found. Giving up.')
64 raise UsageError('No matching passwords found. Giving up.')
67
65
68 h = hashlib.new(algorithm)
66 h = hashlib.new(algorithm)
69 salt = ('%0' + str(salt_len) + 'x') % random.getrandbits(4 * salt_len)
67 salt = ('%0' + str(salt_len) + 'x') % random.getrandbits(4 * salt_len)
70 h.update(cast_bytes(passphrase, 'utf-8') + str_to_bytes(salt, 'ascii'))
68 h.update(cast_bytes(passphrase, 'utf-8') + str_to_bytes(salt, 'ascii'))
71
69
72 return ':'.join((algorithm, salt, h.hexdigest()))
70 return ':'.join((algorithm, salt, h.hexdigest()))
73
71
74
72
75 def passwd_check(hashed_passphrase, passphrase):
73 def passwd_check(hashed_passphrase, passphrase):
76 """Verify that a given passphrase matches its hashed version.
74 """Verify that a given passphrase matches its hashed version.
77
75
78 Parameters
76 Parameters
79 ----------
77 ----------
80 hashed_passphrase : str
78 hashed_passphrase : str
81 Hashed password, in the format returned by `passwd`.
79 Hashed password, in the format returned by `passwd`.
82 passphrase : str
80 passphrase : str
83 Passphrase to validate.
81 Passphrase to validate.
84
82
85 Returns
83 Returns
86 -------
84 -------
87 valid : bool
85 valid : bool
88 True if the passphrase matches the hash.
86 True if the passphrase matches the hash.
89
87
90 Examples
88 Examples
91 --------
89 --------
92 >>> from IPython.lib.security import passwd_check
90 >>> from IPython.lib.security import passwd_check
93 >>> passwd_check('sha1:0e112c3ddfce:a68df677475c2b47b6e86d0467eec97ac5f4b85a',
91 >>> passwd_check('sha1:0e112c3ddfce:a68df677475c2b47b6e86d0467eec97ac5f4b85a',
94 ... 'mypassword')
92 ... 'mypassword')
95 True
93 True
96
94
97 >>> passwd_check('sha1:0e112c3ddfce:a68df677475c2b47b6e86d0467eec97ac5f4b85a',
95 >>> passwd_check('sha1:0e112c3ddfce:a68df677475c2b47b6e86d0467eec97ac5f4b85a',
98 ... 'anotherpassword')
96 ... 'anotherpassword')
99 False
97 False
100 """
98 """
101 try:
99 try:
102 algorithm, salt, pw_digest = hashed_passphrase.split(':', 2)
100 algorithm, salt, pw_digest = hashed_passphrase.split(':', 2)
103 except (ValueError, TypeError):
101 except (ValueError, TypeError):
104 return False
102 return False
105
103
106 try:
104 try:
107 h = hashlib.new(algorithm)
105 h = hashlib.new(algorithm)
108 except ValueError:
106 except ValueError:
109 return False
107 return False
110
108
111 if len(pw_digest) == 0:
109 if len(pw_digest) == 0:
112 return False
110 return False
113
111
114 h.update(cast_bytes(passphrase, 'utf-8') + cast_bytes(salt, 'ascii'))
112 h.update(cast_bytes(passphrase, 'utf-8') + cast_bytes(salt, 'ascii'))
115
113
116 return h.hexdigest() == pw_digest
114 return h.hexdigest() == pw_digest
@@ -1,642 +1,640 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Subclass of InteractiveShell for terminal based frontends."""
2 """Subclass of InteractiveShell for terminal based frontends."""
3
3
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Copyright (C) 2001 Janko Hauser <jhauser@zscout.de>
5 # Copyright (C) 2001 Janko Hauser <jhauser@zscout.de>
6 # Copyright (C) 2001-2007 Fernando Perez. <fperez@colorado.edu>
6 # Copyright (C) 2001-2007 Fernando Perez. <fperez@colorado.edu>
7 # Copyright (C) 2008-2011 The IPython Development Team
7 # Copyright (C) 2008-2011 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 from __future__ import print_function
16 from __future__ import print_function
17
17
18 import bdb
18 import bdb
19 import os
19 import os
20 import sys
20 import sys
21
21
22 from IPython.core.error import TryNext, UsageError
22 from IPython.core.error import TryNext, UsageError
23 from IPython.core.usage import interactive_usage
23 from IPython.core.usage import interactive_usage
24 from IPython.core.inputsplitter import IPythonInputSplitter
24 from IPython.core.inputsplitter import IPythonInputSplitter
25 from IPython.core.interactiveshell import InteractiveShell, InteractiveShellABC
25 from IPython.core.interactiveshell import InteractiveShell, InteractiveShellABC
26 from IPython.core.magic import Magics, magics_class, line_magic
26 from IPython.core.magic import Magics, magics_class, line_magic
27 from IPython.lib.clipboard import ClipboardEmpty
27 from IPython.lib.clipboard import ClipboardEmpty
28 from IPython.testing.skipdoctest import skip_doctest
29 from IPython.utils.encoding import get_stream_enc
28 from IPython.utils.encoding import get_stream_enc
30 from IPython.utils import py3compat
29 from IPython.utils import py3compat
31 from IPython.utils.terminal import toggle_set_term_title, set_term_title
30 from IPython.utils.terminal import toggle_set_term_title, set_term_title
32 from IPython.utils.process import abbrev_cwd
31 from IPython.utils.process import abbrev_cwd
33 from IPython.utils.warn import warn, error
32 from IPython.utils.warn import warn, error
34 from IPython.utils.text import num_ini_spaces, SList, strip_email_quotes
33 from IPython.utils.text import num_ini_spaces, SList, strip_email_quotes
35 from IPython.utils.traitlets import Integer, CBool, Unicode
34 from IPython.utils.traitlets import Integer, CBool, Unicode
36
35
37 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
38 # Utilities
37 # Utilities
39 #-----------------------------------------------------------------------------
38 #-----------------------------------------------------------------------------
40
39
41 def get_default_editor():
40 def get_default_editor():
42 try:
41 try:
43 ed = os.environ['EDITOR']
42 ed = os.environ['EDITOR']
44 if not py3compat.PY3:
43 if not py3compat.PY3:
45 ed = ed.decode()
44 ed = ed.decode()
46 return ed
45 return ed
47 except KeyError:
46 except KeyError:
48 pass
47 pass
49 except UnicodeError:
48 except UnicodeError:
50 warn("$EDITOR environment variable is not pure ASCII. Using platform "
49 warn("$EDITOR environment variable is not pure ASCII. Using platform "
51 "default editor.")
50 "default editor.")
52
51
53 if os.name == 'posix':
52 if os.name == 'posix':
54 return 'vi' # the only one guaranteed to be there!
53 return 'vi' # the only one guaranteed to be there!
55 else:
54 else:
56 return 'notepad' # same in Windows!
55 return 'notepad' # same in Windows!
57
56
58 def get_pasted_lines(sentinel, l_input=py3compat.input, quiet=False):
57 def get_pasted_lines(sentinel, l_input=py3compat.input, quiet=False):
59 """ Yield pasted lines until the user enters the given sentinel value.
58 """ Yield pasted lines until the user enters the given sentinel value.
60 """
59 """
61 if not quiet:
60 if not quiet:
62 print("Pasting code; enter '%s' alone on the line to stop or use Ctrl-D." \
61 print("Pasting code; enter '%s' alone on the line to stop or use Ctrl-D." \
63 % sentinel)
62 % sentinel)
64 prompt = ":"
63 prompt = ":"
65 else:
64 else:
66 prompt = ""
65 prompt = ""
67 while True:
66 while True:
68 try:
67 try:
69 l = py3compat.str_to_unicode(l_input(prompt))
68 l = py3compat.str_to_unicode(l_input(prompt))
70 if l == sentinel:
69 if l == sentinel:
71 return
70 return
72 else:
71 else:
73 yield l
72 yield l
74 except EOFError:
73 except EOFError:
75 print('<EOF>')
74 print('<EOF>')
76 return
75 return
77
76
78
77
79 #------------------------------------------------------------------------
78 #------------------------------------------------------------------------
80 # Terminal-specific magics
79 # Terminal-specific magics
81 #------------------------------------------------------------------------
80 #------------------------------------------------------------------------
82
81
83 @magics_class
82 @magics_class
84 class TerminalMagics(Magics):
83 class TerminalMagics(Magics):
85 def __init__(self, shell):
84 def __init__(self, shell):
86 super(TerminalMagics, self).__init__(shell)
85 super(TerminalMagics, self).__init__(shell)
87 self.input_splitter = IPythonInputSplitter()
86 self.input_splitter = IPythonInputSplitter()
88
87
89 def store_or_execute(self, block, name):
88 def store_or_execute(self, block, name):
90 """ Execute a block, or store it in a variable, per the user's request.
89 """ Execute a block, or store it in a variable, per the user's request.
91 """
90 """
92 if name:
91 if name:
93 # If storing it for further editing
92 # If storing it for further editing
94 self.shell.user_ns[name] = SList(block.splitlines())
93 self.shell.user_ns[name] = SList(block.splitlines())
95 print("Block assigned to '%s'" % name)
94 print("Block assigned to '%s'" % name)
96 else:
95 else:
97 b = self.preclean_input(block)
96 b = self.preclean_input(block)
98 self.shell.user_ns['pasted_block'] = b
97 self.shell.user_ns['pasted_block'] = b
99 self.shell.using_paste_magics = True
98 self.shell.using_paste_magics = True
100 try:
99 try:
101 self.shell.run_cell(b)
100 self.shell.run_cell(b)
102 finally:
101 finally:
103 self.shell.using_paste_magics = False
102 self.shell.using_paste_magics = False
104
103
105 def preclean_input(self, block):
104 def preclean_input(self, block):
106 lines = block.splitlines()
105 lines = block.splitlines()
107 while lines and not lines[0].strip():
106 while lines and not lines[0].strip():
108 lines = lines[1:]
107 lines = lines[1:]
109 return strip_email_quotes('\n'.join(lines))
108 return strip_email_quotes('\n'.join(lines))
110
109
111 def rerun_pasted(self, name='pasted_block'):
110 def rerun_pasted(self, name='pasted_block'):
112 """ Rerun a previously pasted command.
111 """ Rerun a previously pasted command.
113 """
112 """
114 b = self.shell.user_ns.get(name)
113 b = self.shell.user_ns.get(name)
115
114
116 # Sanity checks
115 # Sanity checks
117 if b is None:
116 if b is None:
118 raise UsageError('No previous pasted block available')
117 raise UsageError('No previous pasted block available')
119 if not isinstance(b, py3compat.string_types):
118 if not isinstance(b, py3compat.string_types):
120 raise UsageError(
119 raise UsageError(
121 "Variable 'pasted_block' is not a string, can't execute")
120 "Variable 'pasted_block' is not a string, can't execute")
122
121
123 print("Re-executing '%s...' (%d chars)"% (b.split('\n',1)[0], len(b)))
122 print("Re-executing '%s...' (%d chars)"% (b.split('\n',1)[0], len(b)))
124 self.shell.run_cell(b)
123 self.shell.run_cell(b)
125
124
126 @line_magic
125 @line_magic
127 def autoindent(self, parameter_s = ''):
126 def autoindent(self, parameter_s = ''):
128 """Toggle autoindent on/off (if available)."""
127 """Toggle autoindent on/off (if available)."""
129
128
130 self.shell.set_autoindent()
129 self.shell.set_autoindent()
131 print("Automatic indentation is:",['OFF','ON'][self.shell.autoindent])
130 print("Automatic indentation is:",['OFF','ON'][self.shell.autoindent])
132
131
133 @skip_doctest
134 @line_magic
132 @line_magic
135 def cpaste(self, parameter_s=''):
133 def cpaste(self, parameter_s=''):
136 """Paste & execute a pre-formatted code block from clipboard.
134 """Paste & execute a pre-formatted code block from clipboard.
137
135
138 You must terminate the block with '--' (two minus-signs) or Ctrl-D
136 You must terminate the block with '--' (two minus-signs) or Ctrl-D
139 alone on the line. You can also provide your own sentinel with '%paste
137 alone on the line. You can also provide your own sentinel with '%paste
140 -s %%' ('%%' is the new sentinel for this operation).
138 -s %%' ('%%' is the new sentinel for this operation).
141
139
142 The block is dedented prior to execution to enable execution of method
140 The block is dedented prior to execution to enable execution of method
143 definitions. '>' and '+' characters at the beginning of a line are
141 definitions. '>' and '+' characters at the beginning of a line are
144 ignored, to allow pasting directly from e-mails, diff files and
142 ignored, to allow pasting directly from e-mails, diff files and
145 doctests (the '...' continuation prompt is also stripped). The
143 doctests (the '...' continuation prompt is also stripped). The
146 executed block is also assigned to variable named 'pasted_block' for
144 executed block is also assigned to variable named 'pasted_block' for
147 later editing with '%edit pasted_block'.
145 later editing with '%edit pasted_block'.
148
146
149 You can also pass a variable name as an argument, e.g. '%cpaste foo'.
147 You can also pass a variable name as an argument, e.g. '%cpaste foo'.
150 This assigns the pasted block to variable 'foo' as string, without
148 This assigns the pasted block to variable 'foo' as string, without
151 dedenting or executing it (preceding >>> and + is still stripped)
149 dedenting or executing it (preceding >>> and + is still stripped)
152
150
153 '%cpaste -r' re-executes the block previously entered by cpaste.
151 '%cpaste -r' re-executes the block previously entered by cpaste.
154 '%cpaste -q' suppresses any additional output messages.
152 '%cpaste -q' suppresses any additional output messages.
155
153
156 Do not be alarmed by garbled output on Windows (it's a readline bug).
154 Do not be alarmed by garbled output on Windows (it's a readline bug).
157 Just press enter and type -- (and press enter again) and the block
155 Just press enter and type -- (and press enter again) and the block
158 will be what was just pasted.
156 will be what was just pasted.
159
157
160 IPython statements (magics, shell escapes) are not supported (yet).
158 IPython statements (magics, shell escapes) are not supported (yet).
161
159
162 See also
160 See also
163 --------
161 --------
164 paste: automatically pull code from clipboard.
162 paste: automatically pull code from clipboard.
165
163
166 Examples
164 Examples
167 --------
165 --------
168 ::
166 ::
169
167
170 In [8]: %cpaste
168 In [8]: %cpaste
171 Pasting code; enter '--' alone on the line to stop.
169 Pasting code; enter '--' alone on the line to stop.
172 :>>> a = ["world!", "Hello"]
170 :>>> a = ["world!", "Hello"]
173 :>>> print " ".join(sorted(a))
171 :>>> print " ".join(sorted(a))
174 :--
172 :--
175 Hello world!
173 Hello world!
176 """
174 """
177 opts, name = self.parse_options(parameter_s, 'rqs:', mode='string')
175 opts, name = self.parse_options(parameter_s, 'rqs:', mode='string')
178 if 'r' in opts:
176 if 'r' in opts:
179 self.rerun_pasted()
177 self.rerun_pasted()
180 return
178 return
181
179
182 quiet = ('q' in opts)
180 quiet = ('q' in opts)
183
181
184 sentinel = opts.get('s', u'--')
182 sentinel = opts.get('s', u'--')
185 block = '\n'.join(get_pasted_lines(sentinel, quiet=quiet))
183 block = '\n'.join(get_pasted_lines(sentinel, quiet=quiet))
186 self.store_or_execute(block, name)
184 self.store_or_execute(block, name)
187
185
188 @line_magic
186 @line_magic
189 def paste(self, parameter_s=''):
187 def paste(self, parameter_s=''):
190 """Paste & execute a pre-formatted code block from clipboard.
188 """Paste & execute a pre-formatted code block from clipboard.
191
189
192 The text is pulled directly from the clipboard without user
190 The text is pulled directly from the clipboard without user
193 intervention and printed back on the screen before execution (unless
191 intervention and printed back on the screen before execution (unless
194 the -q flag is given to force quiet mode).
192 the -q flag is given to force quiet mode).
195
193
196 The block is dedented prior to execution to enable execution of method
194 The block is dedented prior to execution to enable execution of method
197 definitions. '>' and '+' characters at the beginning of a line are
195 definitions. '>' and '+' characters at the beginning of a line are
198 ignored, to allow pasting directly from e-mails, diff files and
196 ignored, to allow pasting directly from e-mails, diff files and
199 doctests (the '...' continuation prompt is also stripped). The
197 doctests (the '...' continuation prompt is also stripped). The
200 executed block is also assigned to variable named 'pasted_block' for
198 executed block is also assigned to variable named 'pasted_block' for
201 later editing with '%edit pasted_block'.
199 later editing with '%edit pasted_block'.
202
200
203 You can also pass a variable name as an argument, e.g. '%paste foo'.
201 You can also pass a variable name as an argument, e.g. '%paste foo'.
204 This assigns the pasted block to variable 'foo' as string, without
202 This assigns the pasted block to variable 'foo' as string, without
205 executing it (preceding >>> and + is still stripped).
203 executing it (preceding >>> and + is still stripped).
206
204
207 Options:
205 Options:
208
206
209 -r: re-executes the block previously entered by cpaste.
207 -r: re-executes the block previously entered by cpaste.
210
208
211 -q: quiet mode: do not echo the pasted text back to the terminal.
209 -q: quiet mode: do not echo the pasted text back to the terminal.
212
210
213 IPython statements (magics, shell escapes) are not supported (yet).
211 IPython statements (magics, shell escapes) are not supported (yet).
214
212
215 See also
213 See also
216 --------
214 --------
217 cpaste: manually paste code into terminal until you mark its end.
215 cpaste: manually paste code into terminal until you mark its end.
218 """
216 """
219 opts, name = self.parse_options(parameter_s, 'rq', mode='string')
217 opts, name = self.parse_options(parameter_s, 'rq', mode='string')
220 if 'r' in opts:
218 if 'r' in opts:
221 self.rerun_pasted()
219 self.rerun_pasted()
222 return
220 return
223 try:
221 try:
224 block = self.shell.hooks.clipboard_get()
222 block = self.shell.hooks.clipboard_get()
225 except TryNext as clipboard_exc:
223 except TryNext as clipboard_exc:
226 message = getattr(clipboard_exc, 'args')
224 message = getattr(clipboard_exc, 'args')
227 if message:
225 if message:
228 error(message[0])
226 error(message[0])
229 else:
227 else:
230 error('Could not get text from the clipboard.')
228 error('Could not get text from the clipboard.')
231 return
229 return
232 except ClipboardEmpty:
230 except ClipboardEmpty:
233 raise UsageError("The clipboard appears to be empty")
231 raise UsageError("The clipboard appears to be empty")
234
232
235 # By default, echo back to terminal unless quiet mode is requested
233 # By default, echo back to terminal unless quiet mode is requested
236 if 'q' not in opts:
234 if 'q' not in opts:
237 write = self.shell.write
235 write = self.shell.write
238 write(self.shell.pycolorize(block))
236 write(self.shell.pycolorize(block))
239 if not block.endswith('\n'):
237 if not block.endswith('\n'):
240 write('\n')
238 write('\n')
241 write("## -- End pasted text --\n")
239 write("## -- End pasted text --\n")
242
240
243 self.store_or_execute(block, name)
241 self.store_or_execute(block, name)
244
242
245 # Class-level: add a '%cls' magic only on Windows
243 # Class-level: add a '%cls' magic only on Windows
246 if sys.platform == 'win32':
244 if sys.platform == 'win32':
247 @line_magic
245 @line_magic
248 def cls(self, s):
246 def cls(self, s):
249 """Clear screen.
247 """Clear screen.
250 """
248 """
251 os.system("cls")
249 os.system("cls")
252
250
253 #-----------------------------------------------------------------------------
251 #-----------------------------------------------------------------------------
254 # Main class
252 # Main class
255 #-----------------------------------------------------------------------------
253 #-----------------------------------------------------------------------------
256
254
257 class TerminalInteractiveShell(InteractiveShell):
255 class TerminalInteractiveShell(InteractiveShell):
258
256
259 autoedit_syntax = CBool(False, config=True,
257 autoedit_syntax = CBool(False, config=True,
260 help="auto editing of files with syntax errors.")
258 help="auto editing of files with syntax errors.")
261 confirm_exit = CBool(True, config=True,
259 confirm_exit = CBool(True, config=True,
262 help="""
260 help="""
263 Set to confirm when you try to exit IPython with an EOF (Control-D
261 Set to confirm when you try to exit IPython with an EOF (Control-D
264 in Unix, Control-Z/Enter in Windows). By typing 'exit' or 'quit',
262 in Unix, Control-Z/Enter in Windows). By typing 'exit' or 'quit',
265 you can force a direct exit without any confirmation.""",
263 you can force a direct exit without any confirmation.""",
266 )
264 )
267 # This display_banner only controls whether or not self.show_banner()
265 # This display_banner only controls whether or not self.show_banner()
268 # is called when mainloop/interact are called. The default is False
266 # is called when mainloop/interact are called. The default is False
269 # because for the terminal based application, the banner behavior
267 # because for the terminal based application, the banner behavior
270 # is controlled by the application.
268 # is controlled by the application.
271 display_banner = CBool(False) # This isn't configurable!
269 display_banner = CBool(False) # This isn't configurable!
272 embedded = CBool(False)
270 embedded = CBool(False)
273 embedded_active = CBool(False)
271 embedded_active = CBool(False)
274 editor = Unicode(get_default_editor(), config=True,
272 editor = Unicode(get_default_editor(), config=True,
275 help="Set the editor used by IPython (default to $EDITOR/vi/notepad)."
273 help="Set the editor used by IPython (default to $EDITOR/vi/notepad)."
276 )
274 )
277 pager = Unicode('less', config=True,
275 pager = Unicode('less', config=True,
278 help="The shell program to be used for paging.")
276 help="The shell program to be used for paging.")
279
277
280 screen_length = Integer(0, config=True,
278 screen_length = Integer(0, config=True,
281 help=
279 help=
282 """Number of lines of your screen, used to control printing of very
280 """Number of lines of your screen, used to control printing of very
283 long strings. Strings longer than this number of lines will be sent
281 long strings. Strings longer than this number of lines will be sent
284 through a pager instead of directly printed. The default value for
282 through a pager instead of directly printed. The default value for
285 this is 0, which means IPython will auto-detect your screen size every
283 this is 0, which means IPython will auto-detect your screen size every
286 time it needs to print certain potentially long strings (this doesn't
284 time it needs to print certain potentially long strings (this doesn't
287 change the behavior of the 'print' keyword, it's only triggered
285 change the behavior of the 'print' keyword, it's only triggered
288 internally). If for some reason this isn't working well (it needs
286 internally). If for some reason this isn't working well (it needs
289 curses support), specify it yourself. Otherwise don't change the
287 curses support), specify it yourself. Otherwise don't change the
290 default.""",
288 default.""",
291 )
289 )
292 term_title = CBool(False, config=True,
290 term_title = CBool(False, config=True,
293 help="Enable auto setting the terminal title."
291 help="Enable auto setting the terminal title."
294 )
292 )
295 usage = Unicode(interactive_usage)
293 usage = Unicode(interactive_usage)
296
294
297 # This `using_paste_magics` is used to detect whether the code is being
295 # This `using_paste_magics` is used to detect whether the code is being
298 # executed via paste magics functions
296 # executed via paste magics functions
299 using_paste_magics = CBool(False)
297 using_paste_magics = CBool(False)
300
298
301 # In the terminal, GUI control is done via PyOS_InputHook
299 # In the terminal, GUI control is done via PyOS_InputHook
302 @staticmethod
300 @staticmethod
303 def enable_gui(gui=None, app=None):
301 def enable_gui(gui=None, app=None):
304 """Switch amongst GUI input hooks by name.
302 """Switch amongst GUI input hooks by name.
305 """
303 """
306 # Deferred import
304 # Deferred import
307 from IPython.lib.inputhook import enable_gui as real_enable_gui
305 from IPython.lib.inputhook import enable_gui as real_enable_gui
308 try:
306 try:
309 return real_enable_gui(gui, app)
307 return real_enable_gui(gui, app)
310 except ValueError as e:
308 except ValueError as e:
311 raise UsageError("%s" % e)
309 raise UsageError("%s" % e)
312
310
313 system = InteractiveShell.system_raw
311 system = InteractiveShell.system_raw
314
312
315 #-------------------------------------------------------------------------
313 #-------------------------------------------------------------------------
316 # Overrides of init stages
314 # Overrides of init stages
317 #-------------------------------------------------------------------------
315 #-------------------------------------------------------------------------
318
316
319 def init_display_formatter(self):
317 def init_display_formatter(self):
320 super(TerminalInteractiveShell, self).init_display_formatter()
318 super(TerminalInteractiveShell, self).init_display_formatter()
321 # terminal only supports plaintext
319 # terminal only supports plaintext
322 self.display_formatter.active_types = ['text/plain']
320 self.display_formatter.active_types = ['text/plain']
323
321
324 #-------------------------------------------------------------------------
322 #-------------------------------------------------------------------------
325 # Things related to the terminal
323 # Things related to the terminal
326 #-------------------------------------------------------------------------
324 #-------------------------------------------------------------------------
327
325
328 @property
326 @property
329 def usable_screen_length(self):
327 def usable_screen_length(self):
330 if self.screen_length == 0:
328 if self.screen_length == 0:
331 return 0
329 return 0
332 else:
330 else:
333 num_lines_bot = self.separate_in.count('\n')+1
331 num_lines_bot = self.separate_in.count('\n')+1
334 return self.screen_length - num_lines_bot
332 return self.screen_length - num_lines_bot
335
333
336 def _term_title_changed(self, name, new_value):
334 def _term_title_changed(self, name, new_value):
337 self.init_term_title()
335 self.init_term_title()
338
336
339 def init_term_title(self):
337 def init_term_title(self):
340 # Enable or disable the terminal title.
338 # Enable or disable the terminal title.
341 if self.term_title:
339 if self.term_title:
342 toggle_set_term_title(True)
340 toggle_set_term_title(True)
343 set_term_title('IPython: ' + abbrev_cwd())
341 set_term_title('IPython: ' + abbrev_cwd())
344 else:
342 else:
345 toggle_set_term_title(False)
343 toggle_set_term_title(False)
346
344
347 #-------------------------------------------------------------------------
345 #-------------------------------------------------------------------------
348 # Things related to aliases
346 # Things related to aliases
349 #-------------------------------------------------------------------------
347 #-------------------------------------------------------------------------
350
348
351 def init_alias(self):
349 def init_alias(self):
352 # The parent class defines aliases that can be safely used with any
350 # The parent class defines aliases that can be safely used with any
353 # frontend.
351 # frontend.
354 super(TerminalInteractiveShell, self).init_alias()
352 super(TerminalInteractiveShell, self).init_alias()
355
353
356 # Now define aliases that only make sense on the terminal, because they
354 # Now define aliases that only make sense on the terminal, because they
357 # need direct access to the console in a way that we can't emulate in
355 # need direct access to the console in a way that we can't emulate in
358 # GUI or web frontend
356 # GUI or web frontend
359 if os.name == 'posix':
357 if os.name == 'posix':
360 aliases = [('clear', 'clear'), ('more', 'more'), ('less', 'less'),
358 aliases = [('clear', 'clear'), ('more', 'more'), ('less', 'less'),
361 ('man', 'man')]
359 ('man', 'man')]
362 else :
360 else :
363 aliases = []
361 aliases = []
364
362
365 for name, cmd in aliases:
363 for name, cmd in aliases:
366 self.alias_manager.soft_define_alias(name, cmd)
364 self.alias_manager.soft_define_alias(name, cmd)
367
365
368 #-------------------------------------------------------------------------
366 #-------------------------------------------------------------------------
369 # Mainloop and code execution logic
367 # Mainloop and code execution logic
370 #-------------------------------------------------------------------------
368 #-------------------------------------------------------------------------
371
369
372 def mainloop(self, display_banner=None):
370 def mainloop(self, display_banner=None):
373 """Start the mainloop.
371 """Start the mainloop.
374
372
375 If an optional banner argument is given, it will override the
373 If an optional banner argument is given, it will override the
376 internally created default banner.
374 internally created default banner.
377 """
375 """
378
376
379 with self.builtin_trap, self.display_trap:
377 with self.builtin_trap, self.display_trap:
380
378
381 while 1:
379 while 1:
382 try:
380 try:
383 self.interact(display_banner=display_banner)
381 self.interact(display_banner=display_banner)
384 #self.interact_with_readline()
382 #self.interact_with_readline()
385 # XXX for testing of a readline-decoupled repl loop, call
383 # XXX for testing of a readline-decoupled repl loop, call
386 # interact_with_readline above
384 # interact_with_readline above
387 break
385 break
388 except KeyboardInterrupt:
386 except KeyboardInterrupt:
389 # this should not be necessary, but KeyboardInterrupt
387 # this should not be necessary, but KeyboardInterrupt
390 # handling seems rather unpredictable...
388 # handling seems rather unpredictable...
391 self.write("\nKeyboardInterrupt in interact()\n")
389 self.write("\nKeyboardInterrupt in interact()\n")
392
390
393 def _replace_rlhist_multiline(self, source_raw, hlen_before_cell):
391 def _replace_rlhist_multiline(self, source_raw, hlen_before_cell):
394 """Store multiple lines as a single entry in history"""
392 """Store multiple lines as a single entry in history"""
395
393
396 # do nothing without readline or disabled multiline
394 # do nothing without readline or disabled multiline
397 if not self.has_readline or not self.multiline_history:
395 if not self.has_readline or not self.multiline_history:
398 return hlen_before_cell
396 return hlen_before_cell
399
397
400 # windows rl has no remove_history_item
398 # windows rl has no remove_history_item
401 if not hasattr(self.readline, "remove_history_item"):
399 if not hasattr(self.readline, "remove_history_item"):
402 return hlen_before_cell
400 return hlen_before_cell
403
401
404 # skip empty cells
402 # skip empty cells
405 if not source_raw.rstrip():
403 if not source_raw.rstrip():
406 return hlen_before_cell
404 return hlen_before_cell
407
405
408 # nothing changed do nothing, e.g. when rl removes consecutive dups
406 # nothing changed do nothing, e.g. when rl removes consecutive dups
409 hlen = self.readline.get_current_history_length()
407 hlen = self.readline.get_current_history_length()
410 if hlen == hlen_before_cell:
408 if hlen == hlen_before_cell:
411 return hlen_before_cell
409 return hlen_before_cell
412
410
413 for i in range(hlen - hlen_before_cell):
411 for i in range(hlen - hlen_before_cell):
414 self.readline.remove_history_item(hlen - i - 1)
412 self.readline.remove_history_item(hlen - i - 1)
415 stdin_encoding = get_stream_enc(sys.stdin, 'utf-8')
413 stdin_encoding = get_stream_enc(sys.stdin, 'utf-8')
416 self.readline.add_history(py3compat.unicode_to_str(source_raw.rstrip(),
414 self.readline.add_history(py3compat.unicode_to_str(source_raw.rstrip(),
417 stdin_encoding))
415 stdin_encoding))
418 return self.readline.get_current_history_length()
416 return self.readline.get_current_history_length()
419
417
420 def interact(self, display_banner=None):
418 def interact(self, display_banner=None):
421 """Closely emulate the interactive Python console."""
419 """Closely emulate the interactive Python console."""
422
420
423 # batch run -> do not interact
421 # batch run -> do not interact
424 if self.exit_now:
422 if self.exit_now:
425 return
423 return
426
424
427 if display_banner is None:
425 if display_banner is None:
428 display_banner = self.display_banner
426 display_banner = self.display_banner
429
427
430 if isinstance(display_banner, py3compat.string_types):
428 if isinstance(display_banner, py3compat.string_types):
431 self.show_banner(display_banner)
429 self.show_banner(display_banner)
432 elif display_banner:
430 elif display_banner:
433 self.show_banner()
431 self.show_banner()
434
432
435 more = False
433 more = False
436
434
437 if self.has_readline:
435 if self.has_readline:
438 self.readline_startup_hook(self.pre_readline)
436 self.readline_startup_hook(self.pre_readline)
439 hlen_b4_cell = self.readline.get_current_history_length()
437 hlen_b4_cell = self.readline.get_current_history_length()
440 else:
438 else:
441 hlen_b4_cell = 0
439 hlen_b4_cell = 0
442 # exit_now is set by a call to %Exit or %Quit, through the
440 # exit_now is set by a call to %Exit or %Quit, through the
443 # ask_exit callback.
441 # ask_exit callback.
444
442
445 while not self.exit_now:
443 while not self.exit_now:
446 self.hooks.pre_prompt_hook()
444 self.hooks.pre_prompt_hook()
447 if more:
445 if more:
448 try:
446 try:
449 prompt = self.prompt_manager.render('in2')
447 prompt = self.prompt_manager.render('in2')
450 except:
448 except:
451 self.showtraceback()
449 self.showtraceback()
452 if self.autoindent:
450 if self.autoindent:
453 self.rl_do_indent = True
451 self.rl_do_indent = True
454
452
455 else:
453 else:
456 try:
454 try:
457 prompt = self.separate_in + self.prompt_manager.render('in')
455 prompt = self.separate_in + self.prompt_manager.render('in')
458 except:
456 except:
459 self.showtraceback()
457 self.showtraceback()
460 try:
458 try:
461 line = self.raw_input(prompt)
459 line = self.raw_input(prompt)
462 if self.exit_now:
460 if self.exit_now:
463 # quick exit on sys.std[in|out] close
461 # quick exit on sys.std[in|out] close
464 break
462 break
465 if self.autoindent:
463 if self.autoindent:
466 self.rl_do_indent = False
464 self.rl_do_indent = False
467
465
468 except KeyboardInterrupt:
466 except KeyboardInterrupt:
469 #double-guard against keyboardinterrupts during kbdint handling
467 #double-guard against keyboardinterrupts during kbdint handling
470 try:
468 try:
471 self.write('\n' + self.get_exception_only())
469 self.write('\n' + self.get_exception_only())
472 source_raw = self.input_splitter.raw_reset()
470 source_raw = self.input_splitter.raw_reset()
473 hlen_b4_cell = \
471 hlen_b4_cell = \
474 self._replace_rlhist_multiline(source_raw, hlen_b4_cell)
472 self._replace_rlhist_multiline(source_raw, hlen_b4_cell)
475 more = False
473 more = False
476 except KeyboardInterrupt:
474 except KeyboardInterrupt:
477 pass
475 pass
478 except EOFError:
476 except EOFError:
479 if self.autoindent:
477 if self.autoindent:
480 self.rl_do_indent = False
478 self.rl_do_indent = False
481 if self.has_readline:
479 if self.has_readline:
482 self.readline_startup_hook(None)
480 self.readline_startup_hook(None)
483 self.write('\n')
481 self.write('\n')
484 self.exit()
482 self.exit()
485 except bdb.BdbQuit:
483 except bdb.BdbQuit:
486 warn('The Python debugger has exited with a BdbQuit exception.\n'
484 warn('The Python debugger has exited with a BdbQuit exception.\n'
487 'Because of how pdb handles the stack, it is impossible\n'
485 'Because of how pdb handles the stack, it is impossible\n'
488 'for IPython to properly format this particular exception.\n'
486 'for IPython to properly format this particular exception.\n'
489 'IPython will resume normal operation.')
487 'IPython will resume normal operation.')
490 except:
488 except:
491 # exceptions here are VERY RARE, but they can be triggered
489 # exceptions here are VERY RARE, but they can be triggered
492 # asynchronously by signal handlers, for example.
490 # asynchronously by signal handlers, for example.
493 self.showtraceback()
491 self.showtraceback()
494 else:
492 else:
495 try:
493 try:
496 self.input_splitter.push(line)
494 self.input_splitter.push(line)
497 more = self.input_splitter.push_accepts_more()
495 more = self.input_splitter.push_accepts_more()
498 except SyntaxError:
496 except SyntaxError:
499 # Run the code directly - run_cell takes care of displaying
497 # Run the code directly - run_cell takes care of displaying
500 # the exception.
498 # the exception.
501 more = False
499 more = False
502 if (self.SyntaxTB.last_syntax_error and
500 if (self.SyntaxTB.last_syntax_error and
503 self.autoedit_syntax):
501 self.autoedit_syntax):
504 self.edit_syntax_error()
502 self.edit_syntax_error()
505 if not more:
503 if not more:
506 source_raw = self.input_splitter.raw_reset()
504 source_raw = self.input_splitter.raw_reset()
507 self.run_cell(source_raw, store_history=True)
505 self.run_cell(source_raw, store_history=True)
508 hlen_b4_cell = \
506 hlen_b4_cell = \
509 self._replace_rlhist_multiline(source_raw, hlen_b4_cell)
507 self._replace_rlhist_multiline(source_raw, hlen_b4_cell)
510
508
511 # Turn off the exit flag, so the mainloop can be restarted if desired
509 # Turn off the exit flag, so the mainloop can be restarted if desired
512 self.exit_now = False
510 self.exit_now = False
513
511
514 def raw_input(self, prompt=''):
512 def raw_input(self, prompt=''):
515 """Write a prompt and read a line.
513 """Write a prompt and read a line.
516
514
517 The returned line does not include the trailing newline.
515 The returned line does not include the trailing newline.
518 When the user enters the EOF key sequence, EOFError is raised.
516 When the user enters the EOF key sequence, EOFError is raised.
519
517
520 Parameters
518 Parameters
521 ----------
519 ----------
522
520
523 prompt : str, optional
521 prompt : str, optional
524 A string to be printed to prompt the user.
522 A string to be printed to prompt the user.
525 """
523 """
526 # raw_input expects str, but we pass it unicode sometimes
524 # raw_input expects str, but we pass it unicode sometimes
527 prompt = py3compat.cast_bytes_py2(prompt)
525 prompt = py3compat.cast_bytes_py2(prompt)
528
526
529 try:
527 try:
530 line = py3compat.str_to_unicode(self.raw_input_original(prompt))
528 line = py3compat.str_to_unicode(self.raw_input_original(prompt))
531 except ValueError:
529 except ValueError:
532 warn("\n********\nYou or a %run:ed script called sys.stdin.close()"
530 warn("\n********\nYou or a %run:ed script called sys.stdin.close()"
533 " or sys.stdout.close()!\nExiting IPython!\n")
531 " or sys.stdout.close()!\nExiting IPython!\n")
534 self.ask_exit()
532 self.ask_exit()
535 return ""
533 return ""
536
534
537 # Try to be reasonably smart about not re-indenting pasted input more
535 # Try to be reasonably smart about not re-indenting pasted input more
538 # than necessary. We do this by trimming out the auto-indent initial
536 # than necessary. We do this by trimming out the auto-indent initial
539 # spaces, if the user's actual input started itself with whitespace.
537 # spaces, if the user's actual input started itself with whitespace.
540 if self.autoindent:
538 if self.autoindent:
541 if num_ini_spaces(line) > self.indent_current_nsp:
539 if num_ini_spaces(line) > self.indent_current_nsp:
542 line = line[self.indent_current_nsp:]
540 line = line[self.indent_current_nsp:]
543 self.indent_current_nsp = 0
541 self.indent_current_nsp = 0
544
542
545 return line
543 return line
546
544
547 #-------------------------------------------------------------------------
545 #-------------------------------------------------------------------------
548 # Methods to support auto-editing of SyntaxErrors.
546 # Methods to support auto-editing of SyntaxErrors.
549 #-------------------------------------------------------------------------
547 #-------------------------------------------------------------------------
550
548
551 def edit_syntax_error(self):
549 def edit_syntax_error(self):
552 """The bottom half of the syntax error handler called in the main loop.
550 """The bottom half of the syntax error handler called in the main loop.
553
551
554 Loop until syntax error is fixed or user cancels.
552 Loop until syntax error is fixed or user cancels.
555 """
553 """
556
554
557 while self.SyntaxTB.last_syntax_error:
555 while self.SyntaxTB.last_syntax_error:
558 # copy and clear last_syntax_error
556 # copy and clear last_syntax_error
559 err = self.SyntaxTB.clear_err_state()
557 err = self.SyntaxTB.clear_err_state()
560 if not self._should_recompile(err):
558 if not self._should_recompile(err):
561 return
559 return
562 try:
560 try:
563 # may set last_syntax_error again if a SyntaxError is raised
561 # may set last_syntax_error again if a SyntaxError is raised
564 self.safe_execfile(err.filename,self.user_ns)
562 self.safe_execfile(err.filename,self.user_ns)
565 except:
563 except:
566 self.showtraceback()
564 self.showtraceback()
567 else:
565 else:
568 try:
566 try:
569 f = open(err.filename)
567 f = open(err.filename)
570 try:
568 try:
571 # This should be inside a display_trap block and I
569 # This should be inside a display_trap block and I
572 # think it is.
570 # think it is.
573 sys.displayhook(f.read())
571 sys.displayhook(f.read())
574 finally:
572 finally:
575 f.close()
573 f.close()
576 except:
574 except:
577 self.showtraceback()
575 self.showtraceback()
578
576
579 def _should_recompile(self,e):
577 def _should_recompile(self,e):
580 """Utility routine for edit_syntax_error"""
578 """Utility routine for edit_syntax_error"""
581
579
582 if e.filename in ('<ipython console>','<input>','<string>',
580 if e.filename in ('<ipython console>','<input>','<string>',
583 '<console>','<BackgroundJob compilation>',
581 '<console>','<BackgroundJob compilation>',
584 None):
582 None):
585
583
586 return False
584 return False
587 try:
585 try:
588 if (self.autoedit_syntax and
586 if (self.autoedit_syntax and
589 not self.ask_yes_no('Return to editor to correct syntax error? '
587 not self.ask_yes_no('Return to editor to correct syntax error? '
590 '[Y/n] ','y')):
588 '[Y/n] ','y')):
591 return False
589 return False
592 except EOFError:
590 except EOFError:
593 return False
591 return False
594
592
595 def int0(x):
593 def int0(x):
596 try:
594 try:
597 return int(x)
595 return int(x)
598 except TypeError:
596 except TypeError:
599 return 0
597 return 0
600 # always pass integer line and offset values to editor hook
598 # always pass integer line and offset values to editor hook
601 try:
599 try:
602 self.hooks.fix_error_editor(e.filename,
600 self.hooks.fix_error_editor(e.filename,
603 int0(e.lineno),int0(e.offset),e.msg)
601 int0(e.lineno),int0(e.offset),e.msg)
604 except TryNext:
602 except TryNext:
605 warn('Could not open editor')
603 warn('Could not open editor')
606 return False
604 return False
607 return True
605 return True
608
606
609 #-------------------------------------------------------------------------
607 #-------------------------------------------------------------------------
610 # Things related to exiting
608 # Things related to exiting
611 #-------------------------------------------------------------------------
609 #-------------------------------------------------------------------------
612
610
613 def ask_exit(self):
611 def ask_exit(self):
614 """ Ask the shell to exit. Can be overiden and used as a callback. """
612 """ Ask the shell to exit. Can be overiden and used as a callback. """
615 self.exit_now = True
613 self.exit_now = True
616
614
617 def exit(self):
615 def exit(self):
618 """Handle interactive exit.
616 """Handle interactive exit.
619
617
620 This method calls the ask_exit callback."""
618 This method calls the ask_exit callback."""
621 if self.confirm_exit:
619 if self.confirm_exit:
622 if self.ask_yes_no('Do you really want to exit ([y]/n)?','y'):
620 if self.ask_yes_no('Do you really want to exit ([y]/n)?','y'):
623 self.ask_exit()
621 self.ask_exit()
624 else:
622 else:
625 self.ask_exit()
623 self.ask_exit()
626
624
627 #-------------------------------------------------------------------------
625 #-------------------------------------------------------------------------
628 # Things related to magics
626 # Things related to magics
629 #-------------------------------------------------------------------------
627 #-------------------------------------------------------------------------
630
628
631 def init_magics(self):
629 def init_magics(self):
632 super(TerminalInteractiveShell, self).init_magics()
630 super(TerminalInteractiveShell, self).init_magics()
633 self.register_magics(TerminalMagics)
631 self.register_magics(TerminalMagics)
634
632
635 def showindentationerror(self):
633 def showindentationerror(self):
636 super(TerminalInteractiveShell, self).showindentationerror()
634 super(TerminalInteractiveShell, self).showindentationerror()
637 if not self.using_paste_magics:
635 if not self.using_paste_magics:
638 print("If you want to paste code into IPython, try the "
636 print("If you want to paste code into IPython, try the "
639 "%paste and %cpaste magic functions.")
637 "%paste and %cpaste magic functions.")
640
638
641
639
642 InteractiveShellABC.register(TerminalInteractiveShell)
640 InteractiveShellABC.register(TerminalInteractiveShell)
@@ -1,169 +1,165 b''
1 """Tests for the decorators we've created for IPython.
1 """Tests for the decorators we've created for IPython.
2 """
2 """
3 from __future__ import print_function
3 from __future__ import print_function
4
4
5 # Module imports
5 # Module imports
6 # Std lib
6 # Std lib
7 import inspect
7 import inspect
8 import sys
8 import sys
9
9
10 # Third party
10 # Third party
11 import nose.tools as nt
11 import nose.tools as nt
12
12
13 # Our own
13 # Our own
14 from IPython.testing import decorators as dec
14 from IPython.testing import decorators as dec
15 from IPython.testing.skipdoctest import skip_doctest
16
15
17 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
18 # Utilities
17 # Utilities
19
18
20 # Note: copied from OInspect, kept here so the testing stuff doesn't create
19 # Note: copied from OInspect, kept here so the testing stuff doesn't create
21 # circular dependencies and is easier to reuse.
20 # circular dependencies and is easier to reuse.
22 def getargspec(obj):
21 def getargspec(obj):
23 """Get the names and default values of a function's arguments.
22 """Get the names and default values of a function's arguments.
24
23
25 A tuple of four things is returned: (args, varargs, varkw, defaults).
24 A tuple of four things is returned: (args, varargs, varkw, defaults).
26 'args' is a list of the argument names (it may contain nested lists).
25 'args' is a list of the argument names (it may contain nested lists).
27 'varargs' and 'varkw' are the names of the * and ** arguments or None.
26 'varargs' and 'varkw' are the names of the * and ** arguments or None.
28 'defaults' is an n-tuple of the default values of the last n arguments.
27 'defaults' is an n-tuple of the default values of the last n arguments.
29
28
30 Modified version of inspect.getargspec from the Python Standard
29 Modified version of inspect.getargspec from the Python Standard
31 Library."""
30 Library."""
32
31
33 if inspect.isfunction(obj):
32 if inspect.isfunction(obj):
34 func_obj = obj
33 func_obj = obj
35 elif inspect.ismethod(obj):
34 elif inspect.ismethod(obj):
36 func_obj = obj.__func__
35 func_obj = obj.__func__
37 else:
36 else:
38 raise TypeError('arg is not a Python function')
37 raise TypeError('arg is not a Python function')
39 args, varargs, varkw = inspect.getargs(func_obj.__code__)
38 args, varargs, varkw = inspect.getargs(func_obj.__code__)
40 return args, varargs, varkw, func_obj.__defaults__
39 return args, varargs, varkw, func_obj.__defaults__
41
40
42 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
43 # Testing functions
42 # Testing functions
44
43
45 @dec.as_unittest
44 @dec.as_unittest
46 def trivial():
45 def trivial():
47 """A trivial test"""
46 """A trivial test"""
48 pass
47 pass
49
48
50
49
51 @dec.skip
50 @dec.skip
52 def test_deliberately_broken():
51 def test_deliberately_broken():
53 """A deliberately broken test - we want to skip this one."""
52 """A deliberately broken test - we want to skip this one."""
54 1/0
53 1/0
55
54
56 @dec.skip('Testing the skip decorator')
55 @dec.skip('Testing the skip decorator')
57 def test_deliberately_broken2():
56 def test_deliberately_broken2():
58 """Another deliberately broken test - we want to skip this one."""
57 """Another deliberately broken test - we want to skip this one."""
59 1/0
58 1/0
60
59
61
60
62 # Verify that we can correctly skip the doctest for a function at will, but
61 # Verify that we can correctly skip the doctest for a function at will, but
63 # that the docstring itself is NOT destroyed by the decorator.
62 # that the docstring itself is NOT destroyed by the decorator.
64 @skip_doctest
65 def doctest_bad(x,y=1,**k):
63 def doctest_bad(x,y=1,**k):
66 """A function whose doctest we need to skip.
64 """A function whose doctest we need to skip.
67
65
68 >>> 1+1
66 >>> 1+1
69 3
67 3
70 """
68 """
71 print('x:',x)
69 print('x:',x)
72 print('y:',y)
70 print('y:',y)
73 print('k:',k)
71 print('k:',k)
74
72
75
73
76 def call_doctest_bad():
74 def call_doctest_bad():
77 """Check that we can still call the decorated functions.
75 """Check that we can still call the decorated functions.
78
76
79 >>> doctest_bad(3,y=4)
77 >>> doctest_bad(3,y=4)
80 x: 3
78 x: 3
81 y: 4
79 y: 4
82 k: {}
80 k: {}
83 """
81 """
84 pass
82 pass
85
83
86
84
87 def test_skip_dt_decorator():
85 def test_skip_dt_decorator():
88 """Doctest-skipping decorator should preserve the docstring.
86 """Doctest-skipping decorator should preserve the docstring.
89 """
87 """
90 # Careful: 'check' must be a *verbatim* copy of the doctest_bad docstring!
88 # Careful: 'check' must be a *verbatim* copy of the doctest_bad docstring!
91 check = """A function whose doctest we need to skip.
89 check = """A function whose doctest we need to skip.
92
90
93 >>> 1+1
91 >>> 1+1
94 3
92 3
95 """
93 """
96 # Fetch the docstring from doctest_bad after decoration.
94 # Fetch the docstring from doctest_bad after decoration.
97 val = doctest_bad.__doc__
95 val = doctest_bad.__doc__
98
96
99 nt.assert_equal(check,val,"doctest_bad docstrings don't match")
97 nt.assert_equal(check,val,"doctest_bad docstrings don't match")
100
98
101
99
102 # Doctest skipping should work for class methods too
100 # Doctest skipping should work for class methods too
103 class FooClass(object):
101 class FooClass(object):
104 """FooClass
102 """FooClass
105
103
106 Example:
104 Example:
107
105
108 >>> 1+1
106 >>> 1+1
109 2
107 2
110 """
108 """
111
109
112 @skip_doctest
113 def __init__(self,x):
110 def __init__(self,x):
114 """Make a FooClass.
111 """Make a FooClass.
115
112
116 Example:
113 Example:
117
114
118 >>> f = FooClass(3)
115 >>> f = FooClass(3)
119 junk
116 junk
120 """
117 """
121 print('Making a FooClass.')
118 print('Making a FooClass.')
122 self.x = x
119 self.x = x
123
120
124 @skip_doctest
125 def bar(self,y):
121 def bar(self,y):
126 """Example:
122 """Example:
127
123
128 >>> ff = FooClass(3)
124 >>> ff = FooClass(3)
129 >>> ff.bar(0)
125 >>> ff.bar(0)
130 boom!
126 boom!
131 >>> 1/0
127 >>> 1/0
132 bam!
128 bam!
133 """
129 """
134 return 1/y
130 return 1/y
135
131
136 def baz(self,y):
132 def baz(self,y):
137 """Example:
133 """Example:
138
134
139 >>> ff2 = FooClass(3)
135 >>> ff2 = FooClass(3)
140 Making a FooClass.
136 Making a FooClass.
141 >>> ff2.baz(3)
137 >>> ff2.baz(3)
142 True
138 True
143 """
139 """
144 return self.x==y
140 return self.x==y
145
141
146
142
147 def test_skip_dt_decorator2():
143 def test_skip_dt_decorator2():
148 """Doctest-skipping decorator should preserve function signature.
144 """Doctest-skipping decorator should preserve function signature.
149 """
145 """
150 # Hardcoded correct answer
146 # Hardcoded correct answer
151 dtargs = (['x', 'y'], None, 'k', (1,))
147 dtargs = (['x', 'y'], None, 'k', (1,))
152 # Introspect out the value
148 # Introspect out the value
153 dtargsr = getargspec(doctest_bad)
149 dtargsr = getargspec(doctest_bad)
154 assert dtargsr==dtargs, \
150 assert dtargsr==dtargs, \
155 "Incorrectly reconstructed args for doctest_bad: %s" % (dtargsr,)
151 "Incorrectly reconstructed args for doctest_bad: %s" % (dtargsr,)
156
152
157
153
158 @dec.skip_linux
154 @dec.skip_linux
159 def test_linux():
155 def test_linux():
160 nt.assert_false(sys.platform.startswith('linux'),"This test can't run under linux")
156 nt.assert_false(sys.platform.startswith('linux'),"This test can't run under linux")
161
157
162 @dec.skip_win32
158 @dec.skip_win32
163 def test_win32():
159 def test_win32():
164 nt.assert_not_equal(sys.platform,'win32',"This test can't run under windows")
160 nt.assert_not_equal(sys.platform,'win32',"This test can't run under windows")
165
161
166 @dec.skip_osx
162 @dec.skip_osx
167 def test_osx():
163 def test_osx():
168 nt.assert_not_equal(sys.platform,'darwin',"This test can't run under osx")
164 nt.assert_not_equal(sys.platform,'darwin',"This test can't run under osx")
169
165
@@ -1,448 +1,446 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Utilities for path handling.
3 Utilities for path handling.
4 """
4 """
5
5
6 # Copyright (c) IPython Development Team.
6 # Copyright (c) IPython Development Team.
7 # Distributed under the terms of the Modified BSD License.
7 # Distributed under the terms of the Modified BSD License.
8
8
9 import os
9 import os
10 import sys
10 import sys
11 import errno
11 import errno
12 import shutil
12 import shutil
13 import random
13 import random
14 import tempfile
14 import tempfile
15 import glob
15 import glob
16 from warnings import warn
16 from warnings import warn
17 from hashlib import md5
17 from hashlib import md5
18
18
19 from IPython.testing.skipdoctest import skip_doctest
20 from IPython.utils.process import system
19 from IPython.utils.process import system
21 from IPython.utils import py3compat
20 from IPython.utils import py3compat
22 from IPython.utils.decorators import undoc
21 from IPython.utils.decorators import undoc
23
22
24 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
25 # Code
24 # Code
26 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
27
26
28 fs_encoding = sys.getfilesystemencoding()
27 fs_encoding = sys.getfilesystemencoding()
29
28
30 def _writable_dir(path):
29 def _writable_dir(path):
31 """Whether `path` is a directory, to which the user has write access."""
30 """Whether `path` is a directory, to which the user has write access."""
32 return os.path.isdir(path) and os.access(path, os.W_OK)
31 return os.path.isdir(path) and os.access(path, os.W_OK)
33
32
34 if sys.platform == 'win32':
33 if sys.platform == 'win32':
35 @skip_doctest
36 def _get_long_path_name(path):
34 def _get_long_path_name(path):
37 """Get a long path name (expand ~) on Windows using ctypes.
35 """Get a long path name (expand ~) on Windows using ctypes.
38
36
39 Examples
37 Examples
40 --------
38 --------
41
39
42 >>> get_long_path_name('c:\\docume~1')
40 >>> get_long_path_name('c:\\docume~1')
43 u'c:\\\\Documents and Settings'
41 u'c:\\\\Documents and Settings'
44
42
45 """
43 """
46 try:
44 try:
47 import ctypes
45 import ctypes
48 except ImportError:
46 except ImportError:
49 raise ImportError('you need to have ctypes installed for this to work')
47 raise ImportError('you need to have ctypes installed for this to work')
50 _GetLongPathName = ctypes.windll.kernel32.GetLongPathNameW
48 _GetLongPathName = ctypes.windll.kernel32.GetLongPathNameW
51 _GetLongPathName.argtypes = [ctypes.c_wchar_p, ctypes.c_wchar_p,
49 _GetLongPathName.argtypes = [ctypes.c_wchar_p, ctypes.c_wchar_p,
52 ctypes.c_uint ]
50 ctypes.c_uint ]
53
51
54 buf = ctypes.create_unicode_buffer(260)
52 buf = ctypes.create_unicode_buffer(260)
55 rv = _GetLongPathName(path, buf, 260)
53 rv = _GetLongPathName(path, buf, 260)
56 if rv == 0 or rv > 260:
54 if rv == 0 or rv > 260:
57 return path
55 return path
58 else:
56 else:
59 return buf.value
57 return buf.value
60 else:
58 else:
61 def _get_long_path_name(path):
59 def _get_long_path_name(path):
62 """Dummy no-op."""
60 """Dummy no-op."""
63 return path
61 return path
64
62
65
63
66
64
67 def get_long_path_name(path):
65 def get_long_path_name(path):
68 """Expand a path into its long form.
66 """Expand a path into its long form.
69
67
70 On Windows this expands any ~ in the paths. On other platforms, it is
68 On Windows this expands any ~ in the paths. On other platforms, it is
71 a null operation.
69 a null operation.
72 """
70 """
73 return _get_long_path_name(path)
71 return _get_long_path_name(path)
74
72
75
73
76 def unquote_filename(name, win32=(sys.platform=='win32')):
74 def unquote_filename(name, win32=(sys.platform=='win32')):
77 """ On Windows, remove leading and trailing quotes from filenames.
75 """ On Windows, remove leading and trailing quotes from filenames.
78 """
76 """
79 if win32:
77 if win32:
80 if name.startswith(("'", '"')) and name.endswith(("'", '"')):
78 if name.startswith(("'", '"')) and name.endswith(("'", '"')):
81 name = name[1:-1]
79 name = name[1:-1]
82 return name
80 return name
83
81
84 def compress_user(path):
82 def compress_user(path):
85 """Reverse of :func:`os.path.expanduser`
83 """Reverse of :func:`os.path.expanduser`
86 """
84 """
87 home = os.path.expanduser('~')
85 home = os.path.expanduser('~')
88 if path.startswith(home):
86 if path.startswith(home):
89 path = "~" + path[len(home):]
87 path = "~" + path[len(home):]
90 return path
88 return path
91
89
92 def get_py_filename(name, force_win32=None):
90 def get_py_filename(name, force_win32=None):
93 """Return a valid python filename in the current directory.
91 """Return a valid python filename in the current directory.
94
92
95 If the given name is not a file, it adds '.py' and searches again.
93 If the given name is not a file, it adds '.py' and searches again.
96 Raises IOError with an informative message if the file isn't found.
94 Raises IOError with an informative message if the file isn't found.
97
95
98 On Windows, apply Windows semantics to the filename. In particular, remove
96 On Windows, apply Windows semantics to the filename. In particular, remove
99 any quoting that has been applied to it. This option can be forced for
97 any quoting that has been applied to it. This option can be forced for
100 testing purposes.
98 testing purposes.
101 """
99 """
102
100
103 name = os.path.expanduser(name)
101 name = os.path.expanduser(name)
104 if force_win32 is None:
102 if force_win32 is None:
105 win32 = (sys.platform == 'win32')
103 win32 = (sys.platform == 'win32')
106 else:
104 else:
107 win32 = force_win32
105 win32 = force_win32
108 name = unquote_filename(name, win32=win32)
106 name = unquote_filename(name, win32=win32)
109 if not os.path.isfile(name) and not name.endswith('.py'):
107 if not os.path.isfile(name) and not name.endswith('.py'):
110 name += '.py'
108 name += '.py'
111 if os.path.isfile(name):
109 if os.path.isfile(name):
112 return name
110 return name
113 else:
111 else:
114 raise IOError('File `%r` not found.' % name)
112 raise IOError('File `%r` not found.' % name)
115
113
116
114
117 def filefind(filename, path_dirs=None):
115 def filefind(filename, path_dirs=None):
118 """Find a file by looking through a sequence of paths.
116 """Find a file by looking through a sequence of paths.
119
117
120 This iterates through a sequence of paths looking for a file and returns
118 This iterates through a sequence of paths looking for a file and returns
121 the full, absolute path of the first occurence of the file. If no set of
119 the full, absolute path of the first occurence of the file. If no set of
122 path dirs is given, the filename is tested as is, after running through
120 path dirs is given, the filename is tested as is, after running through
123 :func:`expandvars` and :func:`expanduser`. Thus a simple call::
121 :func:`expandvars` and :func:`expanduser`. Thus a simple call::
124
122
125 filefind('myfile.txt')
123 filefind('myfile.txt')
126
124
127 will find the file in the current working dir, but::
125 will find the file in the current working dir, but::
128
126
129 filefind('~/myfile.txt')
127 filefind('~/myfile.txt')
130
128
131 Will find the file in the users home directory. This function does not
129 Will find the file in the users home directory. This function does not
132 automatically try any paths, such as the cwd or the user's home directory.
130 automatically try any paths, such as the cwd or the user's home directory.
133
131
134 Parameters
132 Parameters
135 ----------
133 ----------
136 filename : str
134 filename : str
137 The filename to look for.
135 The filename to look for.
138 path_dirs : str, None or sequence of str
136 path_dirs : str, None or sequence of str
139 The sequence of paths to look for the file in. If None, the filename
137 The sequence of paths to look for the file in. If None, the filename
140 need to be absolute or be in the cwd. If a string, the string is
138 need to be absolute or be in the cwd. If a string, the string is
141 put into a sequence and the searched. If a sequence, walk through
139 put into a sequence and the searched. If a sequence, walk through
142 each element and join with ``filename``, calling :func:`expandvars`
140 each element and join with ``filename``, calling :func:`expandvars`
143 and :func:`expanduser` before testing for existence.
141 and :func:`expanduser` before testing for existence.
144
142
145 Returns
143 Returns
146 -------
144 -------
147 Raises :exc:`IOError` or returns absolute path to file.
145 Raises :exc:`IOError` or returns absolute path to file.
148 """
146 """
149
147
150 # If paths are quoted, abspath gets confused, strip them...
148 # If paths are quoted, abspath gets confused, strip them...
151 filename = filename.strip('"').strip("'")
149 filename = filename.strip('"').strip("'")
152 # If the input is an absolute path, just check it exists
150 # If the input is an absolute path, just check it exists
153 if os.path.isabs(filename) and os.path.isfile(filename):
151 if os.path.isabs(filename) and os.path.isfile(filename):
154 return filename
152 return filename
155
153
156 if path_dirs is None:
154 if path_dirs is None:
157 path_dirs = ("",)
155 path_dirs = ("",)
158 elif isinstance(path_dirs, py3compat.string_types):
156 elif isinstance(path_dirs, py3compat.string_types):
159 path_dirs = (path_dirs,)
157 path_dirs = (path_dirs,)
160
158
161 for path in path_dirs:
159 for path in path_dirs:
162 if path == '.': path = py3compat.getcwd()
160 if path == '.': path = py3compat.getcwd()
163 testname = expand_path(os.path.join(path, filename))
161 testname = expand_path(os.path.join(path, filename))
164 if os.path.isfile(testname):
162 if os.path.isfile(testname):
165 return os.path.abspath(testname)
163 return os.path.abspath(testname)
166
164
167 raise IOError("File %r does not exist in any of the search paths: %r" %
165 raise IOError("File %r does not exist in any of the search paths: %r" %
168 (filename, path_dirs) )
166 (filename, path_dirs) )
169
167
170
168
171 class HomeDirError(Exception):
169 class HomeDirError(Exception):
172 pass
170 pass
173
171
174
172
175 def get_home_dir(require_writable=False):
173 def get_home_dir(require_writable=False):
176 """Return the 'home' directory, as a unicode string.
174 """Return the 'home' directory, as a unicode string.
177
175
178 Uses os.path.expanduser('~'), and checks for writability.
176 Uses os.path.expanduser('~'), and checks for writability.
179
177
180 See stdlib docs for how this is determined.
178 See stdlib docs for how this is determined.
181 $HOME is first priority on *ALL* platforms.
179 $HOME is first priority on *ALL* platforms.
182
180
183 Parameters
181 Parameters
184 ----------
182 ----------
185
183
186 require_writable : bool [default: False]
184 require_writable : bool [default: False]
187 if True:
185 if True:
188 guarantees the return value is a writable directory, otherwise
186 guarantees the return value is a writable directory, otherwise
189 raises HomeDirError
187 raises HomeDirError
190 if False:
188 if False:
191 The path is resolved, but it is not guaranteed to exist or be writable.
189 The path is resolved, but it is not guaranteed to exist or be writable.
192 """
190 """
193
191
194 homedir = os.path.expanduser('~')
192 homedir = os.path.expanduser('~')
195 # Next line will make things work even when /home/ is a symlink to
193 # Next line will make things work even when /home/ is a symlink to
196 # /usr/home as it is on FreeBSD, for example
194 # /usr/home as it is on FreeBSD, for example
197 homedir = os.path.realpath(homedir)
195 homedir = os.path.realpath(homedir)
198
196
199 if not _writable_dir(homedir) and os.name == 'nt':
197 if not _writable_dir(homedir) and os.name == 'nt':
200 # expanduser failed, use the registry to get the 'My Documents' folder.
198 # expanduser failed, use the registry to get the 'My Documents' folder.
201 try:
199 try:
202 try:
200 try:
203 import winreg as wreg # Py 3
201 import winreg as wreg # Py 3
204 except ImportError:
202 except ImportError:
205 import _winreg as wreg # Py 2
203 import _winreg as wreg # Py 2
206 key = wreg.OpenKey(
204 key = wreg.OpenKey(
207 wreg.HKEY_CURRENT_USER,
205 wreg.HKEY_CURRENT_USER,
208 "Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
206 "Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
209 )
207 )
210 homedir = wreg.QueryValueEx(key,'Personal')[0]
208 homedir = wreg.QueryValueEx(key,'Personal')[0]
211 key.Close()
209 key.Close()
212 except:
210 except:
213 pass
211 pass
214
212
215 if (not require_writable) or _writable_dir(homedir):
213 if (not require_writable) or _writable_dir(homedir):
216 return py3compat.cast_unicode(homedir, fs_encoding)
214 return py3compat.cast_unicode(homedir, fs_encoding)
217 else:
215 else:
218 raise HomeDirError('%s is not a writable dir, '
216 raise HomeDirError('%s is not a writable dir, '
219 'set $HOME environment variable to override' % homedir)
217 'set $HOME environment variable to override' % homedir)
220
218
221 def get_xdg_dir():
219 def get_xdg_dir():
222 """Return the XDG_CONFIG_HOME, if it is defined and exists, else None.
220 """Return the XDG_CONFIG_HOME, if it is defined and exists, else None.
223
221
224 This is only for non-OS X posix (Linux,Unix,etc.) systems.
222 This is only for non-OS X posix (Linux,Unix,etc.) systems.
225 """
223 """
226
224
227 env = os.environ
225 env = os.environ
228
226
229 if os.name == 'posix' and sys.platform != 'darwin':
227 if os.name == 'posix' and sys.platform != 'darwin':
230 # Linux, Unix, AIX, etc.
228 # Linux, Unix, AIX, etc.
231 # use ~/.config if empty OR not set
229 # use ~/.config if empty OR not set
232 xdg = env.get("XDG_CONFIG_HOME", None) or os.path.join(get_home_dir(), '.config')
230 xdg = env.get("XDG_CONFIG_HOME", None) or os.path.join(get_home_dir(), '.config')
233 if xdg and _writable_dir(xdg):
231 if xdg and _writable_dir(xdg):
234 return py3compat.cast_unicode(xdg, fs_encoding)
232 return py3compat.cast_unicode(xdg, fs_encoding)
235
233
236 return None
234 return None
237
235
238
236
239 def get_xdg_cache_dir():
237 def get_xdg_cache_dir():
240 """Return the XDG_CACHE_HOME, if it is defined and exists, else None.
238 """Return the XDG_CACHE_HOME, if it is defined and exists, else None.
241
239
242 This is only for non-OS X posix (Linux,Unix,etc.) systems.
240 This is only for non-OS X posix (Linux,Unix,etc.) systems.
243 """
241 """
244
242
245 env = os.environ
243 env = os.environ
246
244
247 if os.name == 'posix' and sys.platform != 'darwin':
245 if os.name == 'posix' and sys.platform != 'darwin':
248 # Linux, Unix, AIX, etc.
246 # Linux, Unix, AIX, etc.
249 # use ~/.cache if empty OR not set
247 # use ~/.cache if empty OR not set
250 xdg = env.get("XDG_CACHE_HOME", None) or os.path.join(get_home_dir(), '.cache')
248 xdg = env.get("XDG_CACHE_HOME", None) or os.path.join(get_home_dir(), '.cache')
251 if xdg and _writable_dir(xdg):
249 if xdg and _writable_dir(xdg):
252 return py3compat.cast_unicode(xdg, fs_encoding)
250 return py3compat.cast_unicode(xdg, fs_encoding)
253
251
254 return None
252 return None
255
253
256
254
257 @undoc
255 @undoc
258 def get_ipython_dir():
256 def get_ipython_dir():
259 warn("get_ipython_dir has moved to the IPython.paths module")
257 warn("get_ipython_dir has moved to the IPython.paths module")
260 from IPython.paths import get_ipython_dir
258 from IPython.paths import get_ipython_dir
261 return get_ipython_dir()
259 return get_ipython_dir()
262
260
263 @undoc
261 @undoc
264 def get_ipython_cache_dir():
262 def get_ipython_cache_dir():
265 warn("get_ipython_cache_dir has moved to the IPython.paths module")
263 warn("get_ipython_cache_dir has moved to the IPython.paths module")
266 from IPython.paths import get_ipython_cache_dir
264 from IPython.paths import get_ipython_cache_dir
267 return get_ipython_cache_dir()
265 return get_ipython_cache_dir()
268
266
269 @undoc
267 @undoc
270 def get_ipython_package_dir():
268 def get_ipython_package_dir():
271 warn("get_ipython_package_dir has moved to the IPython.paths module")
269 warn("get_ipython_package_dir has moved to the IPython.paths module")
272 from IPython.paths import get_ipython_package_dir
270 from IPython.paths import get_ipython_package_dir
273 return get_ipython_package_dir()
271 return get_ipython_package_dir()
274
272
275 @undoc
273 @undoc
276 def get_ipython_module_path(module_str):
274 def get_ipython_module_path(module_str):
277 warn("get_ipython_module_path has moved to the IPython.paths module")
275 warn("get_ipython_module_path has moved to the IPython.paths module")
278 from IPython.paths import get_ipython_module_path
276 from IPython.paths import get_ipython_module_path
279 return get_ipython_module_path(module_str)
277 return get_ipython_module_path(module_str)
280
278
281 @undoc
279 @undoc
282 def locate_profile(profile='default'):
280 def locate_profile(profile='default'):
283 warn("locate_profile has moved to the IPython.paths module")
281 warn("locate_profile has moved to the IPython.paths module")
284 from IPython.paths import locate_profile
282 from IPython.paths import locate_profile
285 return locate_profile(profile=profile)
283 return locate_profile(profile=profile)
286
284
287 def expand_path(s):
285 def expand_path(s):
288 """Expand $VARS and ~names in a string, like a shell
286 """Expand $VARS and ~names in a string, like a shell
289
287
290 :Examples:
288 :Examples:
291
289
292 In [2]: os.environ['FOO']='test'
290 In [2]: os.environ['FOO']='test'
293
291
294 In [3]: expand_path('variable FOO is $FOO')
292 In [3]: expand_path('variable FOO is $FOO')
295 Out[3]: 'variable FOO is test'
293 Out[3]: 'variable FOO is test'
296 """
294 """
297 # This is a pretty subtle hack. When expand user is given a UNC path
295 # This is a pretty subtle hack. When expand user is given a UNC path
298 # on Windows (\\server\share$\%username%), os.path.expandvars, removes
296 # on Windows (\\server\share$\%username%), os.path.expandvars, removes
299 # the $ to get (\\server\share\%username%). I think it considered $
297 # the $ to get (\\server\share\%username%). I think it considered $
300 # alone an empty var. But, we need the $ to remains there (it indicates
298 # alone an empty var. But, we need the $ to remains there (it indicates
301 # a hidden share).
299 # a hidden share).
302 if os.name=='nt':
300 if os.name=='nt':
303 s = s.replace('$\\', 'IPYTHON_TEMP')
301 s = s.replace('$\\', 'IPYTHON_TEMP')
304 s = os.path.expandvars(os.path.expanduser(s))
302 s = os.path.expandvars(os.path.expanduser(s))
305 if os.name=='nt':
303 if os.name=='nt':
306 s = s.replace('IPYTHON_TEMP', '$\\')
304 s = s.replace('IPYTHON_TEMP', '$\\')
307 return s
305 return s
308
306
309
307
310 def unescape_glob(string):
308 def unescape_glob(string):
311 """Unescape glob pattern in `string`."""
309 """Unescape glob pattern in `string`."""
312 def unescape(s):
310 def unescape(s):
313 for pattern in '*[]!?':
311 for pattern in '*[]!?':
314 s = s.replace(r'\{0}'.format(pattern), pattern)
312 s = s.replace(r'\{0}'.format(pattern), pattern)
315 return s
313 return s
316 return '\\'.join(map(unescape, string.split('\\\\')))
314 return '\\'.join(map(unescape, string.split('\\\\')))
317
315
318
316
319 def shellglob(args):
317 def shellglob(args):
320 """
318 """
321 Do glob expansion for each element in `args` and return a flattened list.
319 Do glob expansion for each element in `args` and return a flattened list.
322
320
323 Unmatched glob pattern will remain as-is in the returned list.
321 Unmatched glob pattern will remain as-is in the returned list.
324
322
325 """
323 """
326 expanded = []
324 expanded = []
327 # Do not unescape backslash in Windows as it is interpreted as
325 # Do not unescape backslash in Windows as it is interpreted as
328 # path separator:
326 # path separator:
329 unescape = unescape_glob if sys.platform != 'win32' else lambda x: x
327 unescape = unescape_glob if sys.platform != 'win32' else lambda x: x
330 for a in args:
328 for a in args:
331 expanded.extend(glob.glob(a) or [unescape(a)])
329 expanded.extend(glob.glob(a) or [unescape(a)])
332 return expanded
330 return expanded
333
331
334
332
335 def target_outdated(target,deps):
333 def target_outdated(target,deps):
336 """Determine whether a target is out of date.
334 """Determine whether a target is out of date.
337
335
338 target_outdated(target,deps) -> 1/0
336 target_outdated(target,deps) -> 1/0
339
337
340 deps: list of filenames which MUST exist.
338 deps: list of filenames which MUST exist.
341 target: single filename which may or may not exist.
339 target: single filename which may or may not exist.
342
340
343 If target doesn't exist or is older than any file listed in deps, return
341 If target doesn't exist or is older than any file listed in deps, return
344 true, otherwise return false.
342 true, otherwise return false.
345 """
343 """
346 try:
344 try:
347 target_time = os.path.getmtime(target)
345 target_time = os.path.getmtime(target)
348 except os.error:
346 except os.error:
349 return 1
347 return 1
350 for dep in deps:
348 for dep in deps:
351 dep_time = os.path.getmtime(dep)
349 dep_time = os.path.getmtime(dep)
352 if dep_time > target_time:
350 if dep_time > target_time:
353 #print "For target",target,"Dep failed:",dep # dbg
351 #print "For target",target,"Dep failed:",dep # dbg
354 #print "times (dep,tar):",dep_time,target_time # dbg
352 #print "times (dep,tar):",dep_time,target_time # dbg
355 return 1
353 return 1
356 return 0
354 return 0
357
355
358
356
359 def target_update(target,deps,cmd):
357 def target_update(target,deps,cmd):
360 """Update a target with a given command given a list of dependencies.
358 """Update a target with a given command given a list of dependencies.
361
359
362 target_update(target,deps,cmd) -> runs cmd if target is outdated.
360 target_update(target,deps,cmd) -> runs cmd if target is outdated.
363
361
364 This is just a wrapper around target_outdated() which calls the given
362 This is just a wrapper around target_outdated() which calls the given
365 command if target is outdated."""
363 command if target is outdated."""
366
364
367 if target_outdated(target,deps):
365 if target_outdated(target,deps):
368 system(cmd)
366 system(cmd)
369
367
370 @undoc
368 @undoc
371 def filehash(path):
369 def filehash(path):
372 """Make an MD5 hash of a file, ignoring any differences in line
370 """Make an MD5 hash of a file, ignoring any differences in line
373 ending characters."""
371 ending characters."""
374 warn("filehash() is deprecated")
372 warn("filehash() is deprecated")
375 with open(path, "rU") as f:
373 with open(path, "rU") as f:
376 return md5(py3compat.str_to_bytes(f.read())).hexdigest()
374 return md5(py3compat.str_to_bytes(f.read())).hexdigest()
377
375
378 ENOLINK = 1998
376 ENOLINK = 1998
379
377
380 def link(src, dst):
378 def link(src, dst):
381 """Hard links ``src`` to ``dst``, returning 0 or errno.
379 """Hard links ``src`` to ``dst``, returning 0 or errno.
382
380
383 Note that the special errno ``ENOLINK`` will be returned if ``os.link`` isn't
381 Note that the special errno ``ENOLINK`` will be returned if ``os.link`` isn't
384 supported by the operating system.
382 supported by the operating system.
385 """
383 """
386
384
387 if not hasattr(os, "link"):
385 if not hasattr(os, "link"):
388 return ENOLINK
386 return ENOLINK
389 link_errno = 0
387 link_errno = 0
390 try:
388 try:
391 os.link(src, dst)
389 os.link(src, dst)
392 except OSError as e:
390 except OSError as e:
393 link_errno = e.errno
391 link_errno = e.errno
394 return link_errno
392 return link_errno
395
393
396
394
397 def link_or_copy(src, dst):
395 def link_or_copy(src, dst):
398 """Attempts to hardlink ``src`` to ``dst``, copying if the link fails.
396 """Attempts to hardlink ``src`` to ``dst``, copying if the link fails.
399
397
400 Attempts to maintain the semantics of ``shutil.copy``.
398 Attempts to maintain the semantics of ``shutil.copy``.
401
399
402 Because ``os.link`` does not overwrite files, a unique temporary file
400 Because ``os.link`` does not overwrite files, a unique temporary file
403 will be used if the target already exists, then that file will be moved
401 will be used if the target already exists, then that file will be moved
404 into place.
402 into place.
405 """
403 """
406
404
407 if os.path.isdir(dst):
405 if os.path.isdir(dst):
408 dst = os.path.join(dst, os.path.basename(src))
406 dst = os.path.join(dst, os.path.basename(src))
409
407
410 link_errno = link(src, dst)
408 link_errno = link(src, dst)
411 if link_errno == errno.EEXIST:
409 if link_errno == errno.EEXIST:
412 if os.stat(src).st_ino == os.stat(dst).st_ino:
410 if os.stat(src).st_ino == os.stat(dst).st_ino:
413 # dst is already a hard link to the correct file, so we don't need
411 # dst is already a hard link to the correct file, so we don't need
414 # to do anything else. If we try to link and rename the file
412 # to do anything else. If we try to link and rename the file
415 # anyway, we get duplicate files - see http://bugs.python.org/issue21876
413 # anyway, we get duplicate files - see http://bugs.python.org/issue21876
416 return
414 return
417
415
418 new_dst = dst + "-temp-%04X" %(random.randint(1, 16**4), )
416 new_dst = dst + "-temp-%04X" %(random.randint(1, 16**4), )
419 try:
417 try:
420 link_or_copy(src, new_dst)
418 link_or_copy(src, new_dst)
421 except:
419 except:
422 try:
420 try:
423 os.remove(new_dst)
421 os.remove(new_dst)
424 except OSError:
422 except OSError:
425 pass
423 pass
426 raise
424 raise
427 os.rename(new_dst, dst)
425 os.rename(new_dst, dst)
428 elif link_errno != 0:
426 elif link_errno != 0:
429 # Either link isn't supported, or the filesystem doesn't support
427 # Either link isn't supported, or the filesystem doesn't support
430 # linking, or 'src' and 'dst' are on different filesystems.
428 # linking, or 'src' and 'dst' are on different filesystems.
431 shutil.copy(src, dst)
429 shutil.copy(src, dst)
432
430
433 def ensure_dir_exists(path, mode=0o755):
431 def ensure_dir_exists(path, mode=0o755):
434 """ensure that a directory exists
432 """ensure that a directory exists
435
433
436 If it doesn't exist, try to create it and protect against a race condition
434 If it doesn't exist, try to create it and protect against a race condition
437 if another process is doing the same.
435 if another process is doing the same.
438
436
439 The default permissions are 755, which differ from os.makedirs default of 777.
437 The default permissions are 755, which differ from os.makedirs default of 777.
440 """
438 """
441 if not os.path.exists(path):
439 if not os.path.exists(path):
442 try:
440 try:
443 os.makedirs(path, mode=mode)
441 os.makedirs(path, mode=mode)
444 except OSError as e:
442 except OSError as e:
445 if e.errno != errno.EEXIST:
443 if e.errno != errno.EEXIST:
446 raise
444 raise
447 elif not os.path.isdir(path):
445 elif not os.path.isdir(path):
448 raise IOError("%r exists but is not a directory" % path)
446 raise IOError("%r exists but is not a directory" % path)
@@ -1,765 +1,764 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Utilities for working with strings and text.
3 Utilities for working with strings and text.
4
4
5 Inheritance diagram:
5 Inheritance diagram:
6
6
7 .. inheritance-diagram:: IPython.utils.text
7 .. inheritance-diagram:: IPython.utils.text
8 :parts: 3
8 :parts: 3
9 """
9 """
10
10
11 import os
11 import os
12 import re
12 import re
13 import sys
13 import sys
14 import textwrap
14 import textwrap
15 from string import Formatter
15 from string import Formatter
16
16
17 from IPython.testing.skipdoctest import skip_doctest_py3, skip_doctest
17 from IPython.testing.skipdoctest import skip_doctest_py3, skip_doctest
18 from IPython.utils import py3compat
18 from IPython.utils import py3compat
19
19
20 # datetime.strftime date format for ipython
20 # datetime.strftime date format for ipython
21 if sys.platform == 'win32':
21 if sys.platform == 'win32':
22 date_format = "%B %d, %Y"
22 date_format = "%B %d, %Y"
23 else:
23 else:
24 date_format = "%B %-d, %Y"
24 date_format = "%B %-d, %Y"
25
25
26 class LSString(str):
26 class LSString(str):
27 """String derivative with a special access attributes.
27 """String derivative with a special access attributes.
28
28
29 These are normal strings, but with the special attributes:
29 These are normal strings, but with the special attributes:
30
30
31 .l (or .list) : value as list (split on newlines).
31 .l (or .list) : value as list (split on newlines).
32 .n (or .nlstr): original value (the string itself).
32 .n (or .nlstr): original value (the string itself).
33 .s (or .spstr): value as whitespace-separated string.
33 .s (or .spstr): value as whitespace-separated string.
34 .p (or .paths): list of path objects (requires path.py package)
34 .p (or .paths): list of path objects (requires path.py package)
35
35
36 Any values which require transformations are computed only once and
36 Any values which require transformations are computed only once and
37 cached.
37 cached.
38
38
39 Such strings are very useful to efficiently interact with the shell, which
39 Such strings are very useful to efficiently interact with the shell, which
40 typically only understands whitespace-separated options for commands."""
40 typically only understands whitespace-separated options for commands."""
41
41
42 def get_list(self):
42 def get_list(self):
43 try:
43 try:
44 return self.__list
44 return self.__list
45 except AttributeError:
45 except AttributeError:
46 self.__list = self.split('\n')
46 self.__list = self.split('\n')
47 return self.__list
47 return self.__list
48
48
49 l = list = property(get_list)
49 l = list = property(get_list)
50
50
51 def get_spstr(self):
51 def get_spstr(self):
52 try:
52 try:
53 return self.__spstr
53 return self.__spstr
54 except AttributeError:
54 except AttributeError:
55 self.__spstr = self.replace('\n',' ')
55 self.__spstr = self.replace('\n',' ')
56 return self.__spstr
56 return self.__spstr
57
57
58 s = spstr = property(get_spstr)
58 s = spstr = property(get_spstr)
59
59
60 def get_nlstr(self):
60 def get_nlstr(self):
61 return self
61 return self
62
62
63 n = nlstr = property(get_nlstr)
63 n = nlstr = property(get_nlstr)
64
64
65 def get_paths(self):
65 def get_paths(self):
66 from path import path
66 from path import path
67 try:
67 try:
68 return self.__paths
68 return self.__paths
69 except AttributeError:
69 except AttributeError:
70 self.__paths = [path(p) for p in self.split('\n') if os.path.exists(p)]
70 self.__paths = [path(p) for p in self.split('\n') if os.path.exists(p)]
71 return self.__paths
71 return self.__paths
72
72
73 p = paths = property(get_paths)
73 p = paths = property(get_paths)
74
74
75 # FIXME: We need to reimplement type specific displayhook and then add this
75 # FIXME: We need to reimplement type specific displayhook and then add this
76 # back as a custom printer. This should also be moved outside utils into the
76 # back as a custom printer. This should also be moved outside utils into the
77 # core.
77 # core.
78
78
79 # def print_lsstring(arg):
79 # def print_lsstring(arg):
80 # """ Prettier (non-repr-like) and more informative printer for LSString """
80 # """ Prettier (non-repr-like) and more informative printer for LSString """
81 # print "LSString (.p, .n, .l, .s available). Value:"
81 # print "LSString (.p, .n, .l, .s available). Value:"
82 # print arg
82 # print arg
83 #
83 #
84 #
84 #
85 # print_lsstring = result_display.when_type(LSString)(print_lsstring)
85 # print_lsstring = result_display.when_type(LSString)(print_lsstring)
86
86
87
87
88 class SList(list):
88 class SList(list):
89 """List derivative with a special access attributes.
89 """List derivative with a special access attributes.
90
90
91 These are normal lists, but with the special attributes:
91 These are normal lists, but with the special attributes:
92
92
93 * .l (or .list) : value as list (the list itself).
93 * .l (or .list) : value as list (the list itself).
94 * .n (or .nlstr): value as a string, joined on newlines.
94 * .n (or .nlstr): value as a string, joined on newlines.
95 * .s (or .spstr): value as a string, joined on spaces.
95 * .s (or .spstr): value as a string, joined on spaces.
96 * .p (or .paths): list of path objects (requires path.py package)
96 * .p (or .paths): list of path objects (requires path.py package)
97
97
98 Any values which require transformations are computed only once and
98 Any values which require transformations are computed only once and
99 cached."""
99 cached."""
100
100
101 def get_list(self):
101 def get_list(self):
102 return self
102 return self
103
103
104 l = list = property(get_list)
104 l = list = property(get_list)
105
105
106 def get_spstr(self):
106 def get_spstr(self):
107 try:
107 try:
108 return self.__spstr
108 return self.__spstr
109 except AttributeError:
109 except AttributeError:
110 self.__spstr = ' '.join(self)
110 self.__spstr = ' '.join(self)
111 return self.__spstr
111 return self.__spstr
112
112
113 s = spstr = property(get_spstr)
113 s = spstr = property(get_spstr)
114
114
115 def get_nlstr(self):
115 def get_nlstr(self):
116 try:
116 try:
117 return self.__nlstr
117 return self.__nlstr
118 except AttributeError:
118 except AttributeError:
119 self.__nlstr = '\n'.join(self)
119 self.__nlstr = '\n'.join(self)
120 return self.__nlstr
120 return self.__nlstr
121
121
122 n = nlstr = property(get_nlstr)
122 n = nlstr = property(get_nlstr)
123
123
124 def get_paths(self):
124 def get_paths(self):
125 from path import path
125 from path import path
126 try:
126 try:
127 return self.__paths
127 return self.__paths
128 except AttributeError:
128 except AttributeError:
129 self.__paths = [path(p) for p in self if os.path.exists(p)]
129 self.__paths = [path(p) for p in self if os.path.exists(p)]
130 return self.__paths
130 return self.__paths
131
131
132 p = paths = property(get_paths)
132 p = paths = property(get_paths)
133
133
134 def grep(self, pattern, prune = False, field = None):
134 def grep(self, pattern, prune = False, field = None):
135 """ Return all strings matching 'pattern' (a regex or callable)
135 """ Return all strings matching 'pattern' (a regex or callable)
136
136
137 This is case-insensitive. If prune is true, return all items
137 This is case-insensitive. If prune is true, return all items
138 NOT matching the pattern.
138 NOT matching the pattern.
139
139
140 If field is specified, the match must occur in the specified
140 If field is specified, the match must occur in the specified
141 whitespace-separated field.
141 whitespace-separated field.
142
142
143 Examples::
143 Examples::
144
144
145 a.grep( lambda x: x.startswith('C') )
145 a.grep( lambda x: x.startswith('C') )
146 a.grep('Cha.*log', prune=1)
146 a.grep('Cha.*log', prune=1)
147 a.grep('chm', field=-1)
147 a.grep('chm', field=-1)
148 """
148 """
149
149
150 def match_target(s):
150 def match_target(s):
151 if field is None:
151 if field is None:
152 return s
152 return s
153 parts = s.split()
153 parts = s.split()
154 try:
154 try:
155 tgt = parts[field]
155 tgt = parts[field]
156 return tgt
156 return tgt
157 except IndexError:
157 except IndexError:
158 return ""
158 return ""
159
159
160 if isinstance(pattern, py3compat.string_types):
160 if isinstance(pattern, py3compat.string_types):
161 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
161 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
162 else:
162 else:
163 pred = pattern
163 pred = pattern
164 if not prune:
164 if not prune:
165 return SList([el for el in self if pred(match_target(el))])
165 return SList([el for el in self if pred(match_target(el))])
166 else:
166 else:
167 return SList([el for el in self if not pred(match_target(el))])
167 return SList([el for el in self if not pred(match_target(el))])
168
168
169 def fields(self, *fields):
169 def fields(self, *fields):
170 """ Collect whitespace-separated fields from string list
170 """ Collect whitespace-separated fields from string list
171
171
172 Allows quick awk-like usage of string lists.
172 Allows quick awk-like usage of string lists.
173
173
174 Example data (in var a, created by 'a = !ls -l')::
174 Example data (in var a, created by 'a = !ls -l')::
175
175
176 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
176 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
177 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
177 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
178
178
179 * ``a.fields(0)`` is ``['-rwxrwxrwx', 'drwxrwxrwx+']``
179 * ``a.fields(0)`` is ``['-rwxrwxrwx', 'drwxrwxrwx+']``
180 * ``a.fields(1,0)`` is ``['1 -rwxrwxrwx', '6 drwxrwxrwx+']``
180 * ``a.fields(1,0)`` is ``['1 -rwxrwxrwx', '6 drwxrwxrwx+']``
181 (note the joining by space).
181 (note the joining by space).
182 * ``a.fields(-1)`` is ``['ChangeLog', 'IPython']``
182 * ``a.fields(-1)`` is ``['ChangeLog', 'IPython']``
183
183
184 IndexErrors are ignored.
184 IndexErrors are ignored.
185
185
186 Without args, fields() just split()'s the strings.
186 Without args, fields() just split()'s the strings.
187 """
187 """
188 if len(fields) == 0:
188 if len(fields) == 0:
189 return [el.split() for el in self]
189 return [el.split() for el in self]
190
190
191 res = SList()
191 res = SList()
192 for el in [f.split() for f in self]:
192 for el in [f.split() for f in self]:
193 lineparts = []
193 lineparts = []
194
194
195 for fd in fields:
195 for fd in fields:
196 try:
196 try:
197 lineparts.append(el[fd])
197 lineparts.append(el[fd])
198 except IndexError:
198 except IndexError:
199 pass
199 pass
200 if lineparts:
200 if lineparts:
201 res.append(" ".join(lineparts))
201 res.append(" ".join(lineparts))
202
202
203 return res
203 return res
204
204
205 def sort(self,field= None, nums = False):
205 def sort(self,field= None, nums = False):
206 """ sort by specified fields (see fields())
206 """ sort by specified fields (see fields())
207
207
208 Example::
208 Example::
209
209
210 a.sort(1, nums = True)
210 a.sort(1, nums = True)
211
211
212 Sorts a by second field, in numerical order (so that 21 > 3)
212 Sorts a by second field, in numerical order (so that 21 > 3)
213
213
214 """
214 """
215
215
216 #decorate, sort, undecorate
216 #decorate, sort, undecorate
217 if field is not None:
217 if field is not None:
218 dsu = [[SList([line]).fields(field), line] for line in self]
218 dsu = [[SList([line]).fields(field), line] for line in self]
219 else:
219 else:
220 dsu = [[line, line] for line in self]
220 dsu = [[line, line] for line in self]
221 if nums:
221 if nums:
222 for i in range(len(dsu)):
222 for i in range(len(dsu)):
223 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
223 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
224 try:
224 try:
225 n = int(numstr)
225 n = int(numstr)
226 except ValueError:
226 except ValueError:
227 n = 0;
227 n = 0;
228 dsu[i][0] = n
228 dsu[i][0] = n
229
229
230
230
231 dsu.sort()
231 dsu.sort()
232 return SList([t[1] for t in dsu])
232 return SList([t[1] for t in dsu])
233
233
234
234
235 # FIXME: We need to reimplement type specific displayhook and then add this
235 # FIXME: We need to reimplement type specific displayhook and then add this
236 # back as a custom printer. This should also be moved outside utils into the
236 # back as a custom printer. This should also be moved outside utils into the
237 # core.
237 # core.
238
238
239 # def print_slist(arg):
239 # def print_slist(arg):
240 # """ Prettier (non-repr-like) and more informative printer for SList """
240 # """ Prettier (non-repr-like) and more informative printer for SList """
241 # print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
241 # print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
242 # if hasattr(arg, 'hideonce') and arg.hideonce:
242 # if hasattr(arg, 'hideonce') and arg.hideonce:
243 # arg.hideonce = False
243 # arg.hideonce = False
244 # return
244 # return
245 #
245 #
246 # nlprint(arg) # This was a nested list printer, now removed.
246 # nlprint(arg) # This was a nested list printer, now removed.
247 #
247 #
248 # print_slist = result_display.when_type(SList)(print_slist)
248 # print_slist = result_display.when_type(SList)(print_slist)
249
249
250
250
251 def indent(instr,nspaces=4, ntabs=0, flatten=False):
251 def indent(instr,nspaces=4, ntabs=0, flatten=False):
252 """Indent a string a given number of spaces or tabstops.
252 """Indent a string a given number of spaces or tabstops.
253
253
254 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
254 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
255
255
256 Parameters
256 Parameters
257 ----------
257 ----------
258
258
259 instr : basestring
259 instr : basestring
260 The string to be indented.
260 The string to be indented.
261 nspaces : int (default: 4)
261 nspaces : int (default: 4)
262 The number of spaces to be indented.
262 The number of spaces to be indented.
263 ntabs : int (default: 0)
263 ntabs : int (default: 0)
264 The number of tabs to be indented.
264 The number of tabs to be indented.
265 flatten : bool (default: False)
265 flatten : bool (default: False)
266 Whether to scrub existing indentation. If True, all lines will be
266 Whether to scrub existing indentation. If True, all lines will be
267 aligned to the same indentation. If False, existing indentation will
267 aligned to the same indentation. If False, existing indentation will
268 be strictly increased.
268 be strictly increased.
269
269
270 Returns
270 Returns
271 -------
271 -------
272
272
273 str|unicode : string indented by ntabs and nspaces.
273 str|unicode : string indented by ntabs and nspaces.
274
274
275 """
275 """
276 if instr is None:
276 if instr is None:
277 return
277 return
278 ind = '\t'*ntabs+' '*nspaces
278 ind = '\t'*ntabs+' '*nspaces
279 if flatten:
279 if flatten:
280 pat = re.compile(r'^\s*', re.MULTILINE)
280 pat = re.compile(r'^\s*', re.MULTILINE)
281 else:
281 else:
282 pat = re.compile(r'^', re.MULTILINE)
282 pat = re.compile(r'^', re.MULTILINE)
283 outstr = re.sub(pat, ind, instr)
283 outstr = re.sub(pat, ind, instr)
284 if outstr.endswith(os.linesep+ind):
284 if outstr.endswith(os.linesep+ind):
285 return outstr[:-len(ind)]
285 return outstr[:-len(ind)]
286 else:
286 else:
287 return outstr
287 return outstr
288
288
289
289
290 def list_strings(arg):
290 def list_strings(arg):
291 """Always return a list of strings, given a string or list of strings
291 """Always return a list of strings, given a string or list of strings
292 as input.
292 as input.
293
293
294 Examples
294 Examples
295 --------
295 --------
296 ::
296 ::
297
297
298 In [7]: list_strings('A single string')
298 In [7]: list_strings('A single string')
299 Out[7]: ['A single string']
299 Out[7]: ['A single string']
300
300
301 In [8]: list_strings(['A single string in a list'])
301 In [8]: list_strings(['A single string in a list'])
302 Out[8]: ['A single string in a list']
302 Out[8]: ['A single string in a list']
303
303
304 In [9]: list_strings(['A','list','of','strings'])
304 In [9]: list_strings(['A','list','of','strings'])
305 Out[9]: ['A', 'list', 'of', 'strings']
305 Out[9]: ['A', 'list', 'of', 'strings']
306 """
306 """
307
307
308 if isinstance(arg, py3compat.string_types): return [arg]
308 if isinstance(arg, py3compat.string_types): return [arg]
309 else: return arg
309 else: return arg
310
310
311
311
312 def marquee(txt='',width=78,mark='*'):
312 def marquee(txt='',width=78,mark='*'):
313 """Return the input string centered in a 'marquee'.
313 """Return the input string centered in a 'marquee'.
314
314
315 Examples
315 Examples
316 --------
316 --------
317 ::
317 ::
318
318
319 In [16]: marquee('A test',40)
319 In [16]: marquee('A test',40)
320 Out[16]: '**************** A test ****************'
320 Out[16]: '**************** A test ****************'
321
321
322 In [17]: marquee('A test',40,'-')
322 In [17]: marquee('A test',40,'-')
323 Out[17]: '---------------- A test ----------------'
323 Out[17]: '---------------- A test ----------------'
324
324
325 In [18]: marquee('A test',40,' ')
325 In [18]: marquee('A test',40,' ')
326 Out[18]: ' A test '
326 Out[18]: ' A test '
327
327
328 """
328 """
329 if not txt:
329 if not txt:
330 return (mark*width)[:width]
330 return (mark*width)[:width]
331 nmark = (width-len(txt)-2)//len(mark)//2
331 nmark = (width-len(txt)-2)//len(mark)//2
332 if nmark < 0: nmark =0
332 if nmark < 0: nmark =0
333 marks = mark*nmark
333 marks = mark*nmark
334 return '%s %s %s' % (marks,txt,marks)
334 return '%s %s %s' % (marks,txt,marks)
335
335
336
336
337 ini_spaces_re = re.compile(r'^(\s+)')
337 ini_spaces_re = re.compile(r'^(\s+)')
338
338
339 def num_ini_spaces(strng):
339 def num_ini_spaces(strng):
340 """Return the number of initial spaces in a string"""
340 """Return the number of initial spaces in a string"""
341
341
342 ini_spaces = ini_spaces_re.match(strng)
342 ini_spaces = ini_spaces_re.match(strng)
343 if ini_spaces:
343 if ini_spaces:
344 return ini_spaces.end()
344 return ini_spaces.end()
345 else:
345 else:
346 return 0
346 return 0
347
347
348
348
349 def format_screen(strng):
349 def format_screen(strng):
350 """Format a string for screen printing.
350 """Format a string for screen printing.
351
351
352 This removes some latex-type format codes."""
352 This removes some latex-type format codes."""
353 # Paragraph continue
353 # Paragraph continue
354 par_re = re.compile(r'\\$',re.MULTILINE)
354 par_re = re.compile(r'\\$',re.MULTILINE)
355 strng = par_re.sub('',strng)
355 strng = par_re.sub('',strng)
356 return strng
356 return strng
357
357
358
358
359 def dedent(text):
359 def dedent(text):
360 """Equivalent of textwrap.dedent that ignores unindented first line.
360 """Equivalent of textwrap.dedent that ignores unindented first line.
361
361
362 This means it will still dedent strings like:
362 This means it will still dedent strings like:
363 '''foo
363 '''foo
364 is a bar
364 is a bar
365 '''
365 '''
366
366
367 For use in wrap_paragraphs.
367 For use in wrap_paragraphs.
368 """
368 """
369
369
370 if text.startswith('\n'):
370 if text.startswith('\n'):
371 # text starts with blank line, don't ignore the first line
371 # text starts with blank line, don't ignore the first line
372 return textwrap.dedent(text)
372 return textwrap.dedent(text)
373
373
374 # split first line
374 # split first line
375 splits = text.split('\n',1)
375 splits = text.split('\n',1)
376 if len(splits) == 1:
376 if len(splits) == 1:
377 # only one line
377 # only one line
378 return textwrap.dedent(text)
378 return textwrap.dedent(text)
379
379
380 first, rest = splits
380 first, rest = splits
381 # dedent everything but the first line
381 # dedent everything but the first line
382 rest = textwrap.dedent(rest)
382 rest = textwrap.dedent(rest)
383 return '\n'.join([first, rest])
383 return '\n'.join([first, rest])
384
384
385
385
386 def wrap_paragraphs(text, ncols=80):
386 def wrap_paragraphs(text, ncols=80):
387 """Wrap multiple paragraphs to fit a specified width.
387 """Wrap multiple paragraphs to fit a specified width.
388
388
389 This is equivalent to textwrap.wrap, but with support for multiple
389 This is equivalent to textwrap.wrap, but with support for multiple
390 paragraphs, as separated by empty lines.
390 paragraphs, as separated by empty lines.
391
391
392 Returns
392 Returns
393 -------
393 -------
394
394
395 list of complete paragraphs, wrapped to fill `ncols` columns.
395 list of complete paragraphs, wrapped to fill `ncols` columns.
396 """
396 """
397 paragraph_re = re.compile(r'\n(\s*\n)+', re.MULTILINE)
397 paragraph_re = re.compile(r'\n(\s*\n)+', re.MULTILINE)
398 text = dedent(text).strip()
398 text = dedent(text).strip()
399 paragraphs = paragraph_re.split(text)[::2] # every other entry is space
399 paragraphs = paragraph_re.split(text)[::2] # every other entry is space
400 out_ps = []
400 out_ps = []
401 indent_re = re.compile(r'\n\s+', re.MULTILINE)
401 indent_re = re.compile(r'\n\s+', re.MULTILINE)
402 for p in paragraphs:
402 for p in paragraphs:
403 # presume indentation that survives dedent is meaningful formatting,
403 # presume indentation that survives dedent is meaningful formatting,
404 # so don't fill unless text is flush.
404 # so don't fill unless text is flush.
405 if indent_re.search(p) is None:
405 if indent_re.search(p) is None:
406 # wrap paragraph
406 # wrap paragraph
407 p = textwrap.fill(p, ncols)
407 p = textwrap.fill(p, ncols)
408 out_ps.append(p)
408 out_ps.append(p)
409 return out_ps
409 return out_ps
410
410
411
411
412 def long_substr(data):
412 def long_substr(data):
413 """Return the longest common substring in a list of strings.
413 """Return the longest common substring in a list of strings.
414
414
415 Credit: http://stackoverflow.com/questions/2892931/longest-common-substring-from-more-than-two-strings-python
415 Credit: http://stackoverflow.com/questions/2892931/longest-common-substring-from-more-than-two-strings-python
416 """
416 """
417 substr = ''
417 substr = ''
418 if len(data) > 1 and len(data[0]) > 0:
418 if len(data) > 1 and len(data[0]) > 0:
419 for i in range(len(data[0])):
419 for i in range(len(data[0])):
420 for j in range(len(data[0])-i+1):
420 for j in range(len(data[0])-i+1):
421 if j > len(substr) and all(data[0][i:i+j] in x for x in data):
421 if j > len(substr) and all(data[0][i:i+j] in x for x in data):
422 substr = data[0][i:i+j]
422 substr = data[0][i:i+j]
423 elif len(data) == 1:
423 elif len(data) == 1:
424 substr = data[0]
424 substr = data[0]
425 return substr
425 return substr
426
426
427
427
428 def strip_email_quotes(text):
428 def strip_email_quotes(text):
429 """Strip leading email quotation characters ('>').
429 """Strip leading email quotation characters ('>').
430
430
431 Removes any combination of leading '>' interspersed with whitespace that
431 Removes any combination of leading '>' interspersed with whitespace that
432 appears *identically* in all lines of the input text.
432 appears *identically* in all lines of the input text.
433
433
434 Parameters
434 Parameters
435 ----------
435 ----------
436 text : str
436 text : str
437
437
438 Examples
438 Examples
439 --------
439 --------
440
440
441 Simple uses::
441 Simple uses::
442
442
443 In [2]: strip_email_quotes('> > text')
443 In [2]: strip_email_quotes('> > text')
444 Out[2]: 'text'
444 Out[2]: 'text'
445
445
446 In [3]: strip_email_quotes('> > text\\n> > more')
446 In [3]: strip_email_quotes('> > text\\n> > more')
447 Out[3]: 'text\\nmore'
447 Out[3]: 'text\\nmore'
448
448
449 Note how only the common prefix that appears in all lines is stripped::
449 Note how only the common prefix that appears in all lines is stripped::
450
450
451 In [4]: strip_email_quotes('> > text\\n> > more\\n> more...')
451 In [4]: strip_email_quotes('> > text\\n> > more\\n> more...')
452 Out[4]: '> text\\n> more\\nmore...'
452 Out[4]: '> text\\n> more\\nmore...'
453
453
454 So if any line has no quote marks ('>') , then none are stripped from any
454 So if any line has no quote marks ('>') , then none are stripped from any
455 of them ::
455 of them ::
456
456
457 In [5]: strip_email_quotes('> > text\\n> > more\\nlast different')
457 In [5]: strip_email_quotes('> > text\\n> > more\\nlast different')
458 Out[5]: '> > text\\n> > more\\nlast different'
458 Out[5]: '> > text\\n> > more\\nlast different'
459 """
459 """
460 lines = text.splitlines()
460 lines = text.splitlines()
461 matches = set()
461 matches = set()
462 for line in lines:
462 for line in lines:
463 prefix = re.match(r'^(\s*>[ >]*)', line)
463 prefix = re.match(r'^(\s*>[ >]*)', line)
464 if prefix:
464 if prefix:
465 matches.add(prefix.group(1))
465 matches.add(prefix.group(1))
466 else:
466 else:
467 break
467 break
468 else:
468 else:
469 prefix = long_substr(list(matches))
469 prefix = long_substr(list(matches))
470 if prefix:
470 if prefix:
471 strip = len(prefix)
471 strip = len(prefix)
472 text = '\n'.join([ ln[strip:] for ln in lines])
472 text = '\n'.join([ ln[strip:] for ln in lines])
473 return text
473 return text
474
474
475 def strip_ansi(source):
475 def strip_ansi(source):
476 """
476 """
477 Remove ansi escape codes from text.
477 Remove ansi escape codes from text.
478
478
479 Parameters
479 Parameters
480 ----------
480 ----------
481 source : str
481 source : str
482 Source to remove the ansi from
482 Source to remove the ansi from
483 """
483 """
484 return re.sub(r'\033\[(\d|;)+?m', '', source)
484 return re.sub(r'\033\[(\d|;)+?m', '', source)
485
485
486
486
487 class EvalFormatter(Formatter):
487 class EvalFormatter(Formatter):
488 """A String Formatter that allows evaluation of simple expressions.
488 """A String Formatter that allows evaluation of simple expressions.
489
489
490 Note that this version interprets a : as specifying a format string (as per
490 Note that this version interprets a : as specifying a format string (as per
491 standard string formatting), so if slicing is required, you must explicitly
491 standard string formatting), so if slicing is required, you must explicitly
492 create a slice.
492 create a slice.
493
493
494 This is to be used in templating cases, such as the parallel batch
494 This is to be used in templating cases, such as the parallel batch
495 script templates, where simple arithmetic on arguments is useful.
495 script templates, where simple arithmetic on arguments is useful.
496
496
497 Examples
497 Examples
498 --------
498 --------
499 ::
499 ::
500
500
501 In [1]: f = EvalFormatter()
501 In [1]: f = EvalFormatter()
502 In [2]: f.format('{n//4}', n=8)
502 In [2]: f.format('{n//4}', n=8)
503 Out[2]: '2'
503 Out[2]: '2'
504
504
505 In [3]: f.format("{greeting[slice(2,4)]}", greeting="Hello")
505 In [3]: f.format("{greeting[slice(2,4)]}", greeting="Hello")
506 Out[3]: 'll'
506 Out[3]: 'll'
507 """
507 """
508 def get_field(self, name, args, kwargs):
508 def get_field(self, name, args, kwargs):
509 v = eval(name, kwargs)
509 v = eval(name, kwargs)
510 return v, name
510 return v, name
511
511
512 #XXX: As of Python 3.4, the format string parsing no longer splits on a colon
512 #XXX: As of Python 3.4, the format string parsing no longer splits on a colon
513 # inside [], so EvalFormatter can handle slicing. Once we only support 3.4 and
513 # inside [], so EvalFormatter can handle slicing. Once we only support 3.4 and
514 # above, it should be possible to remove FullEvalFormatter.
514 # above, it should be possible to remove FullEvalFormatter.
515
515
516 @skip_doctest_py3
516 @skip_doctest_py3
517 class FullEvalFormatter(Formatter):
517 class FullEvalFormatter(Formatter):
518 """A String Formatter that allows evaluation of simple expressions.
518 """A String Formatter that allows evaluation of simple expressions.
519
519
520 Any time a format key is not found in the kwargs,
520 Any time a format key is not found in the kwargs,
521 it will be tried as an expression in the kwargs namespace.
521 it will be tried as an expression in the kwargs namespace.
522
522
523 Note that this version allows slicing using [1:2], so you cannot specify
523 Note that this version allows slicing using [1:2], so you cannot specify
524 a format string. Use :class:`EvalFormatter` to permit format strings.
524 a format string. Use :class:`EvalFormatter` to permit format strings.
525
525
526 Examples
526 Examples
527 --------
527 --------
528 ::
528 ::
529
529
530 In [1]: f = FullEvalFormatter()
530 In [1]: f = FullEvalFormatter()
531 In [2]: f.format('{n//4}', n=8)
531 In [2]: f.format('{n//4}', n=8)
532 Out[2]: u'2'
532 Out[2]: u'2'
533
533
534 In [3]: f.format('{list(range(5))[2:4]}')
534 In [3]: f.format('{list(range(5))[2:4]}')
535 Out[3]: u'[2, 3]'
535 Out[3]: u'[2, 3]'
536
536
537 In [4]: f.format('{3*2}')
537 In [4]: f.format('{3*2}')
538 Out[4]: u'6'
538 Out[4]: u'6'
539 """
539 """
540 # copied from Formatter._vformat with minor changes to allow eval
540 # copied from Formatter._vformat with minor changes to allow eval
541 # and replace the format_spec code with slicing
541 # and replace the format_spec code with slicing
542 def _vformat(self, format_string, args, kwargs, used_args, recursion_depth):
542 def _vformat(self, format_string, args, kwargs, used_args, recursion_depth):
543 if recursion_depth < 0:
543 if recursion_depth < 0:
544 raise ValueError('Max string recursion exceeded')
544 raise ValueError('Max string recursion exceeded')
545 result = []
545 result = []
546 for literal_text, field_name, format_spec, conversion in \
546 for literal_text, field_name, format_spec, conversion in \
547 self.parse(format_string):
547 self.parse(format_string):
548
548
549 # output the literal text
549 # output the literal text
550 if literal_text:
550 if literal_text:
551 result.append(literal_text)
551 result.append(literal_text)
552
552
553 # if there's a field, output it
553 # if there's a field, output it
554 if field_name is not None:
554 if field_name is not None:
555 # this is some markup, find the object and do
555 # this is some markup, find the object and do
556 # the formatting
556 # the formatting
557
557
558 if format_spec:
558 if format_spec:
559 # override format spec, to allow slicing:
559 # override format spec, to allow slicing:
560 field_name = ':'.join([field_name, format_spec])
560 field_name = ':'.join([field_name, format_spec])
561
561
562 # eval the contents of the field for the object
562 # eval the contents of the field for the object
563 # to be formatted
563 # to be formatted
564 obj = eval(field_name, kwargs)
564 obj = eval(field_name, kwargs)
565
565
566 # do any conversion on the resulting object
566 # do any conversion on the resulting object
567 obj = self.convert_field(obj, conversion)
567 obj = self.convert_field(obj, conversion)
568
568
569 # format the object and append to the result
569 # format the object and append to the result
570 result.append(self.format_field(obj, ''))
570 result.append(self.format_field(obj, ''))
571
571
572 return u''.join(py3compat.cast_unicode(s) for s in result)
572 return u''.join(py3compat.cast_unicode(s) for s in result)
573
573
574
574
575 @skip_doctest_py3
575 @skip_doctest_py3
576 class DollarFormatter(FullEvalFormatter):
576 class DollarFormatter(FullEvalFormatter):
577 """Formatter allowing Itpl style $foo replacement, for names and attribute
577 """Formatter allowing Itpl style $foo replacement, for names and attribute
578 access only. Standard {foo} replacement also works, and allows full
578 access only. Standard {foo} replacement also works, and allows full
579 evaluation of its arguments.
579 evaluation of its arguments.
580
580
581 Examples
581 Examples
582 --------
582 --------
583 ::
583 ::
584
584
585 In [1]: f = DollarFormatter()
585 In [1]: f = DollarFormatter()
586 In [2]: f.format('{n//4}', n=8)
586 In [2]: f.format('{n//4}', n=8)
587 Out[2]: u'2'
587 Out[2]: u'2'
588
588
589 In [3]: f.format('23 * 76 is $result', result=23*76)
589 In [3]: f.format('23 * 76 is $result', result=23*76)
590 Out[3]: u'23 * 76 is 1748'
590 Out[3]: u'23 * 76 is 1748'
591
591
592 In [4]: f.format('$a or {b}', a=1, b=2)
592 In [4]: f.format('$a or {b}', a=1, b=2)
593 Out[4]: u'1 or 2'
593 Out[4]: u'1 or 2'
594 """
594 """
595 _dollar_pattern = re.compile("(.*?)\$(\$?[\w\.]+)")
595 _dollar_pattern = re.compile("(.*?)\$(\$?[\w\.]+)")
596 def parse(self, fmt_string):
596 def parse(self, fmt_string):
597 for literal_txt, field_name, format_spec, conversion \
597 for literal_txt, field_name, format_spec, conversion \
598 in Formatter.parse(self, fmt_string):
598 in Formatter.parse(self, fmt_string):
599
599
600 # Find $foo patterns in the literal text.
600 # Find $foo patterns in the literal text.
601 continue_from = 0
601 continue_from = 0
602 txt = ""
602 txt = ""
603 for m in self._dollar_pattern.finditer(literal_txt):
603 for m in self._dollar_pattern.finditer(literal_txt):
604 new_txt, new_field = m.group(1,2)
604 new_txt, new_field = m.group(1,2)
605 # $$foo --> $foo
605 # $$foo --> $foo
606 if new_field.startswith("$"):
606 if new_field.startswith("$"):
607 txt += new_txt + new_field
607 txt += new_txt + new_field
608 else:
608 else:
609 yield (txt + new_txt, new_field, "", None)
609 yield (txt + new_txt, new_field, "", None)
610 txt = ""
610 txt = ""
611 continue_from = m.end()
611 continue_from = m.end()
612
612
613 # Re-yield the {foo} style pattern
613 # Re-yield the {foo} style pattern
614 yield (txt + literal_txt[continue_from:], field_name, format_spec, conversion)
614 yield (txt + literal_txt[continue_from:], field_name, format_spec, conversion)
615
615
616 #-----------------------------------------------------------------------------
616 #-----------------------------------------------------------------------------
617 # Utils to columnize a list of string
617 # Utils to columnize a list of string
618 #-----------------------------------------------------------------------------
618 #-----------------------------------------------------------------------------
619
619
620 def _chunks(l, n):
620 def _chunks(l, n):
621 """Yield successive n-sized chunks from l."""
621 """Yield successive n-sized chunks from l."""
622 for i in py3compat.xrange(0, len(l), n):
622 for i in py3compat.xrange(0, len(l), n):
623 yield l[i:i+n]
623 yield l[i:i+n]
624
624
625
625
626 def _find_optimal(rlist , separator_size=2 , displaywidth=80):
626 def _find_optimal(rlist , separator_size=2 , displaywidth=80):
627 """Calculate optimal info to columnize a list of string"""
627 """Calculate optimal info to columnize a list of string"""
628 for nrow in range(1, len(rlist)+1) :
628 for nrow in range(1, len(rlist)+1) :
629 chk = list(map(max,_chunks(rlist, nrow)))
629 chk = list(map(max,_chunks(rlist, nrow)))
630 sumlength = sum(chk)
630 sumlength = sum(chk)
631 ncols = len(chk)
631 ncols = len(chk)
632 if sumlength+separator_size*(ncols-1) <= displaywidth :
632 if sumlength+separator_size*(ncols-1) <= displaywidth :
633 break;
633 break;
634 return {'columns_numbers' : ncols,
634 return {'columns_numbers' : ncols,
635 'optimal_separator_width':(displaywidth - sumlength)/(ncols-1) if (ncols -1) else 0,
635 'optimal_separator_width':(displaywidth - sumlength)/(ncols-1) if (ncols -1) else 0,
636 'rows_numbers' : nrow,
636 'rows_numbers' : nrow,
637 'columns_width' : chk
637 'columns_width' : chk
638 }
638 }
639
639
640
640
641 def _get_or_default(mylist, i, default=None):
641 def _get_or_default(mylist, i, default=None):
642 """return list item number, or default if don't exist"""
642 """return list item number, or default if don't exist"""
643 if i >= len(mylist):
643 if i >= len(mylist):
644 return default
644 return default
645 else :
645 else :
646 return mylist[i]
646 return mylist[i]
647
647
648
648
649 @skip_doctest
650 def compute_item_matrix(items, empty=None, *args, **kwargs) :
649 def compute_item_matrix(items, empty=None, *args, **kwargs) :
651 """Returns a nested list, and info to columnize items
650 """Returns a nested list, and info to columnize items
652
651
653 Parameters
652 Parameters
654 ----------
653 ----------
655
654
656 items
655 items
657 list of strings to columize
656 list of strings to columize
658 empty : (default None)
657 empty : (default None)
659 default value to fill list if needed
658 default value to fill list if needed
660 separator_size : int (default=2)
659 separator_size : int (default=2)
661 How much caracters will be used as a separation between each columns.
660 How much caracters will be used as a separation between each columns.
662 displaywidth : int (default=80)
661 displaywidth : int (default=80)
663 The width of the area onto wich the columns should enter
662 The width of the area onto wich the columns should enter
664
663
665 Returns
664 Returns
666 -------
665 -------
667
666
668 strings_matrix
667 strings_matrix
669
668
670 nested list of string, the outer most list contains as many list as
669 nested list of string, the outer most list contains as many list as
671 rows, the innermost lists have each as many element as colums. If the
670 rows, the innermost lists have each as many element as colums. If the
672 total number of elements in `items` does not equal the product of
671 total number of elements in `items` does not equal the product of
673 rows*columns, the last element of some lists are filled with `None`.
672 rows*columns, the last element of some lists are filled with `None`.
674
673
675 dict_info
674 dict_info
676 some info to make columnize easier:
675 some info to make columnize easier:
677
676
678 columns_numbers
677 columns_numbers
679 number of columns
678 number of columns
680 rows_numbers
679 rows_numbers
681 number of rows
680 number of rows
682 columns_width
681 columns_width
683 list of with of each columns
682 list of with of each columns
684 optimal_separator_width
683 optimal_separator_width
685 best separator width between columns
684 best separator width between columns
686
685
687 Examples
686 Examples
688 --------
687 --------
689 ::
688 ::
690
689
691 In [1]: l = ['aaa','b','cc','d','eeeee','f','g','h','i','j','k','l']
690 In [1]: l = ['aaa','b','cc','d','eeeee','f','g','h','i','j','k','l']
692 ...: compute_item_matrix(l,displaywidth=12)
691 ...: compute_item_matrix(l,displaywidth=12)
693 Out[1]:
692 Out[1]:
694 ([['aaa', 'f', 'k'],
693 ([['aaa', 'f', 'k'],
695 ['b', 'g', 'l'],
694 ['b', 'g', 'l'],
696 ['cc', 'h', None],
695 ['cc', 'h', None],
697 ['d', 'i', None],
696 ['d', 'i', None],
698 ['eeeee', 'j', None]],
697 ['eeeee', 'j', None]],
699 {'columns_numbers': 3,
698 {'columns_numbers': 3,
700 'columns_width': [5, 1, 1],
699 'columns_width': [5, 1, 1],
701 'optimal_separator_width': 2,
700 'optimal_separator_width': 2,
702 'rows_numbers': 5})
701 'rows_numbers': 5})
703 """
702 """
704 info = _find_optimal(list(map(len, items)), *args, **kwargs)
703 info = _find_optimal(list(map(len, items)), *args, **kwargs)
705 nrow, ncol = info['rows_numbers'], info['columns_numbers']
704 nrow, ncol = info['rows_numbers'], info['columns_numbers']
706 return ([[ _get_or_default(items, c*nrow+i, default=empty) for c in range(ncol) ] for i in range(nrow) ], info)
705 return ([[ _get_or_default(items, c*nrow+i, default=empty) for c in range(ncol) ] for i in range(nrow) ], info)
707
706
708
707
709 def columnize(items, separator=' ', displaywidth=80):
708 def columnize(items, separator=' ', displaywidth=80):
710 """ Transform a list of strings into a single string with columns.
709 """ Transform a list of strings into a single string with columns.
711
710
712 Parameters
711 Parameters
713 ----------
712 ----------
714 items : sequence of strings
713 items : sequence of strings
715 The strings to process.
714 The strings to process.
716
715
717 separator : str, optional [default is two spaces]
716 separator : str, optional [default is two spaces]
718 The string that separates columns.
717 The string that separates columns.
719
718
720 displaywidth : int, optional [default is 80]
719 displaywidth : int, optional [default is 80]
721 Width of the display in number of characters.
720 Width of the display in number of characters.
722
721
723 Returns
722 Returns
724 -------
723 -------
725 The formatted string.
724 The formatted string.
726 """
725 """
727 if not items :
726 if not items :
728 return '\n'
727 return '\n'
729 matrix, info = compute_item_matrix(items, separator_size=len(separator), displaywidth=displaywidth)
728 matrix, info = compute_item_matrix(items, separator_size=len(separator), displaywidth=displaywidth)
730 fmatrix = [filter(None, x) for x in matrix]
729 fmatrix = [filter(None, x) for x in matrix]
731 sjoin = lambda x : separator.join([ y.ljust(w, ' ') for y, w in zip(x, info['columns_width'])])
730 sjoin = lambda x : separator.join([ y.ljust(w, ' ') for y, w in zip(x, info['columns_width'])])
732 return '\n'.join(map(sjoin, fmatrix))+'\n'
731 return '\n'.join(map(sjoin, fmatrix))+'\n'
733
732
734
733
735 def get_text_list(list_, last_sep=' and ', sep=", ", wrap_item_with=""):
734 def get_text_list(list_, last_sep=' and ', sep=", ", wrap_item_with=""):
736 """
735 """
737 Return a string with a natural enumeration of items
736 Return a string with a natural enumeration of items
738
737
739 >>> get_text_list(['a', 'b', 'c', 'd'])
738 >>> get_text_list(['a', 'b', 'c', 'd'])
740 'a, b, c and d'
739 'a, b, c and d'
741 >>> get_text_list(['a', 'b', 'c'], ' or ')
740 >>> get_text_list(['a', 'b', 'c'], ' or ')
742 'a, b or c'
741 'a, b or c'
743 >>> get_text_list(['a', 'b', 'c'], ', ')
742 >>> get_text_list(['a', 'b', 'c'], ', ')
744 'a, b, c'
743 'a, b, c'
745 >>> get_text_list(['a', 'b'], ' or ')
744 >>> get_text_list(['a', 'b'], ' or ')
746 'a or b'
745 'a or b'
747 >>> get_text_list(['a'])
746 >>> get_text_list(['a'])
748 'a'
747 'a'
749 >>> get_text_list([])
748 >>> get_text_list([])
750 ''
749 ''
751 >>> get_text_list(['a', 'b'], wrap_item_with="`")
750 >>> get_text_list(['a', 'b'], wrap_item_with="`")
752 '`a` and `b`'
751 '`a` and `b`'
753 >>> get_text_list(['a', 'b', 'c', 'd'], " = ", sep=" + ")
752 >>> get_text_list(['a', 'b', 'c', 'd'], " = ", sep=" + ")
754 'a + b + c = d'
753 'a + b + c = d'
755 """
754 """
756 if len(list_) == 0:
755 if len(list_) == 0:
757 return ''
756 return ''
758 if wrap_item_with:
757 if wrap_item_with:
759 list_ = ['%s%s%s' % (wrap_item_with, item, wrap_item_with) for
758 list_ = ['%s%s%s' % (wrap_item_with, item, wrap_item_with) for
760 item in list_]
759 item in list_]
761 if len(list_) == 1:
760 if len(list_) == 1:
762 return list_[0]
761 return list_[0]
763 return '%s%s%s' % (
762 return '%s%s%s' % (
764 sep.join(i for i in list_[:-1]),
763 sep.join(i for i in list_[:-1]),
765 last_sep, list_[-1]) No newline at end of file
764 last_sep, list_[-1])
@@ -1,486 +1,484 b''
1 """A ZMQ-based subclass of InteractiveShell.
1 """A ZMQ-based subclass of InteractiveShell.
2
2
3 This code is meant to ease the refactoring of the base InteractiveShell into
3 This code is meant to ease the refactoring of the base InteractiveShell into
4 something with a cleaner architecture for 2-process use, without actually
4 something with a cleaner architecture for 2-process use, without actually
5 breaking InteractiveShell itself. So we're doing something a bit ugly, where
5 breaking InteractiveShell itself. So we're doing something a bit ugly, where
6 we subclass and override what we want to fix. Once this is working well, we
6 we subclass and override what we want to fix. Once this is working well, we
7 can go back to the base class and refactor the code for a cleaner inheritance
7 can go back to the base class and refactor the code for a cleaner inheritance
8 implementation that doesn't rely on so much monkeypatching.
8 implementation that doesn't rely on so much monkeypatching.
9
9
10 But this lets us maintain a fully working IPython as we develop the new
10 But this lets us maintain a fully working IPython as we develop the new
11 machinery. This should thus be thought of as scaffolding.
11 machinery. This should thus be thought of as scaffolding.
12 """
12 """
13
13
14 # Copyright (c) IPython Development Team.
14 # Copyright (c) IPython Development Team.
15 # Distributed under the terms of the Modified BSD License.
15 # Distributed under the terms of the Modified BSD License.
16
16
17 from __future__ import print_function
17 from __future__ import print_function
18
18
19 import os
19 import os
20 import sys
20 import sys
21 import time
21 import time
22
22
23 from zmq.eventloop import ioloop
23 from zmq.eventloop import ioloop
24
24
25 from IPython.core.interactiveshell import (
25 from IPython.core.interactiveshell import (
26 InteractiveShell, InteractiveShellABC
26 InteractiveShell, InteractiveShellABC
27 )
27 )
28 from IPython.core import page
28 from IPython.core import page
29 from IPython.core.autocall import ZMQExitAutocall
29 from IPython.core.autocall import ZMQExitAutocall
30 from IPython.core.displaypub import DisplayPublisher
30 from IPython.core.displaypub import DisplayPublisher
31 from IPython.core.error import UsageError
31 from IPython.core.error import UsageError
32 from IPython.core.magics import MacroToEdit, CodeMagics
32 from IPython.core.magics import MacroToEdit, CodeMagics
33 from IPython.core.magic import magics_class, line_magic, Magics
33 from IPython.core.magic import magics_class, line_magic, Magics
34 from IPython.core import payloadpage
34 from IPython.core import payloadpage
35 from IPython.core.usage import default_gui_banner
35 from IPython.core.usage import default_gui_banner
36 from IPython.display import display, Javascript
36 from IPython.display import display, Javascript
37 from ipython_kernel.inprocess.socket import SocketABC
37 from ipython_kernel.inprocess.socket import SocketABC
38 from ipython_kernel import (
38 from ipython_kernel import (
39 get_connection_file, get_connection_info, connect_qtconsole
39 get_connection_file, get_connection_info, connect_qtconsole
40 )
40 )
41 from IPython.testing.skipdoctest import skip_doctest
42 from IPython.utils import openpy
41 from IPython.utils import openpy
43 from jupyter_client.jsonutil import json_clean, encode_images
42 from jupyter_client.jsonutil import json_clean, encode_images
44 from IPython.utils.process import arg_split
43 from IPython.utils.process import arg_split
45 from IPython.utils import py3compat
44 from IPython.utils import py3compat
46 from IPython.utils.py3compat import unicode_type
45 from IPython.utils.py3compat import unicode_type
47 from IPython.utils.traitlets import Instance, Type, Dict, CBool, CBytes, Any
46 from IPython.utils.traitlets import Instance, Type, Dict, CBool, CBytes, Any
48 from IPython.utils.warn import error
47 from IPython.utils.warn import error
49 from ipython_kernel.displayhook import ZMQShellDisplayHook
48 from ipython_kernel.displayhook import ZMQShellDisplayHook
50 from ipython_kernel.datapub import ZMQDataPublisher
49 from ipython_kernel.datapub import ZMQDataPublisher
51 from ipython_kernel.session import extract_header
50 from ipython_kernel.session import extract_header
52 from .session import Session
51 from .session import Session
53
52
54 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
55 # Functions and classes
54 # Functions and classes
56 #-----------------------------------------------------------------------------
55 #-----------------------------------------------------------------------------
57
56
58 class ZMQDisplayPublisher(DisplayPublisher):
57 class ZMQDisplayPublisher(DisplayPublisher):
59 """A display publisher that publishes data using a ZeroMQ PUB socket."""
58 """A display publisher that publishes data using a ZeroMQ PUB socket."""
60
59
61 session = Instance(Session, allow_none=True)
60 session = Instance(Session, allow_none=True)
62 pub_socket = Instance(SocketABC, allow_none=True)
61 pub_socket = Instance(SocketABC, allow_none=True)
63 parent_header = Dict({})
62 parent_header = Dict({})
64 topic = CBytes(b'display_data')
63 topic = CBytes(b'display_data')
65
64
66 def set_parent(self, parent):
65 def set_parent(self, parent):
67 """Set the parent for outbound messages."""
66 """Set the parent for outbound messages."""
68 self.parent_header = extract_header(parent)
67 self.parent_header = extract_header(parent)
69
68
70 def _flush_streams(self):
69 def _flush_streams(self):
71 """flush IO Streams prior to display"""
70 """flush IO Streams prior to display"""
72 sys.stdout.flush()
71 sys.stdout.flush()
73 sys.stderr.flush()
72 sys.stderr.flush()
74
73
75 def publish(self, data, metadata=None, source=None):
74 def publish(self, data, metadata=None, source=None):
76 self._flush_streams()
75 self._flush_streams()
77 if metadata is None:
76 if metadata is None:
78 metadata = {}
77 metadata = {}
79 self._validate_data(data, metadata)
78 self._validate_data(data, metadata)
80 content = {}
79 content = {}
81 content['data'] = encode_images(data)
80 content['data'] = encode_images(data)
82 content['metadata'] = metadata
81 content['metadata'] = metadata
83 self.session.send(
82 self.session.send(
84 self.pub_socket, u'display_data', json_clean(content),
83 self.pub_socket, u'display_data', json_clean(content),
85 parent=self.parent_header, ident=self.topic,
84 parent=self.parent_header, ident=self.topic,
86 )
85 )
87
86
88 def clear_output(self, wait=False):
87 def clear_output(self, wait=False):
89 content = dict(wait=wait)
88 content = dict(wait=wait)
90 self._flush_streams()
89 self._flush_streams()
91 self.session.send(
90 self.session.send(
92 self.pub_socket, u'clear_output', content,
91 self.pub_socket, u'clear_output', content,
93 parent=self.parent_header, ident=self.topic,
92 parent=self.parent_header, ident=self.topic,
94 )
93 )
95
94
96 @magics_class
95 @magics_class
97 class KernelMagics(Magics):
96 class KernelMagics(Magics):
98 #------------------------------------------------------------------------
97 #------------------------------------------------------------------------
99 # Magic overrides
98 # Magic overrides
100 #------------------------------------------------------------------------
99 #------------------------------------------------------------------------
101 # Once the base class stops inheriting from magic, this code needs to be
100 # Once the base class stops inheriting from magic, this code needs to be
102 # moved into a separate machinery as well. For now, at least isolate here
101 # moved into a separate machinery as well. For now, at least isolate here
103 # the magics which this class needs to implement differently from the base
102 # the magics which this class needs to implement differently from the base
104 # class, or that are unique to it.
103 # class, or that are unique to it.
105
104
106 _find_edit_target = CodeMagics._find_edit_target
105 _find_edit_target = CodeMagics._find_edit_target
107
106
108 @skip_doctest
109 @line_magic
107 @line_magic
110 def edit(self, parameter_s='', last_call=['','']):
108 def edit(self, parameter_s='', last_call=['','']):
111 """Bring up an editor and execute the resulting code.
109 """Bring up an editor and execute the resulting code.
112
110
113 Usage:
111 Usage:
114 %edit [options] [args]
112 %edit [options] [args]
115
113
116 %edit runs an external text editor. You will need to set the command for
114 %edit runs an external text editor. You will need to set the command for
117 this editor via the ``TerminalInteractiveShell.editor`` option in your
115 this editor via the ``TerminalInteractiveShell.editor`` option in your
118 configuration file before it will work.
116 configuration file before it will work.
119
117
120 This command allows you to conveniently edit multi-line code right in
118 This command allows you to conveniently edit multi-line code right in
121 your IPython session.
119 your IPython session.
122
120
123 If called without arguments, %edit opens up an empty editor with a
121 If called without arguments, %edit opens up an empty editor with a
124 temporary file and will execute the contents of this file when you
122 temporary file and will execute the contents of this file when you
125 close it (don't forget to save it!).
123 close it (don't forget to save it!).
126
124
127 Options:
125 Options:
128
126
129 -n <number>
127 -n <number>
130 Open the editor at a specified line number. By default, the IPython
128 Open the editor at a specified line number. By default, the IPython
131 editor hook uses the unix syntax 'editor +N filename', but you can
129 editor hook uses the unix syntax 'editor +N filename', but you can
132 configure this by providing your own modified hook if your favorite
130 configure this by providing your own modified hook if your favorite
133 editor supports line-number specifications with a different syntax.
131 editor supports line-number specifications with a different syntax.
134
132
135 -p
133 -p
136 Call the editor with the same data as the previous time it was used,
134 Call the editor with the same data as the previous time it was used,
137 regardless of how long ago (in your current session) it was.
135 regardless of how long ago (in your current session) it was.
138
136
139 -r
137 -r
140 Use 'raw' input. This option only applies to input taken from the
138 Use 'raw' input. This option only applies to input taken from the
141 user's history. By default, the 'processed' history is used, so that
139 user's history. By default, the 'processed' history is used, so that
142 magics are loaded in their transformed version to valid Python. If
140 magics are loaded in their transformed version to valid Python. If
143 this option is given, the raw input as typed as the command line is
141 this option is given, the raw input as typed as the command line is
144 used instead. When you exit the editor, it will be executed by
142 used instead. When you exit the editor, it will be executed by
145 IPython's own processor.
143 IPython's own processor.
146
144
147 Arguments:
145 Arguments:
148
146
149 If arguments are given, the following possibilites exist:
147 If arguments are given, the following possibilites exist:
150
148
151 - The arguments are numbers or pairs of colon-separated numbers (like
149 - The arguments are numbers or pairs of colon-separated numbers (like
152 1 4:8 9). These are interpreted as lines of previous input to be
150 1 4:8 9). These are interpreted as lines of previous input to be
153 loaded into the editor. The syntax is the same of the %macro command.
151 loaded into the editor. The syntax is the same of the %macro command.
154
152
155 - If the argument doesn't start with a number, it is evaluated as a
153 - If the argument doesn't start with a number, it is evaluated as a
156 variable and its contents loaded into the editor. You can thus edit
154 variable and its contents loaded into the editor. You can thus edit
157 any string which contains python code (including the result of
155 any string which contains python code (including the result of
158 previous edits).
156 previous edits).
159
157
160 - If the argument is the name of an object (other than a string),
158 - If the argument is the name of an object (other than a string),
161 IPython will try to locate the file where it was defined and open the
159 IPython will try to locate the file where it was defined and open the
162 editor at the point where it is defined. You can use ``%edit function``
160 editor at the point where it is defined. You can use ``%edit function``
163 to load an editor exactly at the point where 'function' is defined,
161 to load an editor exactly at the point where 'function' is defined,
164 edit it and have the file be executed automatically.
162 edit it and have the file be executed automatically.
165
163
166 If the object is a macro (see %macro for details), this opens up your
164 If the object is a macro (see %macro for details), this opens up your
167 specified editor with a temporary file containing the macro's data.
165 specified editor with a temporary file containing the macro's data.
168 Upon exit, the macro is reloaded with the contents of the file.
166 Upon exit, the macro is reloaded with the contents of the file.
169
167
170 Note: opening at an exact line is only supported under Unix, and some
168 Note: opening at an exact line is only supported under Unix, and some
171 editors (like kedit and gedit up to Gnome 2.8) do not understand the
169 editors (like kedit and gedit up to Gnome 2.8) do not understand the
172 '+NUMBER' parameter necessary for this feature. Good editors like
170 '+NUMBER' parameter necessary for this feature. Good editors like
173 (X)Emacs, vi, jed, pico and joe all do.
171 (X)Emacs, vi, jed, pico and joe all do.
174
172
175 - If the argument is not found as a variable, IPython will look for a
173 - If the argument is not found as a variable, IPython will look for a
176 file with that name (adding .py if necessary) and load it into the
174 file with that name (adding .py if necessary) and load it into the
177 editor. It will execute its contents with execfile() when you exit,
175 editor. It will execute its contents with execfile() when you exit,
178 loading any code in the file into your interactive namespace.
176 loading any code in the file into your interactive namespace.
179
177
180 Unlike in the terminal, this is designed to use a GUI editor, and we do
178 Unlike in the terminal, this is designed to use a GUI editor, and we do
181 not know when it has closed. So the file you edit will not be
179 not know when it has closed. So the file you edit will not be
182 automatically executed or printed.
180 automatically executed or printed.
183
181
184 Note that %edit is also available through the alias %ed.
182 Note that %edit is also available through the alias %ed.
185 """
183 """
186
184
187 opts,args = self.parse_options(parameter_s,'prn:')
185 opts,args = self.parse_options(parameter_s,'prn:')
188
186
189 try:
187 try:
190 filename, lineno, _ = CodeMagics._find_edit_target(self.shell, args, opts, last_call)
188 filename, lineno, _ = CodeMagics._find_edit_target(self.shell, args, opts, last_call)
191 except MacroToEdit as e:
189 except MacroToEdit as e:
192 # TODO: Implement macro editing over 2 processes.
190 # TODO: Implement macro editing over 2 processes.
193 print("Macro editing not yet implemented in 2-process model.")
191 print("Macro editing not yet implemented in 2-process model.")
194 return
192 return
195
193
196 # Make sure we send to the client an absolute path, in case the working
194 # Make sure we send to the client an absolute path, in case the working
197 # directory of client and kernel don't match
195 # directory of client and kernel don't match
198 filename = os.path.abspath(filename)
196 filename = os.path.abspath(filename)
199
197
200 payload = {
198 payload = {
201 'source' : 'edit_magic',
199 'source' : 'edit_magic',
202 'filename' : filename,
200 'filename' : filename,
203 'line_number' : lineno
201 'line_number' : lineno
204 }
202 }
205 self.shell.payload_manager.write_payload(payload)
203 self.shell.payload_manager.write_payload(payload)
206
204
207 # A few magics that are adapted to the specifics of using pexpect and a
205 # A few magics that are adapted to the specifics of using pexpect and a
208 # remote terminal
206 # remote terminal
209
207
210 @line_magic
208 @line_magic
211 def clear(self, arg_s):
209 def clear(self, arg_s):
212 """Clear the terminal."""
210 """Clear the terminal."""
213 if os.name == 'posix':
211 if os.name == 'posix':
214 self.shell.system("clear")
212 self.shell.system("clear")
215 else:
213 else:
216 self.shell.system("cls")
214 self.shell.system("cls")
217
215
218 if os.name == 'nt':
216 if os.name == 'nt':
219 # This is the usual name in windows
217 # This is the usual name in windows
220 cls = line_magic('cls')(clear)
218 cls = line_magic('cls')(clear)
221
219
222 # Terminal pagers won't work over pexpect, but we do have our own pager
220 # Terminal pagers won't work over pexpect, but we do have our own pager
223
221
224 @line_magic
222 @line_magic
225 def less(self, arg_s):
223 def less(self, arg_s):
226 """Show a file through the pager.
224 """Show a file through the pager.
227
225
228 Files ending in .py are syntax-highlighted."""
226 Files ending in .py are syntax-highlighted."""
229 if not arg_s:
227 if not arg_s:
230 raise UsageError('Missing filename.')
228 raise UsageError('Missing filename.')
231
229
232 if arg_s.endswith('.py'):
230 if arg_s.endswith('.py'):
233 cont = self.shell.pycolorize(openpy.read_py_file(arg_s, skip_encoding_cookie=False))
231 cont = self.shell.pycolorize(openpy.read_py_file(arg_s, skip_encoding_cookie=False))
234 else:
232 else:
235 cont = open(arg_s).read()
233 cont = open(arg_s).read()
236 page.page(cont)
234 page.page(cont)
237
235
238 more = line_magic('more')(less)
236 more = line_magic('more')(less)
239
237
240 # Man calls a pager, so we also need to redefine it
238 # Man calls a pager, so we also need to redefine it
241 if os.name == 'posix':
239 if os.name == 'posix':
242 @line_magic
240 @line_magic
243 def man(self, arg_s):
241 def man(self, arg_s):
244 """Find the man page for the given command and display in pager."""
242 """Find the man page for the given command and display in pager."""
245 page.page(self.shell.getoutput('man %s | col -b' % arg_s,
243 page.page(self.shell.getoutput('man %s | col -b' % arg_s,
246 split=False))
244 split=False))
247
245
248 @line_magic
246 @line_magic
249 def connect_info(self, arg_s):
247 def connect_info(self, arg_s):
250 """Print information for connecting other clients to this kernel
248 """Print information for connecting other clients to this kernel
251
249
252 It will print the contents of this session's connection file, as well as
250 It will print the contents of this session's connection file, as well as
253 shortcuts for local clients.
251 shortcuts for local clients.
254
252
255 In the simplest case, when called from the most recently launched kernel,
253 In the simplest case, when called from the most recently launched kernel,
256 secondary clients can be connected, simply with:
254 secondary clients can be connected, simply with:
257
255
258 $> ipython <app> --existing
256 $> ipython <app> --existing
259
257
260 """
258 """
261
259
262 from IPython.core.application import BaseIPythonApplication as BaseIPApp
260 from IPython.core.application import BaseIPythonApplication as BaseIPApp
263
261
264 if BaseIPApp.initialized():
262 if BaseIPApp.initialized():
265 app = BaseIPApp.instance()
263 app = BaseIPApp.instance()
266 security_dir = app.profile_dir.security_dir
264 security_dir = app.profile_dir.security_dir
267 profile = app.profile
265 profile = app.profile
268 else:
266 else:
269 profile = 'default'
267 profile = 'default'
270 security_dir = ''
268 security_dir = ''
271
269
272 try:
270 try:
273 connection_file = get_connection_file()
271 connection_file = get_connection_file()
274 info = get_connection_info(unpack=False)
272 info = get_connection_info(unpack=False)
275 except Exception as e:
273 except Exception as e:
276 error("Could not get connection info: %r" % e)
274 error("Could not get connection info: %r" % e)
277 return
275 return
278
276
279 # add profile flag for non-default profile
277 # add profile flag for non-default profile
280 profile_flag = "--profile %s" % profile if profile != 'default' else ""
278 profile_flag = "--profile %s" % profile if profile != 'default' else ""
281
279
282 # if it's in the security dir, truncate to basename
280 # if it's in the security dir, truncate to basename
283 if security_dir == os.path.dirname(connection_file):
281 if security_dir == os.path.dirname(connection_file):
284 connection_file = os.path.basename(connection_file)
282 connection_file = os.path.basename(connection_file)
285
283
286
284
287 print (info + '\n')
285 print (info + '\n')
288 print ("Paste the above JSON into a file, and connect with:\n"
286 print ("Paste the above JSON into a file, and connect with:\n"
289 " $> ipython <app> --existing <file>\n"
287 " $> ipython <app> --existing <file>\n"
290 "or, if you are local, you can connect with just:\n"
288 "or, if you are local, you can connect with just:\n"
291 " $> ipython <app> --existing {0} {1}\n"
289 " $> ipython <app> --existing {0} {1}\n"
292 "or even just:\n"
290 "or even just:\n"
293 " $> ipython <app> --existing {1}\n"
291 " $> ipython <app> --existing {1}\n"
294 "if this is the most recent IPython session you have started.".format(
292 "if this is the most recent IPython session you have started.".format(
295 connection_file, profile_flag
293 connection_file, profile_flag
296 )
294 )
297 )
295 )
298
296
299 @line_magic
297 @line_magic
300 def qtconsole(self, arg_s):
298 def qtconsole(self, arg_s):
301 """Open a qtconsole connected to this kernel.
299 """Open a qtconsole connected to this kernel.
302
300
303 Useful for connecting a qtconsole to running notebooks, for better
301 Useful for connecting a qtconsole to running notebooks, for better
304 debugging.
302 debugging.
305 """
303 """
306
304
307 # %qtconsole should imply bind_kernel for engines:
305 # %qtconsole should imply bind_kernel for engines:
308 try:
306 try:
309 from IPython.parallel import bind_kernel
307 from IPython.parallel import bind_kernel
310 except ImportError:
308 except ImportError:
311 # technically possible, because parallel has higher pyzmq min-version
309 # technically possible, because parallel has higher pyzmq min-version
312 pass
310 pass
313 else:
311 else:
314 bind_kernel()
312 bind_kernel()
315
313
316 try:
314 try:
317 p = connect_qtconsole(argv=arg_split(arg_s, os.name=='posix'))
315 p = connect_qtconsole(argv=arg_split(arg_s, os.name=='posix'))
318 except Exception as e:
316 except Exception as e:
319 error("Could not start qtconsole: %r" % e)
317 error("Could not start qtconsole: %r" % e)
320 return
318 return
321
319
322 @line_magic
320 @line_magic
323 def autosave(self, arg_s):
321 def autosave(self, arg_s):
324 """Set the autosave interval in the notebook (in seconds).
322 """Set the autosave interval in the notebook (in seconds).
325
323
326 The default value is 120, or two minutes.
324 The default value is 120, or two minutes.
327 ``%autosave 0`` will disable autosave.
325 ``%autosave 0`` will disable autosave.
328
326
329 This magic only has an effect when called from the notebook interface.
327 This magic only has an effect when called from the notebook interface.
330 It has no effect when called in a startup file.
328 It has no effect when called in a startup file.
331 """
329 """
332
330
333 try:
331 try:
334 interval = int(arg_s)
332 interval = int(arg_s)
335 except ValueError:
333 except ValueError:
336 raise UsageError("%%autosave requires an integer, got %r" % arg_s)
334 raise UsageError("%%autosave requires an integer, got %r" % arg_s)
337
335
338 # javascript wants milliseconds
336 # javascript wants milliseconds
339 milliseconds = 1000 * interval
337 milliseconds = 1000 * interval
340 display(Javascript("IPython.notebook.set_autosave_interval(%i)" % milliseconds),
338 display(Javascript("IPython.notebook.set_autosave_interval(%i)" % milliseconds),
341 include=['application/javascript']
339 include=['application/javascript']
342 )
340 )
343 if interval:
341 if interval:
344 print("Autosaving every %i seconds" % interval)
342 print("Autosaving every %i seconds" % interval)
345 else:
343 else:
346 print("Autosave disabled")
344 print("Autosave disabled")
347
345
348
346
349 class ZMQInteractiveShell(InteractiveShell):
347 class ZMQInteractiveShell(InteractiveShell):
350 """A subclass of InteractiveShell for ZMQ."""
348 """A subclass of InteractiveShell for ZMQ."""
351
349
352 displayhook_class = Type(ZMQShellDisplayHook)
350 displayhook_class = Type(ZMQShellDisplayHook)
353 display_pub_class = Type(ZMQDisplayPublisher)
351 display_pub_class = Type(ZMQDisplayPublisher)
354 data_pub_class = Type(ZMQDataPublisher)
352 data_pub_class = Type(ZMQDataPublisher)
355 kernel = Any()
353 kernel = Any()
356 parent_header = Any()
354 parent_header = Any()
357
355
358 def _banner1_default(self):
356 def _banner1_default(self):
359 return default_gui_banner
357 return default_gui_banner
360
358
361 # Override the traitlet in the parent class, because there's no point using
359 # Override the traitlet in the parent class, because there's no point using
362 # readline for the kernel. Can be removed when the readline code is moved
360 # readline for the kernel. Can be removed when the readline code is moved
363 # to the terminal frontend.
361 # to the terminal frontend.
364 colors_force = CBool(True)
362 colors_force = CBool(True)
365 readline_use = CBool(False)
363 readline_use = CBool(False)
366 # autoindent has no meaning in a zmqshell, and attempting to enable it
364 # autoindent has no meaning in a zmqshell, and attempting to enable it
367 # will print a warning in the absence of readline.
365 # will print a warning in the absence of readline.
368 autoindent = CBool(False)
366 autoindent = CBool(False)
369
367
370 exiter = Instance(ZMQExitAutocall)
368 exiter = Instance(ZMQExitAutocall)
371 def _exiter_default(self):
369 def _exiter_default(self):
372 return ZMQExitAutocall(self)
370 return ZMQExitAutocall(self)
373
371
374 def _exit_now_changed(self, name, old, new):
372 def _exit_now_changed(self, name, old, new):
375 """stop eventloop when exit_now fires"""
373 """stop eventloop when exit_now fires"""
376 if new:
374 if new:
377 loop = ioloop.IOLoop.instance()
375 loop = ioloop.IOLoop.instance()
378 loop.add_timeout(time.time()+0.1, loop.stop)
376 loop.add_timeout(time.time()+0.1, loop.stop)
379
377
380 keepkernel_on_exit = None
378 keepkernel_on_exit = None
381
379
382 # Over ZeroMQ, GUI control isn't done with PyOS_InputHook as there is no
380 # Over ZeroMQ, GUI control isn't done with PyOS_InputHook as there is no
383 # interactive input being read; we provide event loop support in ipkernel
381 # interactive input being read; we provide event loop support in ipkernel
384 @staticmethod
382 @staticmethod
385 def enable_gui(gui):
383 def enable_gui(gui):
386 from .eventloops import enable_gui as real_enable_gui
384 from .eventloops import enable_gui as real_enable_gui
387 try:
385 try:
388 real_enable_gui(gui)
386 real_enable_gui(gui)
389 except ValueError as e:
387 except ValueError as e:
390 raise UsageError("%s" % e)
388 raise UsageError("%s" % e)
391
389
392 def init_environment(self):
390 def init_environment(self):
393 """Configure the user's environment."""
391 """Configure the user's environment."""
394 env = os.environ
392 env = os.environ
395 # These two ensure 'ls' produces nice coloring on BSD-derived systems
393 # These two ensure 'ls' produces nice coloring on BSD-derived systems
396 env['TERM'] = 'xterm-color'
394 env['TERM'] = 'xterm-color'
397 env['CLICOLOR'] = '1'
395 env['CLICOLOR'] = '1'
398 # Since normal pagers don't work at all (over pexpect we don't have
396 # Since normal pagers don't work at all (over pexpect we don't have
399 # single-key control of the subprocess), try to disable paging in
397 # single-key control of the subprocess), try to disable paging in
400 # subprocesses as much as possible.
398 # subprocesses as much as possible.
401 env['PAGER'] = 'cat'
399 env['PAGER'] = 'cat'
402 env['GIT_PAGER'] = 'cat'
400 env['GIT_PAGER'] = 'cat'
403
401
404 def init_hooks(self):
402 def init_hooks(self):
405 super(ZMQInteractiveShell, self).init_hooks()
403 super(ZMQInteractiveShell, self).init_hooks()
406 self.set_hook('show_in_pager', page.as_hook(payloadpage.page), 99)
404 self.set_hook('show_in_pager', page.as_hook(payloadpage.page), 99)
407
405
408 def ask_exit(self):
406 def ask_exit(self):
409 """Engage the exit actions."""
407 """Engage the exit actions."""
410 self.exit_now = (not self.keepkernel_on_exit)
408 self.exit_now = (not self.keepkernel_on_exit)
411 payload = dict(
409 payload = dict(
412 source='ask_exit',
410 source='ask_exit',
413 keepkernel=self.keepkernel_on_exit,
411 keepkernel=self.keepkernel_on_exit,
414 )
412 )
415 self.payload_manager.write_payload(payload)
413 self.payload_manager.write_payload(payload)
416
414
417 def _showtraceback(self, etype, evalue, stb):
415 def _showtraceback(self, etype, evalue, stb):
418 # try to preserve ordering of tracebacks and print statements
416 # try to preserve ordering of tracebacks and print statements
419 sys.stdout.flush()
417 sys.stdout.flush()
420 sys.stderr.flush()
418 sys.stderr.flush()
421
419
422 exc_content = {
420 exc_content = {
423 u'traceback' : stb,
421 u'traceback' : stb,
424 u'ename' : unicode_type(etype.__name__),
422 u'ename' : unicode_type(etype.__name__),
425 u'evalue' : py3compat.safe_unicode(evalue),
423 u'evalue' : py3compat.safe_unicode(evalue),
426 }
424 }
427
425
428 dh = self.displayhook
426 dh = self.displayhook
429 # Send exception info over pub socket for other clients than the caller
427 # Send exception info over pub socket for other clients than the caller
430 # to pick up
428 # to pick up
431 topic = None
429 topic = None
432 if dh.topic:
430 if dh.topic:
433 topic = dh.topic.replace(b'execute_result', b'error')
431 topic = dh.topic.replace(b'execute_result', b'error')
434
432
435 exc_msg = dh.session.send(dh.pub_socket, u'error', json_clean(exc_content), dh.parent_header, ident=topic)
433 exc_msg = dh.session.send(dh.pub_socket, u'error', json_clean(exc_content), dh.parent_header, ident=topic)
436
434
437 # FIXME - Hack: store exception info in shell object. Right now, the
435 # FIXME - Hack: store exception info in shell object. Right now, the
438 # caller is reading this info after the fact, we need to fix this logic
436 # caller is reading this info after the fact, we need to fix this logic
439 # to remove this hack. Even uglier, we need to store the error status
437 # to remove this hack. Even uglier, we need to store the error status
440 # here, because in the main loop, the logic that sets it is being
438 # here, because in the main loop, the logic that sets it is being
441 # skipped because runlines swallows the exceptions.
439 # skipped because runlines swallows the exceptions.
442 exc_content[u'status'] = u'error'
440 exc_content[u'status'] = u'error'
443 self._reply_content = exc_content
441 self._reply_content = exc_content
444 # /FIXME
442 # /FIXME
445
443
446 return exc_content
444 return exc_content
447
445
448 def set_next_input(self, text, replace=False):
446 def set_next_input(self, text, replace=False):
449 """Send the specified text to the frontend to be presented at the next
447 """Send the specified text to the frontend to be presented at the next
450 input cell."""
448 input cell."""
451 payload = dict(
449 payload = dict(
452 source='set_next_input',
450 source='set_next_input',
453 text=text,
451 text=text,
454 replace=replace,
452 replace=replace,
455 )
453 )
456 self.payload_manager.write_payload(payload)
454 self.payload_manager.write_payload(payload)
457
455
458 def set_parent(self, parent):
456 def set_parent(self, parent):
459 """Set the parent header for associating output with its triggering input"""
457 """Set the parent header for associating output with its triggering input"""
460 self.parent_header = parent
458 self.parent_header = parent
461 self.displayhook.set_parent(parent)
459 self.displayhook.set_parent(parent)
462 self.display_pub.set_parent(parent)
460 self.display_pub.set_parent(parent)
463 self.data_pub.set_parent(parent)
461 self.data_pub.set_parent(parent)
464 try:
462 try:
465 sys.stdout.set_parent(parent)
463 sys.stdout.set_parent(parent)
466 except AttributeError:
464 except AttributeError:
467 pass
465 pass
468 try:
466 try:
469 sys.stderr.set_parent(parent)
467 sys.stderr.set_parent(parent)
470 except AttributeError:
468 except AttributeError:
471 pass
469 pass
472
470
473 def get_parent(self):
471 def get_parent(self):
474 return self.parent_header
472 return self.parent_header
475
473
476 #-------------------------------------------------------------------------
474 #-------------------------------------------------------------------------
477 # Things related to magics
475 # Things related to magics
478 #-------------------------------------------------------------------------
476 #-------------------------------------------------------------------------
479
477
480 def init_magics(self):
478 def init_magics(self):
481 super(ZMQInteractiveShell, self).init_magics()
479 super(ZMQInteractiveShell, self).init_magics()
482 self.register_magics(KernelMagics)
480 self.register_magics(KernelMagics)
483 self.magics_manager.register_alias('ed', 'edit')
481 self.magics_manager.register_alias('ed', 'edit')
484
482
485
483
486 InteractiveShellABC.register(ZMQInteractiveShell)
484 InteractiveShellABC.register(ZMQInteractiveShell)
@@ -1,441 +1,436 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 =============
3 =============
4 parallelmagic
4 parallelmagic
5 =============
5 =============
6
6
7 Magic command interface for interactive parallel work.
7 Magic command interface for interactive parallel work.
8
8
9 Usage
9 Usage
10 =====
10 =====
11
11
12 ``%autopx``
12 ``%autopx``
13
13
14 {AUTOPX_DOC}
14 {AUTOPX_DOC}
15
15
16 ``%px``
16 ``%px``
17
17
18 {PX_DOC}
18 {PX_DOC}
19
19
20 ``%pxresult``
20 ``%pxresult``
21
21
22 {RESULT_DOC}
22 {RESULT_DOC}
23
23
24 ``%pxconfig``
24 ``%pxconfig``
25
25
26 {CONFIG_DOC}
26 {CONFIG_DOC}
27
27
28 """
28 """
29 from __future__ import print_function
29 from __future__ import print_function
30
30
31 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
32 # Copyright (C) 2008 The IPython Development Team
32 # Copyright (C) 2008 The IPython Development Team
33 #
33 #
34 # Distributed under the terms of the BSD License. The full license is in
34 # Distributed under the terms of the BSD License. The full license is in
35 # the file COPYING, distributed as part of this software.
35 # the file COPYING, distributed as part of this software.
36 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
37
37
38 #-----------------------------------------------------------------------------
38 #-----------------------------------------------------------------------------
39 # Imports
39 # Imports
40 #-----------------------------------------------------------------------------
40 #-----------------------------------------------------------------------------
41
41
42 import ast
42 import ast
43 import re
43 import re
44
44
45 from IPython.core.error import UsageError
45 from IPython.core.error import UsageError
46 from IPython.core.magic import Magics
46 from IPython.core.magic import Magics
47 from IPython.core import magic_arguments
47 from IPython.core import magic_arguments
48 from IPython.testing.skipdoctest import skip_doctest
49 from IPython.utils.text import dedent
48 from IPython.utils.text import dedent
50
49
51 #-----------------------------------------------------------------------------
50 #-----------------------------------------------------------------------------
52 # Definitions of magic functions for use with IPython
51 # Definitions of magic functions for use with IPython
53 #-----------------------------------------------------------------------------
52 #-----------------------------------------------------------------------------
54
53
55
54
56 NO_LAST_RESULT = "%pxresult recalls last %px result, which has not yet been used."
55 NO_LAST_RESULT = "%pxresult recalls last %px result, which has not yet been used."
57
56
58 def exec_args(f):
57 def exec_args(f):
59 """decorator for adding block/targets args for execution
58 """decorator for adding block/targets args for execution
60
59
61 applied to %pxconfig and %%px
60 applied to %pxconfig and %%px
62 """
61 """
63 args = [
62 args = [
64 magic_arguments.argument('-b', '--block', action="store_const",
63 magic_arguments.argument('-b', '--block', action="store_const",
65 const=True, dest='block',
64 const=True, dest='block',
66 help="use blocking (sync) execution",
65 help="use blocking (sync) execution",
67 ),
66 ),
68 magic_arguments.argument('-a', '--noblock', action="store_const",
67 magic_arguments.argument('-a', '--noblock', action="store_const",
69 const=False, dest='block',
68 const=False, dest='block',
70 help="use non-blocking (async) execution",
69 help="use non-blocking (async) execution",
71 ),
70 ),
72 magic_arguments.argument('-t', '--targets', type=str,
71 magic_arguments.argument('-t', '--targets', type=str,
73 help="specify the targets on which to execute",
72 help="specify the targets on which to execute",
74 ),
73 ),
75 magic_arguments.argument('--local', action="store_const",
74 magic_arguments.argument('--local', action="store_const",
76 const=True, dest="local",
75 const=True, dest="local",
77 help="also execute the cell in the local namespace",
76 help="also execute the cell in the local namespace",
78 ),
77 ),
79 magic_arguments.argument('--verbose', action="store_const",
78 magic_arguments.argument('--verbose', action="store_const",
80 const=True, dest="set_verbose",
79 const=True, dest="set_verbose",
81 help="print a message at each execution",
80 help="print a message at each execution",
82 ),
81 ),
83 magic_arguments.argument('--no-verbose', action="store_const",
82 magic_arguments.argument('--no-verbose', action="store_const",
84 const=False, dest="set_verbose",
83 const=False, dest="set_verbose",
85 help="don't print any messages",
84 help="don't print any messages",
86 ),
85 ),
87 ]
86 ]
88 for a in args:
87 for a in args:
89 f = a(f)
88 f = a(f)
90 return f
89 return f
91
90
92 def output_args(f):
91 def output_args(f):
93 """decorator for output-formatting args
92 """decorator for output-formatting args
94
93
95 applied to %pxresult and %%px
94 applied to %pxresult and %%px
96 """
95 """
97 args = [
96 args = [
98 magic_arguments.argument('-r', action="store_const", dest='groupby',
97 magic_arguments.argument('-r', action="store_const", dest='groupby',
99 const='order',
98 const='order',
100 help="collate outputs in order (same as group-outputs=order)"
99 help="collate outputs in order (same as group-outputs=order)"
101 ),
100 ),
102 magic_arguments.argument('-e', action="store_const", dest='groupby',
101 magic_arguments.argument('-e', action="store_const", dest='groupby',
103 const='engine',
102 const='engine',
104 help="group outputs by engine (same as group-outputs=engine)"
103 help="group outputs by engine (same as group-outputs=engine)"
105 ),
104 ),
106 magic_arguments.argument('--group-outputs', dest='groupby', type=str,
105 magic_arguments.argument('--group-outputs', dest='groupby', type=str,
107 choices=['engine', 'order', 'type'], default='type',
106 choices=['engine', 'order', 'type'], default='type',
108 help="""Group the outputs in a particular way.
107 help="""Group the outputs in a particular way.
109
108
110 Choices are:
109 Choices are:
111
110
112 **type**: group outputs of all engines by type (stdout, stderr, displaypub, etc.).
111 **type**: group outputs of all engines by type (stdout, stderr, displaypub, etc.).
113 **engine**: display all output for each engine together.
112 **engine**: display all output for each engine together.
114 **order**: like type, but individual displaypub output from each engine is collated.
113 **order**: like type, but individual displaypub output from each engine is collated.
115 For example, if multiple plots are generated by each engine, the first
114 For example, if multiple plots are generated by each engine, the first
116 figure of each engine will be displayed, then the second of each, etc.
115 figure of each engine will be displayed, then the second of each, etc.
117 """
116 """
118 ),
117 ),
119 magic_arguments.argument('-o', '--out', dest='save_name', type=str,
118 magic_arguments.argument('-o', '--out', dest='save_name', type=str,
120 help="""store the AsyncResult object for this computation
119 help="""store the AsyncResult object for this computation
121 in the global namespace under this name.
120 in the global namespace under this name.
122 """
121 """
123 ),
122 ),
124 ]
123 ]
125 for a in args:
124 for a in args:
126 f = a(f)
125 f = a(f)
127 return f
126 return f
128
127
129 class ParallelMagics(Magics):
128 class ParallelMagics(Magics):
130 """A set of magics useful when controlling a parallel IPython cluster.
129 """A set of magics useful when controlling a parallel IPython cluster.
131 """
130 """
132
131
133 # magic-related
132 # magic-related
134 magics = None
133 magics = None
135 registered = True
134 registered = True
136
135
137 # suffix for magics
136 # suffix for magics
138 suffix = ''
137 suffix = ''
139 # A flag showing if autopx is activated or not
138 # A flag showing if autopx is activated or not
140 _autopx = False
139 _autopx = False
141 # the current view used by the magics:
140 # the current view used by the magics:
142 view = None
141 view = None
143 # last result cache for %pxresult
142 # last result cache for %pxresult
144 last_result = None
143 last_result = None
145 # verbose flag
144 # verbose flag
146 verbose = False
145 verbose = False
147
146
148 def __init__(self, shell, view, suffix=''):
147 def __init__(self, shell, view, suffix=''):
149 self.view = view
148 self.view = view
150 self.suffix = suffix
149 self.suffix = suffix
151
150
152 # register magics
151 # register magics
153 self.magics = dict(cell={},line={})
152 self.magics = dict(cell={},line={})
154 line_magics = self.magics['line']
153 line_magics = self.magics['line']
155
154
156 px = 'px' + suffix
155 px = 'px' + suffix
157 if not suffix:
156 if not suffix:
158 # keep %result for legacy compatibility
157 # keep %result for legacy compatibility
159 line_magics['result'] = self.result
158 line_magics['result'] = self.result
160
159
161 line_magics['pxresult' + suffix] = self.result
160 line_magics['pxresult' + suffix] = self.result
162 line_magics[px] = self.px
161 line_magics[px] = self.px
163 line_magics['pxconfig' + suffix] = self.pxconfig
162 line_magics['pxconfig' + suffix] = self.pxconfig
164 line_magics['auto' + px] = self.autopx
163 line_magics['auto' + px] = self.autopx
165
164
166 self.magics['cell'][px] = self.cell_px
165 self.magics['cell'][px] = self.cell_px
167
166
168 super(ParallelMagics, self).__init__(shell=shell)
167 super(ParallelMagics, self).__init__(shell=shell)
169
168
170 def _eval_target_str(self, ts):
169 def _eval_target_str(self, ts):
171 if ':' in ts:
170 if ':' in ts:
172 targets = eval("self.view.client.ids[%s]" % ts)
171 targets = eval("self.view.client.ids[%s]" % ts)
173 elif 'all' in ts:
172 elif 'all' in ts:
174 targets = 'all'
173 targets = 'all'
175 else:
174 else:
176 targets = eval(ts)
175 targets = eval(ts)
177 return targets
176 return targets
178
177
179 @magic_arguments.magic_arguments()
178 @magic_arguments.magic_arguments()
180 @exec_args
179 @exec_args
181 def pxconfig(self, line):
180 def pxconfig(self, line):
182 """configure default targets/blocking for %px magics"""
181 """configure default targets/blocking for %px magics"""
183 args = magic_arguments.parse_argstring(self.pxconfig, line)
182 args = magic_arguments.parse_argstring(self.pxconfig, line)
184 if args.targets:
183 if args.targets:
185 self.view.targets = self._eval_target_str(args.targets)
184 self.view.targets = self._eval_target_str(args.targets)
186 if args.block is not None:
185 if args.block is not None:
187 self.view.block = args.block
186 self.view.block = args.block
188 if args.set_verbose is not None:
187 if args.set_verbose is not None:
189 self.verbose = args.set_verbose
188 self.verbose = args.set_verbose
190
189
191 @magic_arguments.magic_arguments()
190 @magic_arguments.magic_arguments()
192 @output_args
191 @output_args
193 @skip_doctest
194 def result(self, line=''):
192 def result(self, line=''):
195 """Print the result of the last asynchronous %px command.
193 """Print the result of the last asynchronous %px command.
196
194
197 This lets you recall the results of %px computations after
195 This lets you recall the results of %px computations after
198 asynchronous submission (block=False).
196 asynchronous submission (block=False).
199
197
200 Examples
198 Examples
201 --------
199 --------
202 ::
200 ::
203
201
204 In [23]: %px os.getpid()
202 In [23]: %px os.getpid()
205 Async parallel execution on engine(s): all
203 Async parallel execution on engine(s): all
206
204
207 In [24]: %pxresult
205 In [24]: %pxresult
208 Out[8:10]: 60920
206 Out[8:10]: 60920
209 Out[9:10]: 60921
207 Out[9:10]: 60921
210 Out[10:10]: 60922
208 Out[10:10]: 60922
211 Out[11:10]: 60923
209 Out[11:10]: 60923
212 """
210 """
213 args = magic_arguments.parse_argstring(self.result, line)
211 args = magic_arguments.parse_argstring(self.result, line)
214
212
215 if self.last_result is None:
213 if self.last_result is None:
216 raise UsageError(NO_LAST_RESULT)
214 raise UsageError(NO_LAST_RESULT)
217
215
218 self.last_result.get()
216 self.last_result.get()
219 self.last_result.display_outputs(groupby=args.groupby)
217 self.last_result.display_outputs(groupby=args.groupby)
220
218
221 @skip_doctest
222 def px(self, line=''):
219 def px(self, line=''):
223 """Executes the given python command in parallel.
220 """Executes the given python command in parallel.
224
221
225 Examples
222 Examples
226 --------
223 --------
227 ::
224 ::
228
225
229 In [24]: %px a = os.getpid()
226 In [24]: %px a = os.getpid()
230 Parallel execution on engine(s): all
227 Parallel execution on engine(s): all
231
228
232 In [25]: %px print a
229 In [25]: %px print a
233 [stdout:0] 1234
230 [stdout:0] 1234
234 [stdout:1] 1235
231 [stdout:1] 1235
235 [stdout:2] 1236
232 [stdout:2] 1236
236 [stdout:3] 1237
233 [stdout:3] 1237
237 """
234 """
238 return self.parallel_execute(line)
235 return self.parallel_execute(line)
239
236
240 def parallel_execute(self, cell, block=None, groupby='type', save_name=None):
237 def parallel_execute(self, cell, block=None, groupby='type', save_name=None):
241 """implementation used by %px and %%parallel"""
238 """implementation used by %px and %%parallel"""
242
239
243 # defaults:
240 # defaults:
244 block = self.view.block if block is None else block
241 block = self.view.block if block is None else block
245
242
246 base = "Parallel" if block else "Async parallel"
243 base = "Parallel" if block else "Async parallel"
247
244
248 targets = self.view.targets
245 targets = self.view.targets
249 if isinstance(targets, list) and len(targets) > 10:
246 if isinstance(targets, list) and len(targets) > 10:
250 str_targets = str(targets[:4])[:-1] + ', ..., ' + str(targets[-4:])[1:]
247 str_targets = str(targets[:4])[:-1] + ', ..., ' + str(targets[-4:])[1:]
251 else:
248 else:
252 str_targets = str(targets)
249 str_targets = str(targets)
253 if self.verbose:
250 if self.verbose:
254 print(base + " execution on engine(s): %s" % str_targets)
251 print(base + " execution on engine(s): %s" % str_targets)
255
252
256 result = self.view.execute(cell, silent=False, block=False)
253 result = self.view.execute(cell, silent=False, block=False)
257 self.last_result = result
254 self.last_result = result
258
255
259 if save_name:
256 if save_name:
260 self.shell.user_ns[save_name] = result
257 self.shell.user_ns[save_name] = result
261
258
262 if block:
259 if block:
263 result.get()
260 result.get()
264 result.display_outputs(groupby)
261 result.display_outputs(groupby)
265 else:
262 else:
266 # return AsyncResult only on non-blocking submission
263 # return AsyncResult only on non-blocking submission
267 return result
264 return result
268
265
269 @magic_arguments.magic_arguments()
266 @magic_arguments.magic_arguments()
270 @exec_args
267 @exec_args
271 @output_args
268 @output_args
272 @skip_doctest
273 def cell_px(self, line='', cell=None):
269 def cell_px(self, line='', cell=None):
274 """Executes the cell in parallel.
270 """Executes the cell in parallel.
275
271
276 Examples
272 Examples
277 --------
273 --------
278 ::
274 ::
279
275
280 In [24]: %%px --noblock
276 In [24]: %%px --noblock
281 ....: a = os.getpid()
277 ....: a = os.getpid()
282 Async parallel execution on engine(s): all
278 Async parallel execution on engine(s): all
283
279
284 In [25]: %%px
280 In [25]: %%px
285 ....: print a
281 ....: print a
286 [stdout:0] 1234
282 [stdout:0] 1234
287 [stdout:1] 1235
283 [stdout:1] 1235
288 [stdout:2] 1236
284 [stdout:2] 1236
289 [stdout:3] 1237
285 [stdout:3] 1237
290 """
286 """
291
287
292 args = magic_arguments.parse_argstring(self.cell_px, line)
288 args = magic_arguments.parse_argstring(self.cell_px, line)
293
289
294 if args.targets:
290 if args.targets:
295 save_targets = self.view.targets
291 save_targets = self.view.targets
296 self.view.targets = self._eval_target_str(args.targets)
292 self.view.targets = self._eval_target_str(args.targets)
297 # if running local, don't block until after local has run
293 # if running local, don't block until after local has run
298 block = False if args.local else args.block
294 block = False if args.local else args.block
299 try:
295 try:
300 ar = self.parallel_execute(cell, block=block,
296 ar = self.parallel_execute(cell, block=block,
301 groupby=args.groupby,
297 groupby=args.groupby,
302 save_name=args.save_name,
298 save_name=args.save_name,
303 )
299 )
304 finally:
300 finally:
305 if args.targets:
301 if args.targets:
306 self.view.targets = save_targets
302 self.view.targets = save_targets
307
303
308 # run locally after submitting remote
304 # run locally after submitting remote
309 block = self.view.block if args.block is None else args.block
305 block = self.view.block if args.block is None else args.block
310 if args.local:
306 if args.local:
311 self.shell.run_cell(cell)
307 self.shell.run_cell(cell)
312 # now apply blocking behavor to remote execution
308 # now apply blocking behavor to remote execution
313 if block:
309 if block:
314 ar.get()
310 ar.get()
315 ar.display_outputs(args.groupby)
311 ar.display_outputs(args.groupby)
316 if not block:
312 if not block:
317 return ar
313 return ar
318
314
319 @skip_doctest
320 def autopx(self, line=''):
315 def autopx(self, line=''):
321 """Toggles auto parallel mode.
316 """Toggles auto parallel mode.
322
317
323 Once this is called, all commands typed at the command line are send to
318 Once this is called, all commands typed at the command line are send to
324 the engines to be executed in parallel. To control which engine are
319 the engines to be executed in parallel. To control which engine are
325 used, the ``targets`` attribute of the view before
320 used, the ``targets`` attribute of the view before
326 entering ``%autopx`` mode.
321 entering ``%autopx`` mode.
327
322
328
323
329 Then you can do the following::
324 Then you can do the following::
330
325
331 In [25]: %autopx
326 In [25]: %autopx
332 %autopx to enabled
327 %autopx to enabled
333
328
334 In [26]: a = 10
329 In [26]: a = 10
335 Parallel execution on engine(s): [0,1,2,3]
330 Parallel execution on engine(s): [0,1,2,3]
336 In [27]: print a
331 In [27]: print a
337 Parallel execution on engine(s): [0,1,2,3]
332 Parallel execution on engine(s): [0,1,2,3]
338 [stdout:0] 10
333 [stdout:0] 10
339 [stdout:1] 10
334 [stdout:1] 10
340 [stdout:2] 10
335 [stdout:2] 10
341 [stdout:3] 10
336 [stdout:3] 10
342
337
343
338
344 In [27]: %autopx
339 In [27]: %autopx
345 %autopx disabled
340 %autopx disabled
346 """
341 """
347 if self._autopx:
342 if self._autopx:
348 self._disable_autopx()
343 self._disable_autopx()
349 else:
344 else:
350 self._enable_autopx()
345 self._enable_autopx()
351
346
352 def _enable_autopx(self):
347 def _enable_autopx(self):
353 """Enable %autopx mode by saving the original run_cell and installing
348 """Enable %autopx mode by saving the original run_cell and installing
354 pxrun_cell.
349 pxrun_cell.
355 """
350 """
356 # override run_cell
351 # override run_cell
357 self._original_run_cell = self.shell.run_cell
352 self._original_run_cell = self.shell.run_cell
358 self.shell.run_cell = self.pxrun_cell
353 self.shell.run_cell = self.pxrun_cell
359
354
360 self._autopx = True
355 self._autopx = True
361 print("%autopx enabled")
356 print("%autopx enabled")
362
357
363 def _disable_autopx(self):
358 def _disable_autopx(self):
364 """Disable %autopx by restoring the original InteractiveShell.run_cell.
359 """Disable %autopx by restoring the original InteractiveShell.run_cell.
365 """
360 """
366 if self._autopx:
361 if self._autopx:
367 self.shell.run_cell = self._original_run_cell
362 self.shell.run_cell = self._original_run_cell
368 self._autopx = False
363 self._autopx = False
369 print("%autopx disabled")
364 print("%autopx disabled")
370
365
371 def pxrun_cell(self, raw_cell, store_history=False, silent=False):
366 def pxrun_cell(self, raw_cell, store_history=False, silent=False):
372 """drop-in replacement for InteractiveShell.run_cell.
367 """drop-in replacement for InteractiveShell.run_cell.
373
368
374 This executes code remotely, instead of in the local namespace.
369 This executes code remotely, instead of in the local namespace.
375
370
376 See InteractiveShell.run_cell for details.
371 See InteractiveShell.run_cell for details.
377 """
372 """
378
373
379 if (not raw_cell) or raw_cell.isspace():
374 if (not raw_cell) or raw_cell.isspace():
380 return
375 return
381
376
382 ipself = self.shell
377 ipself = self.shell
383
378
384 with ipself.builtin_trap:
379 with ipself.builtin_trap:
385 cell = ipself.prefilter_manager.prefilter_lines(raw_cell)
380 cell = ipself.prefilter_manager.prefilter_lines(raw_cell)
386
381
387 # Store raw and processed history
382 # Store raw and processed history
388 if store_history:
383 if store_history:
389 ipself.history_manager.store_inputs(ipself.execution_count,
384 ipself.history_manager.store_inputs(ipself.execution_count,
390 cell, raw_cell)
385 cell, raw_cell)
391
386
392 # ipself.logger.log(cell, raw_cell)
387 # ipself.logger.log(cell, raw_cell)
393
388
394 cell_name = ipself.compile.cache(cell, ipself.execution_count)
389 cell_name = ipself.compile.cache(cell, ipself.execution_count)
395
390
396 try:
391 try:
397 ast.parse(cell, filename=cell_name)
392 ast.parse(cell, filename=cell_name)
398 except (OverflowError, SyntaxError, ValueError, TypeError,
393 except (OverflowError, SyntaxError, ValueError, TypeError,
399 MemoryError):
394 MemoryError):
400 # Case 1
395 # Case 1
401 ipself.showsyntaxerror()
396 ipself.showsyntaxerror()
402 ipself.execution_count += 1
397 ipself.execution_count += 1
403 return None
398 return None
404 except NameError:
399 except NameError:
405 # ignore name errors, because we don't know the remote keys
400 # ignore name errors, because we don't know the remote keys
406 pass
401 pass
407
402
408 if store_history:
403 if store_history:
409 # Write output to the database. Does nothing unless
404 # Write output to the database. Does nothing unless
410 # history output logging is enabled.
405 # history output logging is enabled.
411 ipself.history_manager.store_output(ipself.execution_count)
406 ipself.history_manager.store_output(ipself.execution_count)
412 # Each cell is a *single* input, regardless of how many lines it has
407 # Each cell is a *single* input, regardless of how many lines it has
413 ipself.execution_count += 1
408 ipself.execution_count += 1
414 if re.search(r'get_ipython\(\)\.magic\(u?["\']%?autopx', cell):
409 if re.search(r'get_ipython\(\)\.magic\(u?["\']%?autopx', cell):
415 self._disable_autopx()
410 self._disable_autopx()
416 return False
411 return False
417 else:
412 else:
418 try:
413 try:
419 result = self.view.execute(cell, silent=False, block=False)
414 result = self.view.execute(cell, silent=False, block=False)
420 except:
415 except:
421 ipself.showtraceback()
416 ipself.showtraceback()
422 return True
417 return True
423 else:
418 else:
424 if self.view.block:
419 if self.view.block:
425 try:
420 try:
426 result.get()
421 result.get()
427 except:
422 except:
428 self.shell.showtraceback()
423 self.shell.showtraceback()
429 return True
424 return True
430 else:
425 else:
431 with ipself.builtin_trap:
426 with ipself.builtin_trap:
432 result.display_outputs()
427 result.display_outputs()
433 return False
428 return False
434
429
435
430
436 __doc__ = __doc__.format(
431 __doc__ = __doc__.format(
437 AUTOPX_DOC = dedent(ParallelMagics.autopx.__doc__),
432 AUTOPX_DOC = dedent(ParallelMagics.autopx.__doc__),
438 PX_DOC = dedent(ParallelMagics.px.__doc__),
433 PX_DOC = dedent(ParallelMagics.px.__doc__),
439 RESULT_DOC = dedent(ParallelMagics.result.__doc__),
434 RESULT_DOC = dedent(ParallelMagics.result.__doc__),
440 CONFIG_DOC = dedent(ParallelMagics.pxconfig.__doc__),
435 CONFIG_DOC = dedent(ParallelMagics.pxconfig.__doc__),
441 )
436 )
@@ -1,276 +1,273 b''
1 """Remote Functions and decorators for Views."""
1 """Remote Functions and decorators for Views."""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 from __future__ import division
6 from __future__ import division
7
7
8 import sys
8 import sys
9 import warnings
9 import warnings
10
10
11 from decorator import decorator
11 from decorator import decorator
12 from IPython.testing.skipdoctest import skip_doctest
13
12
14 from . import map as Map
13 from . import map as Map
15 from .asyncresult import AsyncMapResult
14 from .asyncresult import AsyncMapResult
16
15
17 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
18 # Functions and Decorators
17 # Functions and Decorators
19 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
20
19
21 @skip_doctest
22 def remote(view, block=None, **flags):
20 def remote(view, block=None, **flags):
23 """Turn a function into a remote function.
21 """Turn a function into a remote function.
24
22
25 This method can be used for map:
23 This method can be used for map:
26
24
27 In [1]: @remote(view,block=True)
25 In [1]: @remote(view,block=True)
28 ...: def func(a):
26 ...: def func(a):
29 ...: pass
27 ...: pass
30 """
28 """
31
29
32 def remote_function(f):
30 def remote_function(f):
33 return RemoteFunction(view, f, block=block, **flags)
31 return RemoteFunction(view, f, block=block, **flags)
34 return remote_function
32 return remote_function
35
33
36 @skip_doctest
37 def parallel(view, dist='b', block=None, ordered=True, **flags):
34 def parallel(view, dist='b', block=None, ordered=True, **flags):
38 """Turn a function into a parallel remote function.
35 """Turn a function into a parallel remote function.
39
36
40 This method can be used for map:
37 This method can be used for map:
41
38
42 In [1]: @parallel(view, block=True)
39 In [1]: @parallel(view, block=True)
43 ...: def func(a):
40 ...: def func(a):
44 ...: pass
41 ...: pass
45 """
42 """
46
43
47 def parallel_function(f):
44 def parallel_function(f):
48 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
45 return ParallelFunction(view, f, dist=dist, block=block, ordered=ordered, **flags)
49 return parallel_function
46 return parallel_function
50
47
51 def getname(f):
48 def getname(f):
52 """Get the name of an object.
49 """Get the name of an object.
53
50
54 For use in case of callables that are not functions, and
51 For use in case of callables that are not functions, and
55 thus may not have __name__ defined.
52 thus may not have __name__ defined.
56
53
57 Order: f.__name__ > f.name > str(f)
54 Order: f.__name__ > f.name > str(f)
58 """
55 """
59 try:
56 try:
60 return f.__name__
57 return f.__name__
61 except:
58 except:
62 pass
59 pass
63 try:
60 try:
64 return f.name
61 return f.name
65 except:
62 except:
66 pass
63 pass
67
64
68 return str(f)
65 return str(f)
69
66
70 @decorator
67 @decorator
71 def sync_view_results(f, self, *args, **kwargs):
68 def sync_view_results(f, self, *args, **kwargs):
72 """sync relevant results from self.client to our results attribute.
69 """sync relevant results from self.client to our results attribute.
73
70
74 This is a clone of view.sync_results, but for remote functions
71 This is a clone of view.sync_results, but for remote functions
75 """
72 """
76 view = self.view
73 view = self.view
77 if view._in_sync_results:
74 if view._in_sync_results:
78 return f(self, *args, **kwargs)
75 return f(self, *args, **kwargs)
79 view._in_sync_results = True
76 view._in_sync_results = True
80 try:
77 try:
81 ret = f(self, *args, **kwargs)
78 ret = f(self, *args, **kwargs)
82 finally:
79 finally:
83 view._in_sync_results = False
80 view._in_sync_results = False
84 view._sync_results()
81 view._sync_results()
85 return ret
82 return ret
86
83
87 #--------------------------------------------------------------------------
84 #--------------------------------------------------------------------------
88 # Classes
85 # Classes
89 #--------------------------------------------------------------------------
86 #--------------------------------------------------------------------------
90
87
91 class RemoteFunction(object):
88 class RemoteFunction(object):
92 """Turn an existing function into a remote function.
89 """Turn an existing function into a remote function.
93
90
94 Parameters
91 Parameters
95 ----------
92 ----------
96
93
97 view : View instance
94 view : View instance
98 The view to be used for execution
95 The view to be used for execution
99 f : callable
96 f : callable
100 The function to be wrapped into a remote function
97 The function to be wrapped into a remote function
101 block : bool [default: None]
98 block : bool [default: None]
102 Whether to wait for results or not. The default behavior is
99 Whether to wait for results or not. The default behavior is
103 to use the current `block` attribute of `view`
100 to use the current `block` attribute of `view`
104
101
105 **flags : remaining kwargs are passed to View.temp_flags
102 **flags : remaining kwargs are passed to View.temp_flags
106 """
103 """
107
104
108 view = None # the remote connection
105 view = None # the remote connection
109 func = None # the wrapped function
106 func = None # the wrapped function
110 block = None # whether to block
107 block = None # whether to block
111 flags = None # dict of extra kwargs for temp_flags
108 flags = None # dict of extra kwargs for temp_flags
112
109
113 def __init__(self, view, f, block=None, **flags):
110 def __init__(self, view, f, block=None, **flags):
114 self.view = view
111 self.view = view
115 self.func = f
112 self.func = f
116 self.block=block
113 self.block=block
117 self.flags=flags
114 self.flags=flags
118
115
119 def __call__(self, *args, **kwargs):
116 def __call__(self, *args, **kwargs):
120 block = self.view.block if self.block is None else self.block
117 block = self.view.block if self.block is None else self.block
121 with self.view.temp_flags(block=block, **self.flags):
118 with self.view.temp_flags(block=block, **self.flags):
122 return self.view.apply(self.func, *args, **kwargs)
119 return self.view.apply(self.func, *args, **kwargs)
123
120
124
121
125 class ParallelFunction(RemoteFunction):
122 class ParallelFunction(RemoteFunction):
126 """Class for mapping a function to sequences.
123 """Class for mapping a function to sequences.
127
124
128 This will distribute the sequences according the a mapper, and call
125 This will distribute the sequences according the a mapper, and call
129 the function on each sub-sequence. If called via map, then the function
126 the function on each sub-sequence. If called via map, then the function
130 will be called once on each element, rather that each sub-sequence.
127 will be called once on each element, rather that each sub-sequence.
131
128
132 Parameters
129 Parameters
133 ----------
130 ----------
134
131
135 view : View instance
132 view : View instance
136 The view to be used for execution
133 The view to be used for execution
137 f : callable
134 f : callable
138 The function to be wrapped into a remote function
135 The function to be wrapped into a remote function
139 dist : str [default: 'b']
136 dist : str [default: 'b']
140 The key for which mapObject to use to distribute sequences
137 The key for which mapObject to use to distribute sequences
141 options are:
138 options are:
142
139
143 * 'b' : use contiguous chunks in order
140 * 'b' : use contiguous chunks in order
144 * 'r' : use round-robin striping
141 * 'r' : use round-robin striping
145
142
146 block : bool [default: None]
143 block : bool [default: None]
147 Whether to wait for results or not. The default behavior is
144 Whether to wait for results or not. The default behavior is
148 to use the current `block` attribute of `view`
145 to use the current `block` attribute of `view`
149 chunksize : int or None
146 chunksize : int or None
150 The size of chunk to use when breaking up sequences in a load-balanced manner
147 The size of chunk to use when breaking up sequences in a load-balanced manner
151 ordered : bool [default: True]
148 ordered : bool [default: True]
152 Whether the result should be kept in order. If False,
149 Whether the result should be kept in order. If False,
153 results become available as they arrive, regardless of submission order.
150 results become available as they arrive, regardless of submission order.
154 **flags
151 **flags
155 remaining kwargs are passed to View.temp_flags
152 remaining kwargs are passed to View.temp_flags
156 """
153 """
157
154
158 chunksize = None
155 chunksize = None
159 ordered = None
156 ordered = None
160 mapObject = None
157 mapObject = None
161 _mapping = False
158 _mapping = False
162
159
163 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
160 def __init__(self, view, f, dist='b', block=None, chunksize=None, ordered=True, **flags):
164 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
161 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
165 self.chunksize = chunksize
162 self.chunksize = chunksize
166 self.ordered = ordered
163 self.ordered = ordered
167
164
168 mapClass = Map.dists[dist]
165 mapClass = Map.dists[dist]
169 self.mapObject = mapClass()
166 self.mapObject = mapClass()
170
167
171 @sync_view_results
168 @sync_view_results
172 def __call__(self, *sequences):
169 def __call__(self, *sequences):
173 client = self.view.client
170 client = self.view.client
174
171
175 lens = []
172 lens = []
176 maxlen = minlen = -1
173 maxlen = minlen = -1
177 for i, seq in enumerate(sequences):
174 for i, seq in enumerate(sequences):
178 try:
175 try:
179 n = len(seq)
176 n = len(seq)
180 except Exception:
177 except Exception:
181 seq = list(seq)
178 seq = list(seq)
182 if isinstance(sequences, tuple):
179 if isinstance(sequences, tuple):
183 # can't alter a tuple
180 # can't alter a tuple
184 sequences = list(sequences)
181 sequences = list(sequences)
185 sequences[i] = seq
182 sequences[i] = seq
186 n = len(seq)
183 n = len(seq)
187 if n > maxlen:
184 if n > maxlen:
188 maxlen = n
185 maxlen = n
189 if minlen == -1 or n < minlen:
186 if minlen == -1 or n < minlen:
190 minlen = n
187 minlen = n
191 lens.append(n)
188 lens.append(n)
192
189
193 if maxlen == 0:
190 if maxlen == 0:
194 # nothing to iterate over
191 # nothing to iterate over
195 return []
192 return []
196
193
197 # check that the length of sequences match
194 # check that the length of sequences match
198 if not self._mapping and minlen != maxlen:
195 if not self._mapping and minlen != maxlen:
199 msg = 'all sequences must have equal length, but have %s' % lens
196 msg = 'all sequences must have equal length, but have %s' % lens
200 raise ValueError(msg)
197 raise ValueError(msg)
201
198
202 balanced = 'Balanced' in self.view.__class__.__name__
199 balanced = 'Balanced' in self.view.__class__.__name__
203 if balanced:
200 if balanced:
204 if self.chunksize:
201 if self.chunksize:
205 nparts = maxlen // self.chunksize + int(maxlen % self.chunksize > 0)
202 nparts = maxlen // self.chunksize + int(maxlen % self.chunksize > 0)
206 else:
203 else:
207 nparts = maxlen
204 nparts = maxlen
208 targets = [None]*nparts
205 targets = [None]*nparts
209 else:
206 else:
210 if self.chunksize:
207 if self.chunksize:
211 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
208 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
212 # multiplexed:
209 # multiplexed:
213 targets = self.view.targets
210 targets = self.view.targets
214 # 'all' is lazily evaluated at execution time, which is now:
211 # 'all' is lazily evaluated at execution time, which is now:
215 if targets == 'all':
212 if targets == 'all':
216 targets = client._build_targets(targets)[1]
213 targets = client._build_targets(targets)[1]
217 elif isinstance(targets, int):
214 elif isinstance(targets, int):
218 # single-engine view, targets must be iterable
215 # single-engine view, targets must be iterable
219 targets = [targets]
216 targets = [targets]
220 nparts = len(targets)
217 nparts = len(targets)
221
218
222 msg_ids = []
219 msg_ids = []
223 for index, t in enumerate(targets):
220 for index, t in enumerate(targets):
224 args = []
221 args = []
225 for seq in sequences:
222 for seq in sequences:
226 part = self.mapObject.getPartition(seq, index, nparts, maxlen)
223 part = self.mapObject.getPartition(seq, index, nparts, maxlen)
227 args.append(part)
224 args.append(part)
228
225
229 if sum([len(arg) for arg in args]) == 0:
226 if sum([len(arg) for arg in args]) == 0:
230 continue
227 continue
231
228
232 if self._mapping:
229 if self._mapping:
233 if sys.version_info[0] >= 3:
230 if sys.version_info[0] >= 3:
234 f = lambda f, *sequences: list(map(f, *sequences))
231 f = lambda f, *sequences: list(map(f, *sequences))
235 else:
232 else:
236 f = map
233 f = map
237 args = [self.func] + args
234 args = [self.func] + args
238 else:
235 else:
239 f=self.func
236 f=self.func
240
237
241 view = self.view if balanced else client[t]
238 view = self.view if balanced else client[t]
242 with view.temp_flags(block=False, **self.flags):
239 with view.temp_flags(block=False, **self.flags):
243 ar = view.apply(f, *args)
240 ar = view.apply(f, *args)
244
241
245 msg_ids.extend(ar.msg_ids)
242 msg_ids.extend(ar.msg_ids)
246
243
247 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
244 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject,
248 fname=getname(self.func),
245 fname=getname(self.func),
249 ordered=self.ordered
246 ordered=self.ordered
250 )
247 )
251
248
252 if self.block:
249 if self.block:
253 try:
250 try:
254 return r.get()
251 return r.get()
255 except KeyboardInterrupt:
252 except KeyboardInterrupt:
256 return r
253 return r
257 else:
254 else:
258 return r
255 return r
259
256
260 def map(self, *sequences):
257 def map(self, *sequences):
261 """call a function on each element of one or more sequence(s) remotely.
258 """call a function on each element of one or more sequence(s) remotely.
262 This should behave very much like the builtin map, but return an AsyncMapResult
259 This should behave very much like the builtin map, but return an AsyncMapResult
263 if self.block is False.
260 if self.block is False.
264
261
265 That means it can take generators (will be cast to lists locally),
262 That means it can take generators (will be cast to lists locally),
266 and mismatched sequence lengths will be padded with None.
263 and mismatched sequence lengths will be padded with None.
267 """
264 """
268 # set _mapping as a flag for use inside self.__call__
265 # set _mapping as a flag for use inside self.__call__
269 self._mapping = True
266 self._mapping = True
270 try:
267 try:
271 ret = self(*sequences)
268 ret = self(*sequences)
272 finally:
269 finally:
273 self._mapping = False
270 self._mapping = False
274 return ret
271 return ret
275
272
276 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
273 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -1,1125 +1,1121 b''
1 """Views of remote engines."""
1 """Views of remote engines."""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 from __future__ import print_function
6 from __future__ import print_function
7
7
8 import imp
8 import imp
9 import sys
9 import sys
10 import warnings
10 import warnings
11 from contextlib import contextmanager
11 from contextlib import contextmanager
12 from types import ModuleType
12 from types import ModuleType
13
13
14 import zmq
14 import zmq
15
15
16 from IPython.testing.skipdoctest import skip_doctest
17 from IPython.utils import pickleutil
16 from IPython.utils import pickleutil
18 from IPython.utils.traitlets import (
17 from IPython.utils.traitlets import (
19 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
18 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
20 )
19 )
21 from decorator import decorator
20 from decorator import decorator
22
21
23 from ipython_parallel import util
22 from ipython_parallel import util
24 from ipython_parallel.controller.dependency import Dependency, dependent
23 from ipython_parallel.controller.dependency import Dependency, dependent
25 from IPython.utils.py3compat import string_types, iteritems, PY3
24 from IPython.utils.py3compat import string_types, iteritems, PY3
26
25
27 from . import map as Map
26 from . import map as Map
28 from .asyncresult import AsyncResult, AsyncMapResult
27 from .asyncresult import AsyncResult, AsyncMapResult
29 from .remotefunction import ParallelFunction, parallel, remote, getname
28 from .remotefunction import ParallelFunction, parallel, remote, getname
30
29
31 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
32 # Decorators
31 # Decorators
33 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
34
33
35 @decorator
34 @decorator
36 def save_ids(f, self, *args, **kwargs):
35 def save_ids(f, self, *args, **kwargs):
37 """Keep our history and outstanding attributes up to date after a method call."""
36 """Keep our history and outstanding attributes up to date after a method call."""
38 n_previous = len(self.client.history)
37 n_previous = len(self.client.history)
39 try:
38 try:
40 ret = f(self, *args, **kwargs)
39 ret = f(self, *args, **kwargs)
41 finally:
40 finally:
42 nmsgs = len(self.client.history) - n_previous
41 nmsgs = len(self.client.history) - n_previous
43 msg_ids = self.client.history[-nmsgs:]
42 msg_ids = self.client.history[-nmsgs:]
44 self.history.extend(msg_ids)
43 self.history.extend(msg_ids)
45 self.outstanding.update(msg_ids)
44 self.outstanding.update(msg_ids)
46 return ret
45 return ret
47
46
48 @decorator
47 @decorator
49 def sync_results(f, self, *args, **kwargs):
48 def sync_results(f, self, *args, **kwargs):
50 """sync relevant results from self.client to our results attribute."""
49 """sync relevant results from self.client to our results attribute."""
51 if self._in_sync_results:
50 if self._in_sync_results:
52 return f(self, *args, **kwargs)
51 return f(self, *args, **kwargs)
53 self._in_sync_results = True
52 self._in_sync_results = True
54 try:
53 try:
55 ret = f(self, *args, **kwargs)
54 ret = f(self, *args, **kwargs)
56 finally:
55 finally:
57 self._in_sync_results = False
56 self._in_sync_results = False
58 self._sync_results()
57 self._sync_results()
59 return ret
58 return ret
60
59
61 @decorator
60 @decorator
62 def spin_after(f, self, *args, **kwargs):
61 def spin_after(f, self, *args, **kwargs):
63 """call spin after the method."""
62 """call spin after the method."""
64 ret = f(self, *args, **kwargs)
63 ret = f(self, *args, **kwargs)
65 self.spin()
64 self.spin()
66 return ret
65 return ret
67
66
68 #-----------------------------------------------------------------------------
67 #-----------------------------------------------------------------------------
69 # Classes
68 # Classes
70 #-----------------------------------------------------------------------------
69 #-----------------------------------------------------------------------------
71
70
72 @skip_doctest
73 class View(HasTraits):
71 class View(HasTraits):
74 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
72 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
75
73
76 Don't use this class, use subclasses.
74 Don't use this class, use subclasses.
77
75
78 Methods
76 Methods
79 -------
77 -------
80
78
81 spin
79 spin
82 flushes incoming results and registration state changes
80 flushes incoming results and registration state changes
83 control methods spin, and requesting `ids` also ensures up to date
81 control methods spin, and requesting `ids` also ensures up to date
84
82
85 wait
83 wait
86 wait on one or more msg_ids
84 wait on one or more msg_ids
87
85
88 execution methods
86 execution methods
89 apply
87 apply
90 legacy: execute, run
88 legacy: execute, run
91
89
92 data movement
90 data movement
93 push, pull, scatter, gather
91 push, pull, scatter, gather
94
92
95 query methods
93 query methods
96 get_result, queue_status, purge_results, result_status
94 get_result, queue_status, purge_results, result_status
97
95
98 control methods
96 control methods
99 abort, shutdown
97 abort, shutdown
100
98
101 """
99 """
102 # flags
100 # flags
103 block=Bool(False)
101 block=Bool(False)
104 track=Bool(True)
102 track=Bool(True)
105 targets = Any()
103 targets = Any()
106
104
107 history=List()
105 history=List()
108 outstanding = Set()
106 outstanding = Set()
109 results = Dict()
107 results = Dict()
110 client = Instance('ipython_parallel.Client', allow_none=True)
108 client = Instance('ipython_parallel.Client', allow_none=True)
111
109
112 _socket = Instance('zmq.Socket', allow_none=True)
110 _socket = Instance('zmq.Socket', allow_none=True)
113 _flag_names = List(['targets', 'block', 'track'])
111 _flag_names = List(['targets', 'block', 'track'])
114 _in_sync_results = Bool(False)
112 _in_sync_results = Bool(False)
115 _targets = Any()
113 _targets = Any()
116 _idents = Any()
114 _idents = Any()
117
115
118 def __init__(self, client=None, socket=None, **flags):
116 def __init__(self, client=None, socket=None, **flags):
119 super(View, self).__init__(client=client, _socket=socket)
117 super(View, self).__init__(client=client, _socket=socket)
120 self.results = client.results
118 self.results = client.results
121 self.block = client.block
119 self.block = client.block
122
120
123 self.set_flags(**flags)
121 self.set_flags(**flags)
124
122
125 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
123 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
126
124
127 def __repr__(self):
125 def __repr__(self):
128 strtargets = str(self.targets)
126 strtargets = str(self.targets)
129 if len(strtargets) > 16:
127 if len(strtargets) > 16:
130 strtargets = strtargets[:12]+'...]'
128 strtargets = strtargets[:12]+'...]'
131 return "<%s %s>"%(self.__class__.__name__, strtargets)
129 return "<%s %s>"%(self.__class__.__name__, strtargets)
132
130
133 def __len__(self):
131 def __len__(self):
134 if isinstance(self.targets, list):
132 if isinstance(self.targets, list):
135 return len(self.targets)
133 return len(self.targets)
136 elif isinstance(self.targets, int):
134 elif isinstance(self.targets, int):
137 return 1
135 return 1
138 else:
136 else:
139 return len(self.client)
137 return len(self.client)
140
138
141 def set_flags(self, **kwargs):
139 def set_flags(self, **kwargs):
142 """set my attribute flags by keyword.
140 """set my attribute flags by keyword.
143
141
144 Views determine behavior with a few attributes (`block`, `track`, etc.).
142 Views determine behavior with a few attributes (`block`, `track`, etc.).
145 These attributes can be set all at once by name with this method.
143 These attributes can be set all at once by name with this method.
146
144
147 Parameters
145 Parameters
148 ----------
146 ----------
149
147
150 block : bool
148 block : bool
151 whether to wait for results
149 whether to wait for results
152 track : bool
150 track : bool
153 whether to create a MessageTracker to allow the user to
151 whether to create a MessageTracker to allow the user to
154 safely edit after arrays and buffers during non-copying
152 safely edit after arrays and buffers during non-copying
155 sends.
153 sends.
156 """
154 """
157 for name, value in iteritems(kwargs):
155 for name, value in iteritems(kwargs):
158 if name not in self._flag_names:
156 if name not in self._flag_names:
159 raise KeyError("Invalid name: %r"%name)
157 raise KeyError("Invalid name: %r"%name)
160 else:
158 else:
161 setattr(self, name, value)
159 setattr(self, name, value)
162
160
163 @contextmanager
161 @contextmanager
164 def temp_flags(self, **kwargs):
162 def temp_flags(self, **kwargs):
165 """temporarily set flags, for use in `with` statements.
163 """temporarily set flags, for use in `with` statements.
166
164
167 See set_flags for permanent setting of flags
165 See set_flags for permanent setting of flags
168
166
169 Examples
167 Examples
170 --------
168 --------
171
169
172 >>> view.track=False
170 >>> view.track=False
173 ...
171 ...
174 >>> with view.temp_flags(track=True):
172 >>> with view.temp_flags(track=True):
175 ... ar = view.apply(dostuff, my_big_array)
173 ... ar = view.apply(dostuff, my_big_array)
176 ... ar.tracker.wait() # wait for send to finish
174 ... ar.tracker.wait() # wait for send to finish
177 >>> view.track
175 >>> view.track
178 False
176 False
179
177
180 """
178 """
181 # preflight: save flags, and set temporaries
179 # preflight: save flags, and set temporaries
182 saved_flags = {}
180 saved_flags = {}
183 for f in self._flag_names:
181 for f in self._flag_names:
184 saved_flags[f] = getattr(self, f)
182 saved_flags[f] = getattr(self, f)
185 self.set_flags(**kwargs)
183 self.set_flags(**kwargs)
186 # yield to the with-statement block
184 # yield to the with-statement block
187 try:
185 try:
188 yield
186 yield
189 finally:
187 finally:
190 # postflight: restore saved flags
188 # postflight: restore saved flags
191 self.set_flags(**saved_flags)
189 self.set_flags(**saved_flags)
192
190
193
191
194 #----------------------------------------------------------------
192 #----------------------------------------------------------------
195 # apply
193 # apply
196 #----------------------------------------------------------------
194 #----------------------------------------------------------------
197
195
198 def _sync_results(self):
196 def _sync_results(self):
199 """to be called by @sync_results decorator
197 """to be called by @sync_results decorator
200
198
201 after submitting any tasks.
199 after submitting any tasks.
202 """
200 """
203 delta = self.outstanding.difference(self.client.outstanding)
201 delta = self.outstanding.difference(self.client.outstanding)
204 completed = self.outstanding.intersection(delta)
202 completed = self.outstanding.intersection(delta)
205 self.outstanding = self.outstanding.difference(completed)
203 self.outstanding = self.outstanding.difference(completed)
206
204
207 @sync_results
205 @sync_results
208 @save_ids
206 @save_ids
209 def _really_apply(self, f, args, kwargs, block=None, **options):
207 def _really_apply(self, f, args, kwargs, block=None, **options):
210 """wrapper for client.send_apply_request"""
208 """wrapper for client.send_apply_request"""
211 raise NotImplementedError("Implement in subclasses")
209 raise NotImplementedError("Implement in subclasses")
212
210
213 def apply(self, f, *args, **kwargs):
211 def apply(self, f, *args, **kwargs):
214 """calls ``f(*args, **kwargs)`` on remote engines, returning the result.
212 """calls ``f(*args, **kwargs)`` on remote engines, returning the result.
215
213
216 This method sets all apply flags via this View's attributes.
214 This method sets all apply flags via this View's attributes.
217
215
218 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult`
216 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult`
219 instance if ``self.block`` is False, otherwise the return value of
217 instance if ``self.block`` is False, otherwise the return value of
220 ``f(*args, **kwargs)``.
218 ``f(*args, **kwargs)``.
221 """
219 """
222 return self._really_apply(f, args, kwargs)
220 return self._really_apply(f, args, kwargs)
223
221
224 def apply_async(self, f, *args, **kwargs):
222 def apply_async(self, f, *args, **kwargs):
225 """calls ``f(*args, **kwargs)`` on remote engines in a nonblocking manner.
223 """calls ``f(*args, **kwargs)`` on remote engines in a nonblocking manner.
226
224
227 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult` instance.
225 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult` instance.
228 """
226 """
229 return self._really_apply(f, args, kwargs, block=False)
227 return self._really_apply(f, args, kwargs, block=False)
230
228
231 @spin_after
229 @spin_after
232 def apply_sync(self, f, *args, **kwargs):
230 def apply_sync(self, f, *args, **kwargs):
233 """calls ``f(*args, **kwargs)`` on remote engines in a blocking manner,
231 """calls ``f(*args, **kwargs)`` on remote engines in a blocking manner,
234 returning the result.
232 returning the result.
235 """
233 """
236 return self._really_apply(f, args, kwargs, block=True)
234 return self._really_apply(f, args, kwargs, block=True)
237
235
238 #----------------------------------------------------------------
236 #----------------------------------------------------------------
239 # wrappers for client and control methods
237 # wrappers for client and control methods
240 #----------------------------------------------------------------
238 #----------------------------------------------------------------
241 @sync_results
239 @sync_results
242 def spin(self):
240 def spin(self):
243 """spin the client, and sync"""
241 """spin the client, and sync"""
244 self.client.spin()
242 self.client.spin()
245
243
246 @sync_results
244 @sync_results
247 def wait(self, jobs=None, timeout=-1):
245 def wait(self, jobs=None, timeout=-1):
248 """waits on one or more `jobs`, for up to `timeout` seconds.
246 """waits on one or more `jobs`, for up to `timeout` seconds.
249
247
250 Parameters
248 Parameters
251 ----------
249 ----------
252
250
253 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
251 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
254 ints are indices to self.history
252 ints are indices to self.history
255 strs are msg_ids
253 strs are msg_ids
256 default: wait on all outstanding messages
254 default: wait on all outstanding messages
257 timeout : float
255 timeout : float
258 a time in seconds, after which to give up.
256 a time in seconds, after which to give up.
259 default is -1, which means no timeout
257 default is -1, which means no timeout
260
258
261 Returns
259 Returns
262 -------
260 -------
263
261
264 True : when all msg_ids are done
262 True : when all msg_ids are done
265 False : timeout reached, some msg_ids still outstanding
263 False : timeout reached, some msg_ids still outstanding
266 """
264 """
267 if jobs is None:
265 if jobs is None:
268 jobs = self.history
266 jobs = self.history
269 return self.client.wait(jobs, timeout)
267 return self.client.wait(jobs, timeout)
270
268
271 def abort(self, jobs=None, targets=None, block=None):
269 def abort(self, jobs=None, targets=None, block=None):
272 """Abort jobs on my engines.
270 """Abort jobs on my engines.
273
271
274 Parameters
272 Parameters
275 ----------
273 ----------
276
274
277 jobs : None, str, list of strs, optional
275 jobs : None, str, list of strs, optional
278 if None: abort all jobs.
276 if None: abort all jobs.
279 else: abort specific msg_id(s).
277 else: abort specific msg_id(s).
280 """
278 """
281 block = block if block is not None else self.block
279 block = block if block is not None else self.block
282 targets = targets if targets is not None else self.targets
280 targets = targets if targets is not None else self.targets
283 jobs = jobs if jobs is not None else list(self.outstanding)
281 jobs = jobs if jobs is not None else list(self.outstanding)
284
282
285 return self.client.abort(jobs=jobs, targets=targets, block=block)
283 return self.client.abort(jobs=jobs, targets=targets, block=block)
286
284
287 def queue_status(self, targets=None, verbose=False):
285 def queue_status(self, targets=None, verbose=False):
288 """Fetch the Queue status of my engines"""
286 """Fetch the Queue status of my engines"""
289 targets = targets if targets is not None else self.targets
287 targets = targets if targets is not None else self.targets
290 return self.client.queue_status(targets=targets, verbose=verbose)
288 return self.client.queue_status(targets=targets, verbose=verbose)
291
289
292 def purge_results(self, jobs=[], targets=[]):
290 def purge_results(self, jobs=[], targets=[]):
293 """Instruct the controller to forget specific results."""
291 """Instruct the controller to forget specific results."""
294 if targets is None or targets == 'all':
292 if targets is None or targets == 'all':
295 targets = self.targets
293 targets = self.targets
296 return self.client.purge_results(jobs=jobs, targets=targets)
294 return self.client.purge_results(jobs=jobs, targets=targets)
297
295
298 def shutdown(self, targets=None, restart=False, hub=False, block=None):
296 def shutdown(self, targets=None, restart=False, hub=False, block=None):
299 """Terminates one or more engine processes, optionally including the hub.
297 """Terminates one or more engine processes, optionally including the hub.
300 """
298 """
301 block = self.block if block is None else block
299 block = self.block if block is None else block
302 if targets is None or targets == 'all':
300 if targets is None or targets == 'all':
303 targets = self.targets
301 targets = self.targets
304 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
302 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
305
303
306 @spin_after
304 @spin_after
307 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
305 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
308 """return one or more results, specified by history index or msg_id.
306 """return one or more results, specified by history index or msg_id.
309
307
310 See :meth:`IPython.parallel.client.client.Client.get_result` for details.
308 See :meth:`IPython.parallel.client.client.Client.get_result` for details.
311 """
309 """
312
310
313 if indices_or_msg_ids is None:
311 if indices_or_msg_ids is None:
314 indices_or_msg_ids = -1
312 indices_or_msg_ids = -1
315 if isinstance(indices_or_msg_ids, int):
313 if isinstance(indices_or_msg_ids, int):
316 indices_or_msg_ids = self.history[indices_or_msg_ids]
314 indices_or_msg_ids = self.history[indices_or_msg_ids]
317 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
315 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
318 indices_or_msg_ids = list(indices_or_msg_ids)
316 indices_or_msg_ids = list(indices_or_msg_ids)
319 for i,index in enumerate(indices_or_msg_ids):
317 for i,index in enumerate(indices_or_msg_ids):
320 if isinstance(index, int):
318 if isinstance(index, int):
321 indices_or_msg_ids[i] = self.history[index]
319 indices_or_msg_ids[i] = self.history[index]
322 return self.client.get_result(indices_or_msg_ids, block=block, owner=owner)
320 return self.client.get_result(indices_or_msg_ids, block=block, owner=owner)
323
321
324 #-------------------------------------------------------------------
322 #-------------------------------------------------------------------
325 # Map
323 # Map
326 #-------------------------------------------------------------------
324 #-------------------------------------------------------------------
327
325
328 @sync_results
326 @sync_results
329 def map(self, f, *sequences, **kwargs):
327 def map(self, f, *sequences, **kwargs):
330 """override in subclasses"""
328 """override in subclasses"""
331 raise NotImplementedError
329 raise NotImplementedError
332
330
333 def map_async(self, f, *sequences, **kwargs):
331 def map_async(self, f, *sequences, **kwargs):
334 """Parallel version of builtin :func:`python:map`, using this view's engines.
332 """Parallel version of builtin :func:`python:map`, using this view's engines.
335
333
336 This is equivalent to ``map(...block=False)``.
334 This is equivalent to ``map(...block=False)``.
337
335
338 See `self.map` for details.
336 See `self.map` for details.
339 """
337 """
340 if 'block' in kwargs:
338 if 'block' in kwargs:
341 raise TypeError("map_async doesn't take a `block` keyword argument.")
339 raise TypeError("map_async doesn't take a `block` keyword argument.")
342 kwargs['block'] = False
340 kwargs['block'] = False
343 return self.map(f,*sequences,**kwargs)
341 return self.map(f,*sequences,**kwargs)
344
342
345 def map_sync(self, f, *sequences, **kwargs):
343 def map_sync(self, f, *sequences, **kwargs):
346 """Parallel version of builtin :func:`python:map`, using this view's engines.
344 """Parallel version of builtin :func:`python:map`, using this view's engines.
347
345
348 This is equivalent to ``map(...block=True)``.
346 This is equivalent to ``map(...block=True)``.
349
347
350 See `self.map` for details.
348 See `self.map` for details.
351 """
349 """
352 if 'block' in kwargs:
350 if 'block' in kwargs:
353 raise TypeError("map_sync doesn't take a `block` keyword argument.")
351 raise TypeError("map_sync doesn't take a `block` keyword argument.")
354 kwargs['block'] = True
352 kwargs['block'] = True
355 return self.map(f,*sequences,**kwargs)
353 return self.map(f,*sequences,**kwargs)
356
354
357 def imap(self, f, *sequences, **kwargs):
355 def imap(self, f, *sequences, **kwargs):
358 """Parallel version of :func:`itertools.imap`.
356 """Parallel version of :func:`itertools.imap`.
359
357
360 See `self.map` for details.
358 See `self.map` for details.
361
359
362 """
360 """
363
361
364 return iter(self.map_async(f,*sequences, **kwargs))
362 return iter(self.map_async(f,*sequences, **kwargs))
365
363
366 #-------------------------------------------------------------------
364 #-------------------------------------------------------------------
367 # Decorators
365 # Decorators
368 #-------------------------------------------------------------------
366 #-------------------------------------------------------------------
369
367
370 def remote(self, block=None, **flags):
368 def remote(self, block=None, **flags):
371 """Decorator for making a RemoteFunction"""
369 """Decorator for making a RemoteFunction"""
372 block = self.block if block is None else block
370 block = self.block if block is None else block
373 return remote(self, block=block, **flags)
371 return remote(self, block=block, **flags)
374
372
375 def parallel(self, dist='b', block=None, **flags):
373 def parallel(self, dist='b', block=None, **flags):
376 """Decorator for making a ParallelFunction"""
374 """Decorator for making a ParallelFunction"""
377 block = self.block if block is None else block
375 block = self.block if block is None else block
378 return parallel(self, dist=dist, block=block, **flags)
376 return parallel(self, dist=dist, block=block, **flags)
379
377
380 @skip_doctest
381 class DirectView(View):
378 class DirectView(View):
382 """Direct Multiplexer View of one or more engines.
379 """Direct Multiplexer View of one or more engines.
383
380
384 These are created via indexed access to a client:
381 These are created via indexed access to a client:
385
382
386 >>> dv_1 = client[1]
383 >>> dv_1 = client[1]
387 >>> dv_all = client[:]
384 >>> dv_all = client[:]
388 >>> dv_even = client[::2]
385 >>> dv_even = client[::2]
389 >>> dv_some = client[1:3]
386 >>> dv_some = client[1:3]
390
387
391 This object provides dictionary access to engine namespaces:
388 This object provides dictionary access to engine namespaces:
392
389
393 # push a=5:
390 # push a=5:
394 >>> dv['a'] = 5
391 >>> dv['a'] = 5
395 # pull 'foo':
392 # pull 'foo':
396 >>> dv['foo']
393 >>> dv['foo']
397
394
398 """
395 """
399
396
400 def __init__(self, client=None, socket=None, targets=None):
397 def __init__(self, client=None, socket=None, targets=None):
401 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
398 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
402
399
403 @property
400 @property
404 def importer(self):
401 def importer(self):
405 """sync_imports(local=True) as a property.
402 """sync_imports(local=True) as a property.
406
403
407 See sync_imports for details.
404 See sync_imports for details.
408
405
409 """
406 """
410 return self.sync_imports(True)
407 return self.sync_imports(True)
411
408
412 @contextmanager
409 @contextmanager
413 def sync_imports(self, local=True, quiet=False):
410 def sync_imports(self, local=True, quiet=False):
414 """Context Manager for performing simultaneous local and remote imports.
411 """Context Manager for performing simultaneous local and remote imports.
415
412
416 'import x as y' will *not* work. The 'as y' part will simply be ignored.
413 'import x as y' will *not* work. The 'as y' part will simply be ignored.
417
414
418 If `local=True`, then the package will also be imported locally.
415 If `local=True`, then the package will also be imported locally.
419
416
420 If `quiet=True`, no output will be produced when attempting remote
417 If `quiet=True`, no output will be produced when attempting remote
421 imports.
418 imports.
422
419
423 Note that remote-only (`local=False`) imports have not been implemented.
420 Note that remote-only (`local=False`) imports have not been implemented.
424
421
425 >>> with view.sync_imports():
422 >>> with view.sync_imports():
426 ... from numpy import recarray
423 ... from numpy import recarray
427 importing recarray from numpy on engine(s)
424 importing recarray from numpy on engine(s)
428
425
429 """
426 """
430 from IPython.utils.py3compat import builtin_mod
427 from IPython.utils.py3compat import builtin_mod
431 local_import = builtin_mod.__import__
428 local_import = builtin_mod.__import__
432 modules = set()
429 modules = set()
433 results = []
430 results = []
434 @util.interactive
431 @util.interactive
435 def remote_import(name, fromlist, level):
432 def remote_import(name, fromlist, level):
436 """the function to be passed to apply, that actually performs the import
433 """the function to be passed to apply, that actually performs the import
437 on the engine, and loads up the user namespace.
434 on the engine, and loads up the user namespace.
438 """
435 """
439 import sys
436 import sys
440 user_ns = globals()
437 user_ns = globals()
441 mod = __import__(name, fromlist=fromlist, level=level)
438 mod = __import__(name, fromlist=fromlist, level=level)
442 if fromlist:
439 if fromlist:
443 for key in fromlist:
440 for key in fromlist:
444 user_ns[key] = getattr(mod, key)
441 user_ns[key] = getattr(mod, key)
445 else:
442 else:
446 user_ns[name] = sys.modules[name]
443 user_ns[name] = sys.modules[name]
447
444
448 def view_import(name, globals={}, locals={}, fromlist=[], level=0):
445 def view_import(name, globals={}, locals={}, fromlist=[], level=0):
449 """the drop-in replacement for __import__, that optionally imports
446 """the drop-in replacement for __import__, that optionally imports
450 locally as well.
447 locally as well.
451 """
448 """
452 # don't override nested imports
449 # don't override nested imports
453 save_import = builtin_mod.__import__
450 save_import = builtin_mod.__import__
454 builtin_mod.__import__ = local_import
451 builtin_mod.__import__ = local_import
455
452
456 if imp.lock_held():
453 if imp.lock_held():
457 # this is a side-effect import, don't do it remotely, or even
454 # this is a side-effect import, don't do it remotely, or even
458 # ignore the local effects
455 # ignore the local effects
459 return local_import(name, globals, locals, fromlist, level)
456 return local_import(name, globals, locals, fromlist, level)
460
457
461 imp.acquire_lock()
458 imp.acquire_lock()
462 if local:
459 if local:
463 mod = local_import(name, globals, locals, fromlist, level)
460 mod = local_import(name, globals, locals, fromlist, level)
464 else:
461 else:
465 raise NotImplementedError("remote-only imports not yet implemented")
462 raise NotImplementedError("remote-only imports not yet implemented")
466 imp.release_lock()
463 imp.release_lock()
467
464
468 key = name+':'+','.join(fromlist or [])
465 key = name+':'+','.join(fromlist or [])
469 if level <= 0 and key not in modules:
466 if level <= 0 and key not in modules:
470 modules.add(key)
467 modules.add(key)
471 if not quiet:
468 if not quiet:
472 if fromlist:
469 if fromlist:
473 print("importing %s from %s on engine(s)"%(','.join(fromlist), name))
470 print("importing %s from %s on engine(s)"%(','.join(fromlist), name))
474 else:
471 else:
475 print("importing %s on engine(s)"%name)
472 print("importing %s on engine(s)"%name)
476 results.append(self.apply_async(remote_import, name, fromlist, level))
473 results.append(self.apply_async(remote_import, name, fromlist, level))
477 # restore override
474 # restore override
478 builtin_mod.__import__ = save_import
475 builtin_mod.__import__ = save_import
479
476
480 return mod
477 return mod
481
478
482 # override __import__
479 # override __import__
483 builtin_mod.__import__ = view_import
480 builtin_mod.__import__ = view_import
484 try:
481 try:
485 # enter the block
482 # enter the block
486 yield
483 yield
487 except ImportError:
484 except ImportError:
488 if local:
485 if local:
489 raise
486 raise
490 else:
487 else:
491 # ignore import errors if not doing local imports
488 # ignore import errors if not doing local imports
492 pass
489 pass
493 finally:
490 finally:
494 # always restore __import__
491 # always restore __import__
495 builtin_mod.__import__ = local_import
492 builtin_mod.__import__ = local_import
496
493
497 for r in results:
494 for r in results:
498 # raise possible remote ImportErrors here
495 # raise possible remote ImportErrors here
499 r.get()
496 r.get()
500
497
501 def use_dill(self):
498 def use_dill(self):
502 """Expand serialization support with dill
499 """Expand serialization support with dill
503
500
504 adds support for closures, etc.
501 adds support for closures, etc.
505
502
506 This calls ipython_kernel.pickleutil.use_dill() here and on each engine.
503 This calls ipython_kernel.pickleutil.use_dill() here and on each engine.
507 """
504 """
508 pickleutil.use_dill()
505 pickleutil.use_dill()
509 return self.apply(pickleutil.use_dill)
506 return self.apply(pickleutil.use_dill)
510
507
511 def use_cloudpickle(self):
508 def use_cloudpickle(self):
512 """Expand serialization support with cloudpickle.
509 """Expand serialization support with cloudpickle.
513 """
510 """
514 pickleutil.use_cloudpickle()
511 pickleutil.use_cloudpickle()
515 return self.apply(pickleutil.use_cloudpickle)
512 return self.apply(pickleutil.use_cloudpickle)
516
513
517
514
518 @sync_results
515 @sync_results
519 @save_ids
516 @save_ids
520 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
517 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
521 """calls f(*args, **kwargs) on remote engines, returning the result.
518 """calls f(*args, **kwargs) on remote engines, returning the result.
522
519
523 This method sets all of `apply`'s flags via this View's attributes.
520 This method sets all of `apply`'s flags via this View's attributes.
524
521
525 Parameters
522 Parameters
526 ----------
523 ----------
527
524
528 f : callable
525 f : callable
529
526
530 args : list [default: empty]
527 args : list [default: empty]
531
528
532 kwargs : dict [default: empty]
529 kwargs : dict [default: empty]
533
530
534 targets : target list [default: self.targets]
531 targets : target list [default: self.targets]
535 where to run
532 where to run
536 block : bool [default: self.block]
533 block : bool [default: self.block]
537 whether to block
534 whether to block
538 track : bool [default: self.track]
535 track : bool [default: self.track]
539 whether to ask zmq to track the message, for safe non-copying sends
536 whether to ask zmq to track the message, for safe non-copying sends
540
537
541 Returns
538 Returns
542 -------
539 -------
543
540
544 if self.block is False:
541 if self.block is False:
545 returns AsyncResult
542 returns AsyncResult
546 else:
543 else:
547 returns actual result of f(*args, **kwargs) on the engine(s)
544 returns actual result of f(*args, **kwargs) on the engine(s)
548 This will be a list of self.targets is also a list (even length 1), or
545 This will be a list of self.targets is also a list (even length 1), or
549 the single result if self.targets is an integer engine id
546 the single result if self.targets is an integer engine id
550 """
547 """
551 args = [] if args is None else args
548 args = [] if args is None else args
552 kwargs = {} if kwargs is None else kwargs
549 kwargs = {} if kwargs is None else kwargs
553 block = self.block if block is None else block
550 block = self.block if block is None else block
554 track = self.track if track is None else track
551 track = self.track if track is None else track
555 targets = self.targets if targets is None else targets
552 targets = self.targets if targets is None else targets
556
553
557 _idents, _targets = self.client._build_targets(targets)
554 _idents, _targets = self.client._build_targets(targets)
558 msg_ids = []
555 msg_ids = []
559 trackers = []
556 trackers = []
560 for ident in _idents:
557 for ident in _idents:
561 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
558 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
562 ident=ident)
559 ident=ident)
563 if track:
560 if track:
564 trackers.append(msg['tracker'])
561 trackers.append(msg['tracker'])
565 msg_ids.append(msg['header']['msg_id'])
562 msg_ids.append(msg['header']['msg_id'])
566 if isinstance(targets, int):
563 if isinstance(targets, int):
567 msg_ids = msg_ids[0]
564 msg_ids = msg_ids[0]
568 tracker = None if track is False else zmq.MessageTracker(*trackers)
565 tracker = None if track is False else zmq.MessageTracker(*trackers)
569 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets,
566 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets,
570 tracker=tracker, owner=True,
567 tracker=tracker, owner=True,
571 )
568 )
572 if block:
569 if block:
573 try:
570 try:
574 return ar.get()
571 return ar.get()
575 except KeyboardInterrupt:
572 except KeyboardInterrupt:
576 pass
573 pass
577 return ar
574 return ar
578
575
579
576
580 @sync_results
577 @sync_results
581 def map(self, f, *sequences, **kwargs):
578 def map(self, f, *sequences, **kwargs):
582 """``view.map(f, *sequences, block=self.block)`` => list|AsyncMapResult
579 """``view.map(f, *sequences, block=self.block)`` => list|AsyncMapResult
583
580
584 Parallel version of builtin `map`, using this View's `targets`.
581 Parallel version of builtin `map`, using this View's `targets`.
585
582
586 There will be one task per target, so work will be chunked
583 There will be one task per target, so work will be chunked
587 if the sequences are longer than `targets`.
584 if the sequences are longer than `targets`.
588
585
589 Results can be iterated as they are ready, but will become available in chunks.
586 Results can be iterated as they are ready, but will become available in chunks.
590
587
591 Parameters
588 Parameters
592 ----------
589 ----------
593
590
594 f : callable
591 f : callable
595 function to be mapped
592 function to be mapped
596 *sequences: one or more sequences of matching length
593 *sequences: one or more sequences of matching length
597 the sequences to be distributed and passed to `f`
594 the sequences to be distributed and passed to `f`
598 block : bool
595 block : bool
599 whether to wait for the result or not [default self.block]
596 whether to wait for the result or not [default self.block]
600
597
601 Returns
598 Returns
602 -------
599 -------
603
600
604
601
605 If block=False
602 If block=False
606 An :class:`~ipython_parallel.client.asyncresult.AsyncMapResult` instance.
603 An :class:`~ipython_parallel.client.asyncresult.AsyncMapResult` instance.
607 An object like AsyncResult, but which reassembles the sequence of results
604 An object like AsyncResult, but which reassembles the sequence of results
608 into a single list. AsyncMapResults can be iterated through before all
605 into a single list. AsyncMapResults can be iterated through before all
609 results are complete.
606 results are complete.
610 else
607 else
611 A list, the result of ``map(f,*sequences)``
608 A list, the result of ``map(f,*sequences)``
612 """
609 """
613
610
614 block = kwargs.pop('block', self.block)
611 block = kwargs.pop('block', self.block)
615 for k in kwargs.keys():
612 for k in kwargs.keys():
616 if k not in ['block', 'track']:
613 if k not in ['block', 'track']:
617 raise TypeError("invalid keyword arg, %r"%k)
614 raise TypeError("invalid keyword arg, %r"%k)
618
615
619 assert len(sequences) > 0, "must have some sequences to map onto!"
616 assert len(sequences) > 0, "must have some sequences to map onto!"
620 pf = ParallelFunction(self, f, block=block, **kwargs)
617 pf = ParallelFunction(self, f, block=block, **kwargs)
621 return pf.map(*sequences)
618 return pf.map(*sequences)
622
619
623 @sync_results
620 @sync_results
624 @save_ids
621 @save_ids
625 def execute(self, code, silent=True, targets=None, block=None):
622 def execute(self, code, silent=True, targets=None, block=None):
626 """Executes `code` on `targets` in blocking or nonblocking manner.
623 """Executes `code` on `targets` in blocking or nonblocking manner.
627
624
628 ``execute`` is always `bound` (affects engine namespace)
625 ``execute`` is always `bound` (affects engine namespace)
629
626
630 Parameters
627 Parameters
631 ----------
628 ----------
632
629
633 code : str
630 code : str
634 the code string to be executed
631 the code string to be executed
635 block : bool
632 block : bool
636 whether or not to wait until done to return
633 whether or not to wait until done to return
637 default: self.block
634 default: self.block
638 """
635 """
639 block = self.block if block is None else block
636 block = self.block if block is None else block
640 targets = self.targets if targets is None else targets
637 targets = self.targets if targets is None else targets
641
638
642 _idents, _targets = self.client._build_targets(targets)
639 _idents, _targets = self.client._build_targets(targets)
643 msg_ids = []
640 msg_ids = []
644 trackers = []
641 trackers = []
645 for ident in _idents:
642 for ident in _idents:
646 msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident)
643 msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident)
647 msg_ids.append(msg['header']['msg_id'])
644 msg_ids.append(msg['header']['msg_id'])
648 if isinstance(targets, int):
645 if isinstance(targets, int):
649 msg_ids = msg_ids[0]
646 msg_ids = msg_ids[0]
650 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets, owner=True)
647 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets, owner=True)
651 if block:
648 if block:
652 try:
649 try:
653 ar.get()
650 ar.get()
654 except KeyboardInterrupt:
651 except KeyboardInterrupt:
655 pass
652 pass
656 return ar
653 return ar
657
654
658 def run(self, filename, targets=None, block=None):
655 def run(self, filename, targets=None, block=None):
659 """Execute contents of `filename` on my engine(s).
656 """Execute contents of `filename` on my engine(s).
660
657
661 This simply reads the contents of the file and calls `execute`.
658 This simply reads the contents of the file and calls `execute`.
662
659
663 Parameters
660 Parameters
664 ----------
661 ----------
665
662
666 filename : str
663 filename : str
667 The path to the file
664 The path to the file
668 targets : int/str/list of ints/strs
665 targets : int/str/list of ints/strs
669 the engines on which to execute
666 the engines on which to execute
670 default : all
667 default : all
671 block : bool
668 block : bool
672 whether or not to wait until done
669 whether or not to wait until done
673 default: self.block
670 default: self.block
674
671
675 """
672 """
676 with open(filename, 'r') as f:
673 with open(filename, 'r') as f:
677 # add newline in case of trailing indented whitespace
674 # add newline in case of trailing indented whitespace
678 # which will cause SyntaxError
675 # which will cause SyntaxError
679 code = f.read()+'\n'
676 code = f.read()+'\n'
680 return self.execute(code, block=block, targets=targets)
677 return self.execute(code, block=block, targets=targets)
681
678
682 def update(self, ns):
679 def update(self, ns):
683 """update remote namespace with dict `ns`
680 """update remote namespace with dict `ns`
684
681
685 See `push` for details.
682 See `push` for details.
686 """
683 """
687 return self.push(ns, block=self.block, track=self.track)
684 return self.push(ns, block=self.block, track=self.track)
688
685
689 def push(self, ns, targets=None, block=None, track=None):
686 def push(self, ns, targets=None, block=None, track=None):
690 """update remote namespace with dict `ns`
687 """update remote namespace with dict `ns`
691
688
692 Parameters
689 Parameters
693 ----------
690 ----------
694
691
695 ns : dict
692 ns : dict
696 dict of keys with which to update engine namespace(s)
693 dict of keys with which to update engine namespace(s)
697 block : bool [default : self.block]
694 block : bool [default : self.block]
698 whether to wait to be notified of engine receipt
695 whether to wait to be notified of engine receipt
699
696
700 """
697 """
701
698
702 block = block if block is not None else self.block
699 block = block if block is not None else self.block
703 track = track if track is not None else self.track
700 track = track if track is not None else self.track
704 targets = targets if targets is not None else self.targets
701 targets = targets if targets is not None else self.targets
705 # applier = self.apply_sync if block else self.apply_async
702 # applier = self.apply_sync if block else self.apply_async
706 if not isinstance(ns, dict):
703 if not isinstance(ns, dict):
707 raise TypeError("Must be a dict, not %s"%type(ns))
704 raise TypeError("Must be a dict, not %s"%type(ns))
708 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
705 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
709
706
710 def get(self, key_s):
707 def get(self, key_s):
711 """get object(s) by `key_s` from remote namespace
708 """get object(s) by `key_s` from remote namespace
712
709
713 see `pull` for details.
710 see `pull` for details.
714 """
711 """
715 # block = block if block is not None else self.block
712 # block = block if block is not None else self.block
716 return self.pull(key_s, block=True)
713 return self.pull(key_s, block=True)
717
714
718 def pull(self, names, targets=None, block=None):
715 def pull(self, names, targets=None, block=None):
719 """get object(s) by `name` from remote namespace
716 """get object(s) by `name` from remote namespace
720
717
721 will return one object if it is a key.
718 will return one object if it is a key.
722 can also take a list of keys, in which case it will return a list of objects.
719 can also take a list of keys, in which case it will return a list of objects.
723 """
720 """
724 block = block if block is not None else self.block
721 block = block if block is not None else self.block
725 targets = targets if targets is not None else self.targets
722 targets = targets if targets is not None else self.targets
726 applier = self.apply_sync if block else self.apply_async
723 applier = self.apply_sync if block else self.apply_async
727 if isinstance(names, string_types):
724 if isinstance(names, string_types):
728 pass
725 pass
729 elif isinstance(names, (list,tuple,set)):
726 elif isinstance(names, (list,tuple,set)):
730 for key in names:
727 for key in names:
731 if not isinstance(key, string_types):
728 if not isinstance(key, string_types):
732 raise TypeError("keys must be str, not type %r"%type(key))
729 raise TypeError("keys must be str, not type %r"%type(key))
733 else:
730 else:
734 raise TypeError("names must be strs, not %r"%names)
731 raise TypeError("names must be strs, not %r"%names)
735 return self._really_apply(util._pull, (names,), block=block, targets=targets)
732 return self._really_apply(util._pull, (names,), block=block, targets=targets)
736
733
737 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
734 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
738 """
735 """
739 Partition a Python sequence and send the partitions to a set of engines.
736 Partition a Python sequence and send the partitions to a set of engines.
740 """
737 """
741 block = block if block is not None else self.block
738 block = block if block is not None else self.block
742 track = track if track is not None else self.track
739 track = track if track is not None else self.track
743 targets = targets if targets is not None else self.targets
740 targets = targets if targets is not None else self.targets
744
741
745 # construct integer ID list:
742 # construct integer ID list:
746 targets = self.client._build_targets(targets)[1]
743 targets = self.client._build_targets(targets)[1]
747
744
748 mapObject = Map.dists[dist]()
745 mapObject = Map.dists[dist]()
749 nparts = len(targets)
746 nparts = len(targets)
750 msg_ids = []
747 msg_ids = []
751 trackers = []
748 trackers = []
752 for index, engineid in enumerate(targets):
749 for index, engineid in enumerate(targets):
753 partition = mapObject.getPartition(seq, index, nparts)
750 partition = mapObject.getPartition(seq, index, nparts)
754 if flatten and len(partition) == 1:
751 if flatten and len(partition) == 1:
755 ns = {key: partition[0]}
752 ns = {key: partition[0]}
756 else:
753 else:
757 ns = {key: partition}
754 ns = {key: partition}
758 r = self.push(ns, block=False, track=track, targets=engineid)
755 r = self.push(ns, block=False, track=track, targets=engineid)
759 msg_ids.extend(r.msg_ids)
756 msg_ids.extend(r.msg_ids)
760 if track:
757 if track:
761 trackers.append(r._tracker)
758 trackers.append(r._tracker)
762
759
763 if track:
760 if track:
764 tracker = zmq.MessageTracker(*trackers)
761 tracker = zmq.MessageTracker(*trackers)
765 else:
762 else:
766 tracker = None
763 tracker = None
767
764
768 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets,
765 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets,
769 tracker=tracker, owner=True,
766 tracker=tracker, owner=True,
770 )
767 )
771 if block:
768 if block:
772 r.wait()
769 r.wait()
773 else:
770 else:
774 return r
771 return r
775
772
776 @sync_results
773 @sync_results
777 @save_ids
774 @save_ids
778 def gather(self, key, dist='b', targets=None, block=None):
775 def gather(self, key, dist='b', targets=None, block=None):
779 """
776 """
780 Gather a partitioned sequence on a set of engines as a single local seq.
777 Gather a partitioned sequence on a set of engines as a single local seq.
781 """
778 """
782 block = block if block is not None else self.block
779 block = block if block is not None else self.block
783 targets = targets if targets is not None else self.targets
780 targets = targets if targets is not None else self.targets
784 mapObject = Map.dists[dist]()
781 mapObject = Map.dists[dist]()
785 msg_ids = []
782 msg_ids = []
786
783
787 # construct integer ID list:
784 # construct integer ID list:
788 targets = self.client._build_targets(targets)[1]
785 targets = self.client._build_targets(targets)[1]
789
786
790 for index, engineid in enumerate(targets):
787 for index, engineid in enumerate(targets):
791 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
788 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
792
789
793 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
790 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
794
791
795 if block:
792 if block:
796 try:
793 try:
797 return r.get()
794 return r.get()
798 except KeyboardInterrupt:
795 except KeyboardInterrupt:
799 pass
796 pass
800 return r
797 return r
801
798
802 def __getitem__(self, key):
799 def __getitem__(self, key):
803 return self.get(key)
800 return self.get(key)
804
801
805 def __setitem__(self,key, value):
802 def __setitem__(self,key, value):
806 self.update({key:value})
803 self.update({key:value})
807
804
808 def clear(self, targets=None, block=None):
805 def clear(self, targets=None, block=None):
809 """Clear the remote namespaces on my engines."""
806 """Clear the remote namespaces on my engines."""
810 block = block if block is not None else self.block
807 block = block if block is not None else self.block
811 targets = targets if targets is not None else self.targets
808 targets = targets if targets is not None else self.targets
812 return self.client.clear(targets=targets, block=block)
809 return self.client.clear(targets=targets, block=block)
813
810
814 #----------------------------------------
811 #----------------------------------------
815 # activate for %px, %autopx, etc. magics
812 # activate for %px, %autopx, etc. magics
816 #----------------------------------------
813 #----------------------------------------
817
814
818 def activate(self, suffix=''):
815 def activate(self, suffix=''):
819 """Activate IPython magics associated with this View
816 """Activate IPython magics associated with this View
820
817
821 Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig`
818 Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig`
822
819
823 Parameters
820 Parameters
824 ----------
821 ----------
825
822
826 suffix: str [default: '']
823 suffix: str [default: '']
827 The suffix, if any, for the magics. This allows you to have
824 The suffix, if any, for the magics. This allows you to have
828 multiple views associated with parallel magics at the same time.
825 multiple views associated with parallel magics at the same time.
829
826
830 e.g. ``rc[::2].activate(suffix='_even')`` will give you
827 e.g. ``rc[::2].activate(suffix='_even')`` will give you
831 the magics ``%px_even``, ``%pxresult_even``, etc. for running magics
828 the magics ``%px_even``, ``%pxresult_even``, etc. for running magics
832 on the even engines.
829 on the even engines.
833 """
830 """
834
831
835 from IPython.parallel.client.magics import ParallelMagics
832 from IPython.parallel.client.magics import ParallelMagics
836
833
837 try:
834 try:
838 # This is injected into __builtins__.
835 # This is injected into __builtins__.
839 ip = get_ipython()
836 ip = get_ipython()
840 except NameError:
837 except NameError:
841 print("The IPython parallel magics (%px, etc.) only work within IPython.")
838 print("The IPython parallel magics (%px, etc.) only work within IPython.")
842 return
839 return
843
840
844 M = ParallelMagics(ip, self, suffix)
841 M = ParallelMagics(ip, self, suffix)
845 ip.magics_manager.register(M)
842 ip.magics_manager.register(M)
846
843
847
844
848 @skip_doctest
849 class LoadBalancedView(View):
845 class LoadBalancedView(View):
850 """An load-balancing View that only executes via the Task scheduler.
846 """An load-balancing View that only executes via the Task scheduler.
851
847
852 Load-balanced views can be created with the client's `view` method:
848 Load-balanced views can be created with the client's `view` method:
853
849
854 >>> v = client.load_balanced_view()
850 >>> v = client.load_balanced_view()
855
851
856 or targets can be specified, to restrict the potential destinations:
852 or targets can be specified, to restrict the potential destinations:
857
853
858 >>> v = client.load_balanced_view([1,3])
854 >>> v = client.load_balanced_view([1,3])
859
855
860 which would restrict loadbalancing to between engines 1 and 3.
856 which would restrict loadbalancing to between engines 1 and 3.
861
857
862 """
858 """
863
859
864 follow=Any()
860 follow=Any()
865 after=Any()
861 after=Any()
866 timeout=CFloat()
862 timeout=CFloat()
867 retries = Integer(0)
863 retries = Integer(0)
868
864
869 _task_scheme = Any()
865 _task_scheme = Any()
870 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
866 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
871
867
872 def __init__(self, client=None, socket=None, **flags):
868 def __init__(self, client=None, socket=None, **flags):
873 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
869 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
874 self._task_scheme=client._task_scheme
870 self._task_scheme=client._task_scheme
875
871
876 def _validate_dependency(self, dep):
872 def _validate_dependency(self, dep):
877 """validate a dependency.
873 """validate a dependency.
878
874
879 For use in `set_flags`.
875 For use in `set_flags`.
880 """
876 """
881 if dep is None or isinstance(dep, string_types + (AsyncResult, Dependency)):
877 if dep is None or isinstance(dep, string_types + (AsyncResult, Dependency)):
882 return True
878 return True
883 elif isinstance(dep, (list,set, tuple)):
879 elif isinstance(dep, (list,set, tuple)):
884 for d in dep:
880 for d in dep:
885 if not isinstance(d, string_types + (AsyncResult,)):
881 if not isinstance(d, string_types + (AsyncResult,)):
886 return False
882 return False
887 elif isinstance(dep, dict):
883 elif isinstance(dep, dict):
888 if set(dep.keys()) != set(Dependency().as_dict().keys()):
884 if set(dep.keys()) != set(Dependency().as_dict().keys()):
889 return False
885 return False
890 if not isinstance(dep['msg_ids'], list):
886 if not isinstance(dep['msg_ids'], list):
891 return False
887 return False
892 for d in dep['msg_ids']:
888 for d in dep['msg_ids']:
893 if not isinstance(d, string_types):
889 if not isinstance(d, string_types):
894 return False
890 return False
895 else:
891 else:
896 return False
892 return False
897
893
898 return True
894 return True
899
895
900 def _render_dependency(self, dep):
896 def _render_dependency(self, dep):
901 """helper for building jsonable dependencies from various input forms."""
897 """helper for building jsonable dependencies from various input forms."""
902 if isinstance(dep, Dependency):
898 if isinstance(dep, Dependency):
903 return dep.as_dict()
899 return dep.as_dict()
904 elif isinstance(dep, AsyncResult):
900 elif isinstance(dep, AsyncResult):
905 return dep.msg_ids
901 return dep.msg_ids
906 elif dep is None:
902 elif dep is None:
907 return []
903 return []
908 else:
904 else:
909 # pass to Dependency constructor
905 # pass to Dependency constructor
910 return list(Dependency(dep))
906 return list(Dependency(dep))
911
907
912 def set_flags(self, **kwargs):
908 def set_flags(self, **kwargs):
913 """set my attribute flags by keyword.
909 """set my attribute flags by keyword.
914
910
915 A View is a wrapper for the Client's apply method, but with attributes
911 A View is a wrapper for the Client's apply method, but with attributes
916 that specify keyword arguments, those attributes can be set by keyword
912 that specify keyword arguments, those attributes can be set by keyword
917 argument with this method.
913 argument with this method.
918
914
919 Parameters
915 Parameters
920 ----------
916 ----------
921
917
922 block : bool
918 block : bool
923 whether to wait for results
919 whether to wait for results
924 track : bool
920 track : bool
925 whether to create a MessageTracker to allow the user to
921 whether to create a MessageTracker to allow the user to
926 safely edit after arrays and buffers during non-copying
922 safely edit after arrays and buffers during non-copying
927 sends.
923 sends.
928
924
929 after : Dependency or collection of msg_ids
925 after : Dependency or collection of msg_ids
930 Only for load-balanced execution (targets=None)
926 Only for load-balanced execution (targets=None)
931 Specify a list of msg_ids as a time-based dependency.
927 Specify a list of msg_ids as a time-based dependency.
932 This job will only be run *after* the dependencies
928 This job will only be run *after* the dependencies
933 have been met.
929 have been met.
934
930
935 follow : Dependency or collection of msg_ids
931 follow : Dependency or collection of msg_ids
936 Only for load-balanced execution (targets=None)
932 Only for load-balanced execution (targets=None)
937 Specify a list of msg_ids as a location-based dependency.
933 Specify a list of msg_ids as a location-based dependency.
938 This job will only be run on an engine where this dependency
934 This job will only be run on an engine where this dependency
939 is met.
935 is met.
940
936
941 timeout : float/int or None
937 timeout : float/int or None
942 Only for load-balanced execution (targets=None)
938 Only for load-balanced execution (targets=None)
943 Specify an amount of time (in seconds) for the scheduler to
939 Specify an amount of time (in seconds) for the scheduler to
944 wait for dependencies to be met before failing with a
940 wait for dependencies to be met before failing with a
945 DependencyTimeout.
941 DependencyTimeout.
946
942
947 retries : int
943 retries : int
948 Number of times a task will be retried on failure.
944 Number of times a task will be retried on failure.
949 """
945 """
950
946
951 super(LoadBalancedView, self).set_flags(**kwargs)
947 super(LoadBalancedView, self).set_flags(**kwargs)
952 for name in ('follow', 'after'):
948 for name in ('follow', 'after'):
953 if name in kwargs:
949 if name in kwargs:
954 value = kwargs[name]
950 value = kwargs[name]
955 if self._validate_dependency(value):
951 if self._validate_dependency(value):
956 setattr(self, name, value)
952 setattr(self, name, value)
957 else:
953 else:
958 raise ValueError("Invalid dependency: %r"%value)
954 raise ValueError("Invalid dependency: %r"%value)
959 if 'timeout' in kwargs:
955 if 'timeout' in kwargs:
960 t = kwargs['timeout']
956 t = kwargs['timeout']
961 if not isinstance(t, (int, float, type(None))):
957 if not isinstance(t, (int, float, type(None))):
962 if (not PY3) and (not isinstance(t, long)):
958 if (not PY3) and (not isinstance(t, long)):
963 raise TypeError("Invalid type for timeout: %r"%type(t))
959 raise TypeError("Invalid type for timeout: %r"%type(t))
964 if t is not None:
960 if t is not None:
965 if t < 0:
961 if t < 0:
966 raise ValueError("Invalid timeout: %s"%t)
962 raise ValueError("Invalid timeout: %s"%t)
967 self.timeout = t
963 self.timeout = t
968
964
969 @sync_results
965 @sync_results
970 @save_ids
966 @save_ids
971 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
967 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
972 after=None, follow=None, timeout=None,
968 after=None, follow=None, timeout=None,
973 targets=None, retries=None):
969 targets=None, retries=None):
974 """calls f(*args, **kwargs) on a remote engine, returning the result.
970 """calls f(*args, **kwargs) on a remote engine, returning the result.
975
971
976 This method temporarily sets all of `apply`'s flags for a single call.
972 This method temporarily sets all of `apply`'s flags for a single call.
977
973
978 Parameters
974 Parameters
979 ----------
975 ----------
980
976
981 f : callable
977 f : callable
982
978
983 args : list [default: empty]
979 args : list [default: empty]
984
980
985 kwargs : dict [default: empty]
981 kwargs : dict [default: empty]
986
982
987 block : bool [default: self.block]
983 block : bool [default: self.block]
988 whether to block
984 whether to block
989 track : bool [default: self.track]
985 track : bool [default: self.track]
990 whether to ask zmq to track the message, for safe non-copying sends
986 whether to ask zmq to track the message, for safe non-copying sends
991
987
992 !!!!!! TODO: THE REST HERE !!!!
988 !!!!!! TODO: THE REST HERE !!!!
993
989
994 Returns
990 Returns
995 -------
991 -------
996
992
997 if self.block is False:
993 if self.block is False:
998 returns AsyncResult
994 returns AsyncResult
999 else:
995 else:
1000 returns actual result of f(*args, **kwargs) on the engine(s)
996 returns actual result of f(*args, **kwargs) on the engine(s)
1001 This will be a list of self.targets is also a list (even length 1), or
997 This will be a list of self.targets is also a list (even length 1), or
1002 the single result if self.targets is an integer engine id
998 the single result if self.targets is an integer engine id
1003 """
999 """
1004
1000
1005 # validate whether we can run
1001 # validate whether we can run
1006 if self._socket.closed:
1002 if self._socket.closed:
1007 msg = "Task farming is disabled"
1003 msg = "Task farming is disabled"
1008 if self._task_scheme == 'pure':
1004 if self._task_scheme == 'pure':
1009 msg += " because the pure ZMQ scheduler cannot handle"
1005 msg += " because the pure ZMQ scheduler cannot handle"
1010 msg += " disappearing engines."
1006 msg += " disappearing engines."
1011 raise RuntimeError(msg)
1007 raise RuntimeError(msg)
1012
1008
1013 if self._task_scheme == 'pure':
1009 if self._task_scheme == 'pure':
1014 # pure zmq scheme doesn't support extra features
1010 # pure zmq scheme doesn't support extra features
1015 msg = "Pure ZMQ scheduler doesn't support the following flags:"
1011 msg = "Pure ZMQ scheduler doesn't support the following flags:"
1016 "follow, after, retries, targets, timeout"
1012 "follow, after, retries, targets, timeout"
1017 if (follow or after or retries or targets or timeout):
1013 if (follow or after or retries or targets or timeout):
1018 # hard fail on Scheduler flags
1014 # hard fail on Scheduler flags
1019 raise RuntimeError(msg)
1015 raise RuntimeError(msg)
1020 if isinstance(f, dependent):
1016 if isinstance(f, dependent):
1021 # soft warn on functional dependencies
1017 # soft warn on functional dependencies
1022 warnings.warn(msg, RuntimeWarning)
1018 warnings.warn(msg, RuntimeWarning)
1023
1019
1024 # build args
1020 # build args
1025 args = [] if args is None else args
1021 args = [] if args is None else args
1026 kwargs = {} if kwargs is None else kwargs
1022 kwargs = {} if kwargs is None else kwargs
1027 block = self.block if block is None else block
1023 block = self.block if block is None else block
1028 track = self.track if track is None else track
1024 track = self.track if track is None else track
1029 after = self.after if after is None else after
1025 after = self.after if after is None else after
1030 retries = self.retries if retries is None else retries
1026 retries = self.retries if retries is None else retries
1031 follow = self.follow if follow is None else follow
1027 follow = self.follow if follow is None else follow
1032 timeout = self.timeout if timeout is None else timeout
1028 timeout = self.timeout if timeout is None else timeout
1033 targets = self.targets if targets is None else targets
1029 targets = self.targets if targets is None else targets
1034
1030
1035 if not isinstance(retries, int):
1031 if not isinstance(retries, int):
1036 raise TypeError('retries must be int, not %r'%type(retries))
1032 raise TypeError('retries must be int, not %r'%type(retries))
1037
1033
1038 if targets is None:
1034 if targets is None:
1039 idents = []
1035 idents = []
1040 else:
1036 else:
1041 idents = self.client._build_targets(targets)[0]
1037 idents = self.client._build_targets(targets)[0]
1042 # ensure *not* bytes
1038 # ensure *not* bytes
1043 idents = [ ident.decode() for ident in idents ]
1039 idents = [ ident.decode() for ident in idents ]
1044
1040
1045 after = self._render_dependency(after)
1041 after = self._render_dependency(after)
1046 follow = self._render_dependency(follow)
1042 follow = self._render_dependency(follow)
1047 metadata = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
1043 metadata = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
1048
1044
1049 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
1045 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
1050 metadata=metadata)
1046 metadata=metadata)
1051 tracker = None if track is False else msg['tracker']
1047 tracker = None if track is False else msg['tracker']
1052
1048
1053 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f),
1049 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f),
1054 targets=None, tracker=tracker, owner=True,
1050 targets=None, tracker=tracker, owner=True,
1055 )
1051 )
1056 if block:
1052 if block:
1057 try:
1053 try:
1058 return ar.get()
1054 return ar.get()
1059 except KeyboardInterrupt:
1055 except KeyboardInterrupt:
1060 pass
1056 pass
1061 return ar
1057 return ar
1062
1058
1063 @sync_results
1059 @sync_results
1064 @save_ids
1060 @save_ids
1065 def map(self, f, *sequences, **kwargs):
1061 def map(self, f, *sequences, **kwargs):
1066 """``view.map(f, *sequences, block=self.block, chunksize=1, ordered=True)`` => list|AsyncMapResult
1062 """``view.map(f, *sequences, block=self.block, chunksize=1, ordered=True)`` => list|AsyncMapResult
1067
1063
1068 Parallel version of builtin `map`, load-balanced by this View.
1064 Parallel version of builtin `map`, load-balanced by this View.
1069
1065
1070 `block`, and `chunksize` can be specified by keyword only.
1066 `block`, and `chunksize` can be specified by keyword only.
1071
1067
1072 Each `chunksize` elements will be a separate task, and will be
1068 Each `chunksize` elements will be a separate task, and will be
1073 load-balanced. This lets individual elements be available for iteration
1069 load-balanced. This lets individual elements be available for iteration
1074 as soon as they arrive.
1070 as soon as they arrive.
1075
1071
1076 Parameters
1072 Parameters
1077 ----------
1073 ----------
1078
1074
1079 f : callable
1075 f : callable
1080 function to be mapped
1076 function to be mapped
1081 *sequences: one or more sequences of matching length
1077 *sequences: one or more sequences of matching length
1082 the sequences to be distributed and passed to `f`
1078 the sequences to be distributed and passed to `f`
1083 block : bool [default self.block]
1079 block : bool [default self.block]
1084 whether to wait for the result or not
1080 whether to wait for the result or not
1085 track : bool
1081 track : bool
1086 whether to create a MessageTracker to allow the user to
1082 whether to create a MessageTracker to allow the user to
1087 safely edit after arrays and buffers during non-copying
1083 safely edit after arrays and buffers during non-copying
1088 sends.
1084 sends.
1089 chunksize : int [default 1]
1085 chunksize : int [default 1]
1090 how many elements should be in each task.
1086 how many elements should be in each task.
1091 ordered : bool [default True]
1087 ordered : bool [default True]
1092 Whether the results should be gathered as they arrive, or enforce
1088 Whether the results should be gathered as they arrive, or enforce
1093 the order of submission.
1089 the order of submission.
1094
1090
1095 Only applies when iterating through AsyncMapResult as results arrive.
1091 Only applies when iterating through AsyncMapResult as results arrive.
1096 Has no effect when block=True.
1092 Has no effect when block=True.
1097
1093
1098 Returns
1094 Returns
1099 -------
1095 -------
1100
1096
1101 if block=False
1097 if block=False
1102 An :class:`~ipython_parallel.client.asyncresult.AsyncMapResult` instance.
1098 An :class:`~ipython_parallel.client.asyncresult.AsyncMapResult` instance.
1103 An object like AsyncResult, but which reassembles the sequence of results
1099 An object like AsyncResult, but which reassembles the sequence of results
1104 into a single list. AsyncMapResults can be iterated through before all
1100 into a single list. AsyncMapResults can be iterated through before all
1105 results are complete.
1101 results are complete.
1106 else
1102 else
1107 A list, the result of ``map(f,*sequences)``
1103 A list, the result of ``map(f,*sequences)``
1108 """
1104 """
1109
1105
1110 # default
1106 # default
1111 block = kwargs.get('block', self.block)
1107 block = kwargs.get('block', self.block)
1112 chunksize = kwargs.get('chunksize', 1)
1108 chunksize = kwargs.get('chunksize', 1)
1113 ordered = kwargs.get('ordered', True)
1109 ordered = kwargs.get('ordered', True)
1114
1110
1115 keyset = set(kwargs.keys())
1111 keyset = set(kwargs.keys())
1116 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1112 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1117 if extra_keys:
1113 if extra_keys:
1118 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1114 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1119
1115
1120 assert len(sequences) > 0, "must have some sequences to map onto!"
1116 assert len(sequences) > 0, "must have some sequences to map onto!"
1121
1117
1122 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1118 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1123 return pf.map(*sequences)
1119 return pf.map(*sequences)
1124
1120
1125 __all__ = ['LoadBalancedView', 'DirectView']
1121 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,1878 +1,1875 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 A lightweight Traits like module.
3 A lightweight Traits like module.
4
4
5 This is designed to provide a lightweight, simple, pure Python version of
5 This is designed to provide a lightweight, simple, pure Python version of
6 many of the capabilities of enthought.traits. This includes:
6 many of the capabilities of enthought.traits. This includes:
7
7
8 * Validation
8 * Validation
9 * Type specification with defaults
9 * Type specification with defaults
10 * Static and dynamic notification
10 * Static and dynamic notification
11 * Basic predefined types
11 * Basic predefined types
12 * An API that is similar to enthought.traits
12 * An API that is similar to enthought.traits
13
13
14 We don't support:
14 We don't support:
15
15
16 * Delegation
16 * Delegation
17 * Automatic GUI generation
17 * Automatic GUI generation
18 * A full set of trait types. Most importantly, we don't provide container
18 * A full set of trait types. Most importantly, we don't provide container
19 traits (list, dict, tuple) that can trigger notifications if their
19 traits (list, dict, tuple) that can trigger notifications if their
20 contents change.
20 contents change.
21 * API compatibility with enthought.traits
21 * API compatibility with enthought.traits
22
22
23 There are also some important difference in our design:
23 There are also some important difference in our design:
24
24
25 * enthought.traits does not validate default values. We do.
25 * enthought.traits does not validate default values. We do.
26
26
27 We choose to create this module because we need these capabilities, but
27 We choose to create this module because we need these capabilities, but
28 we need them to be pure Python so they work in all Python implementations,
28 we need them to be pure Python so they work in all Python implementations,
29 including Jython and IronPython.
29 including Jython and IronPython.
30
30
31 Inheritance diagram:
31 Inheritance diagram:
32
32
33 .. inheritance-diagram:: IPython.utils.traitlets
33 .. inheritance-diagram:: IPython.utils.traitlets
34 :parts: 3
34 :parts: 3
35 """
35 """
36
36
37 # Copyright (c) IPython Development Team.
37 # Copyright (c) IPython Development Team.
38 # Distributed under the terms of the Modified BSD License.
38 # Distributed under the terms of the Modified BSD License.
39 #
39 #
40 # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
40 # Adapted from enthought.traits, Copyright (c) Enthought, Inc.,
41 # also under the terms of the Modified BSD License.
41 # also under the terms of the Modified BSD License.
42
42
43 import contextlib
43 import contextlib
44 import inspect
44 import inspect
45 import re
45 import re
46 import sys
46 import sys
47 import types
47 import types
48 from types import FunctionType
48 from types import FunctionType
49 try:
49 try:
50 from types import ClassType, InstanceType
50 from types import ClassType, InstanceType
51 ClassTypes = (ClassType, type)
51 ClassTypes = (ClassType, type)
52 except:
52 except:
53 ClassTypes = (type,)
53 ClassTypes = (type,)
54 from warnings import warn
54 from warnings import warn
55
55
56 from IPython.utils import py3compat
56 from IPython.utils import py3compat
57 from IPython.utils import eventful
57 from IPython.utils import eventful
58 from IPython.utils.getargspec import getargspec
58 from IPython.utils.getargspec import getargspec
59 from IPython.utils.importstring import import_item
59 from IPython.utils.importstring import import_item
60 from IPython.utils.py3compat import iteritems, string_types
60 from IPython.utils.py3compat import iteritems, string_types
61 from IPython.testing.skipdoctest import skip_doctest
62
61
63 from .sentinel import Sentinel
62 from .sentinel import Sentinel
64 SequenceTypes = (list, tuple, set, frozenset)
63 SequenceTypes = (list, tuple, set, frozenset)
65
64
66 #-----------------------------------------------------------------------------
65 #-----------------------------------------------------------------------------
67 # Basic classes
66 # Basic classes
68 #-----------------------------------------------------------------------------
67 #-----------------------------------------------------------------------------
69
68
70
69
71 NoDefaultSpecified = Sentinel('NoDefaultSpecified', __name__,
70 NoDefaultSpecified = Sentinel('NoDefaultSpecified', __name__,
72 '''
71 '''
73 Used in Traitlets to specify that no defaults are set in kwargs
72 Used in Traitlets to specify that no defaults are set in kwargs
74 '''
73 '''
75 )
74 )
76
75
77
76
78 class Undefined ( object ): pass
77 class Undefined ( object ): pass
79 Undefined = Undefined()
78 Undefined = Undefined()
80
79
81 class TraitError(Exception):
80 class TraitError(Exception):
82 pass
81 pass
83
82
84 #-----------------------------------------------------------------------------
83 #-----------------------------------------------------------------------------
85 # Utilities
84 # Utilities
86 #-----------------------------------------------------------------------------
85 #-----------------------------------------------------------------------------
87
86
88
87
89 def class_of ( object ):
88 def class_of ( object ):
90 """ Returns a string containing the class name of an object with the
89 """ Returns a string containing the class name of an object with the
91 correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image',
90 correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image',
92 'a PlotValue').
91 'a PlotValue').
93 """
92 """
94 if isinstance( object, py3compat.string_types ):
93 if isinstance( object, py3compat.string_types ):
95 return add_article( object )
94 return add_article( object )
96
95
97 return add_article( object.__class__.__name__ )
96 return add_article( object.__class__.__name__ )
98
97
99
98
100 def add_article ( name ):
99 def add_article ( name ):
101 """ Returns a string containing the correct indefinite article ('a' or 'an')
100 """ Returns a string containing the correct indefinite article ('a' or 'an')
102 prefixed to the specified string.
101 prefixed to the specified string.
103 """
102 """
104 if name[:1].lower() in 'aeiou':
103 if name[:1].lower() in 'aeiou':
105 return 'an ' + name
104 return 'an ' + name
106
105
107 return 'a ' + name
106 return 'a ' + name
108
107
109
108
110 def repr_type(obj):
109 def repr_type(obj):
111 """ Return a string representation of a value and its type for readable
110 """ Return a string representation of a value and its type for readable
112 error messages.
111 error messages.
113 """
112 """
114 the_type = type(obj)
113 the_type = type(obj)
115 if (not py3compat.PY3) and the_type is InstanceType:
114 if (not py3compat.PY3) and the_type is InstanceType:
116 # Old-style class.
115 # Old-style class.
117 the_type = obj.__class__
116 the_type = obj.__class__
118 msg = '%r %r' % (obj, the_type)
117 msg = '%r %r' % (obj, the_type)
119 return msg
118 return msg
120
119
121
120
122 def is_trait(t):
121 def is_trait(t):
123 """ Returns whether the given value is an instance or subclass of TraitType.
122 """ Returns whether the given value is an instance or subclass of TraitType.
124 """
123 """
125 return (isinstance(t, TraitType) or
124 return (isinstance(t, TraitType) or
126 (isinstance(t, type) and issubclass(t, TraitType)))
125 (isinstance(t, type) and issubclass(t, TraitType)))
127
126
128
127
129 def parse_notifier_name(name):
128 def parse_notifier_name(name):
130 """Convert the name argument to a list of names.
129 """Convert the name argument to a list of names.
131
130
132 Examples
131 Examples
133 --------
132 --------
134
133
135 >>> parse_notifier_name('a')
134 >>> parse_notifier_name('a')
136 ['a']
135 ['a']
137 >>> parse_notifier_name(['a','b'])
136 >>> parse_notifier_name(['a','b'])
138 ['a', 'b']
137 ['a', 'b']
139 >>> parse_notifier_name(None)
138 >>> parse_notifier_name(None)
140 ['anytrait']
139 ['anytrait']
141 """
140 """
142 if isinstance(name, string_types):
141 if isinstance(name, string_types):
143 return [name]
142 return [name]
144 elif name is None:
143 elif name is None:
145 return ['anytrait']
144 return ['anytrait']
146 elif isinstance(name, (list, tuple)):
145 elif isinstance(name, (list, tuple)):
147 for n in name:
146 for n in name:
148 assert isinstance(n, string_types), "names must be strings"
147 assert isinstance(n, string_types), "names must be strings"
149 return name
148 return name
150
149
151
150
152 class _SimpleTest:
151 class _SimpleTest:
153 def __init__ ( self, value ): self.value = value
152 def __init__ ( self, value ): self.value = value
154 def __call__ ( self, test ):
153 def __call__ ( self, test ):
155 return test == self.value
154 return test == self.value
156 def __repr__(self):
155 def __repr__(self):
157 return "<SimpleTest(%r)" % self.value
156 return "<SimpleTest(%r)" % self.value
158 def __str__(self):
157 def __str__(self):
159 return self.__repr__()
158 return self.__repr__()
160
159
161
160
162 def getmembers(object, predicate=None):
161 def getmembers(object, predicate=None):
163 """A safe version of inspect.getmembers that handles missing attributes.
162 """A safe version of inspect.getmembers that handles missing attributes.
164
163
165 This is useful when there are descriptor based attributes that for
164 This is useful when there are descriptor based attributes that for
166 some reason raise AttributeError even though they exist. This happens
165 some reason raise AttributeError even though they exist. This happens
167 in zope.inteface with the __provides__ attribute.
166 in zope.inteface with the __provides__ attribute.
168 """
167 """
169 results = []
168 results = []
170 for key in dir(object):
169 for key in dir(object):
171 try:
170 try:
172 value = getattr(object, key)
171 value = getattr(object, key)
173 except AttributeError:
172 except AttributeError:
174 pass
173 pass
175 else:
174 else:
176 if not predicate or predicate(value):
175 if not predicate or predicate(value):
177 results.append((key, value))
176 results.append((key, value))
178 results.sort()
177 results.sort()
179 return results
178 return results
180
179
181 def _validate_link(*tuples):
180 def _validate_link(*tuples):
182 """Validate arguments for traitlet link functions"""
181 """Validate arguments for traitlet link functions"""
183 for t in tuples:
182 for t in tuples:
184 if not len(t) == 2:
183 if not len(t) == 2:
185 raise TypeError("Each linked traitlet must be specified as (HasTraits, 'trait_name'), not %r" % t)
184 raise TypeError("Each linked traitlet must be specified as (HasTraits, 'trait_name'), not %r" % t)
186 obj, trait_name = t
185 obj, trait_name = t
187 if not isinstance(obj, HasTraits):
186 if not isinstance(obj, HasTraits):
188 raise TypeError("Each object must be HasTraits, not %r" % type(obj))
187 raise TypeError("Each object must be HasTraits, not %r" % type(obj))
189 if not trait_name in obj.traits():
188 if not trait_name in obj.traits():
190 raise TypeError("%r has no trait %r" % (obj, trait_name))
189 raise TypeError("%r has no trait %r" % (obj, trait_name))
191
190
192 @skip_doctest
193 class link(object):
191 class link(object):
194 """Link traits from different objects together so they remain in sync.
192 """Link traits from different objects together so they remain in sync.
195
193
196 Parameters
194 Parameters
197 ----------
195 ----------
198 *args : pairs of objects/attributes
196 *args : pairs of objects/attributes
199
197
200 Examples
198 Examples
201 --------
199 --------
202
200
203 >>> c = link((obj1, 'value'), (obj2, 'value'), (obj3, 'value'))
201 >>> c = link((obj1, 'value'), (obj2, 'value'), (obj3, 'value'))
204 >>> obj1.value = 5 # updates other objects as well
202 >>> obj1.value = 5 # updates other objects as well
205 """
203 """
206 updating = False
204 updating = False
207 def __init__(self, *args):
205 def __init__(self, *args):
208 if len(args) < 2:
206 if len(args) < 2:
209 raise TypeError('At least two traitlets must be provided.')
207 raise TypeError('At least two traitlets must be provided.')
210 _validate_link(*args)
208 _validate_link(*args)
211
209
212 self.objects = {}
210 self.objects = {}
213
211
214 initial = getattr(args[0][0], args[0][1])
212 initial = getattr(args[0][0], args[0][1])
215 for obj, attr in args:
213 for obj, attr in args:
216 setattr(obj, attr, initial)
214 setattr(obj, attr, initial)
217
215
218 callback = self._make_closure(obj, attr)
216 callback = self._make_closure(obj, attr)
219 obj.on_trait_change(callback, attr)
217 obj.on_trait_change(callback, attr)
220 self.objects[(obj, attr)] = callback
218 self.objects[(obj, attr)] = callback
221
219
222 @contextlib.contextmanager
220 @contextlib.contextmanager
223 def _busy_updating(self):
221 def _busy_updating(self):
224 self.updating = True
222 self.updating = True
225 try:
223 try:
226 yield
224 yield
227 finally:
225 finally:
228 self.updating = False
226 self.updating = False
229
227
230 def _make_closure(self, sending_obj, sending_attr):
228 def _make_closure(self, sending_obj, sending_attr):
231 def update(name, old, new):
229 def update(name, old, new):
232 self._update(sending_obj, sending_attr, new)
230 self._update(sending_obj, sending_attr, new)
233 return update
231 return update
234
232
235 def _update(self, sending_obj, sending_attr, new):
233 def _update(self, sending_obj, sending_attr, new):
236 if self.updating:
234 if self.updating:
237 return
235 return
238 with self._busy_updating():
236 with self._busy_updating():
239 for obj, attr in self.objects.keys():
237 for obj, attr in self.objects.keys():
240 setattr(obj, attr, new)
238 setattr(obj, attr, new)
241
239
242 def unlink(self):
240 def unlink(self):
243 for key, callback in self.objects.items():
241 for key, callback in self.objects.items():
244 (obj, attr) = key
242 (obj, attr) = key
245 obj.on_trait_change(callback, attr, remove=True)
243 obj.on_trait_change(callback, attr, remove=True)
246
244
247 @skip_doctest
248 class directional_link(object):
245 class directional_link(object):
249 """Link the trait of a source object with traits of target objects.
246 """Link the trait of a source object with traits of target objects.
250
247
251 Parameters
248 Parameters
252 ----------
249 ----------
253 source : pair of object, name
250 source : pair of object, name
254 targets : pairs of objects/attributes
251 targets : pairs of objects/attributes
255
252
256 Examples
253 Examples
257 --------
254 --------
258
255
259 >>> c = directional_link((src, 'value'), (tgt1, 'value'), (tgt2, 'value'))
256 >>> c = directional_link((src, 'value'), (tgt1, 'value'), (tgt2, 'value'))
260 >>> src.value = 5 # updates target objects
257 >>> src.value = 5 # updates target objects
261 >>> tgt1.value = 6 # does not update other objects
258 >>> tgt1.value = 6 # does not update other objects
262 """
259 """
263 updating = False
260 updating = False
264
261
265 def __init__(self, source, *targets):
262 def __init__(self, source, *targets):
266 if len(targets) < 1:
263 if len(targets) < 1:
267 raise TypeError('At least two traitlets must be provided.')
264 raise TypeError('At least two traitlets must be provided.')
268 _validate_link(source, *targets)
265 _validate_link(source, *targets)
269 self.source = source
266 self.source = source
270 self.targets = targets
267 self.targets = targets
271
268
272 # Update current value
269 # Update current value
273 src_attr_value = getattr(source[0], source[1])
270 src_attr_value = getattr(source[0], source[1])
274 for obj, attr in targets:
271 for obj, attr in targets:
275 setattr(obj, attr, src_attr_value)
272 setattr(obj, attr, src_attr_value)
276
273
277 # Wire
274 # Wire
278 self.source[0].on_trait_change(self._update, self.source[1])
275 self.source[0].on_trait_change(self._update, self.source[1])
279
276
280 @contextlib.contextmanager
277 @contextlib.contextmanager
281 def _busy_updating(self):
278 def _busy_updating(self):
282 self.updating = True
279 self.updating = True
283 try:
280 try:
284 yield
281 yield
285 finally:
282 finally:
286 self.updating = False
283 self.updating = False
287
284
288 def _update(self, name, old, new):
285 def _update(self, name, old, new):
289 if self.updating:
286 if self.updating:
290 return
287 return
291 with self._busy_updating():
288 with self._busy_updating():
292 for obj, attr in self.targets:
289 for obj, attr in self.targets:
293 setattr(obj, attr, new)
290 setattr(obj, attr, new)
294
291
295 def unlink(self):
292 def unlink(self):
296 self.source[0].on_trait_change(self._update, self.source[1], remove=True)
293 self.source[0].on_trait_change(self._update, self.source[1], remove=True)
297 self.source = None
294 self.source = None
298 self.targets = []
295 self.targets = []
299
296
300 dlink = directional_link
297 dlink = directional_link
301
298
302
299
303 #-----------------------------------------------------------------------------
300 #-----------------------------------------------------------------------------
304 # Base TraitType for all traits
301 # Base TraitType for all traits
305 #-----------------------------------------------------------------------------
302 #-----------------------------------------------------------------------------
306
303
307
304
308 class TraitType(object):
305 class TraitType(object):
309 """A base class for all trait descriptors.
306 """A base class for all trait descriptors.
310
307
311 Notes
308 Notes
312 -----
309 -----
313 Our implementation of traits is based on Python's descriptor
310 Our implementation of traits is based on Python's descriptor
314 prototol. This class is the base class for all such descriptors. The
311 prototol. This class is the base class for all such descriptors. The
315 only magic we use is a custom metaclass for the main :class:`HasTraits`
312 only magic we use is a custom metaclass for the main :class:`HasTraits`
316 class that does the following:
313 class that does the following:
317
314
318 1. Sets the :attr:`name` attribute of every :class:`TraitType`
315 1. Sets the :attr:`name` attribute of every :class:`TraitType`
319 instance in the class dict to the name of the attribute.
316 instance in the class dict to the name of the attribute.
320 2. Sets the :attr:`this_class` attribute of every :class:`TraitType`
317 2. Sets the :attr:`this_class` attribute of every :class:`TraitType`
321 instance in the class dict to the *class* that declared the trait.
318 instance in the class dict to the *class* that declared the trait.
322 This is used by the :class:`This` trait to allow subclasses to
319 This is used by the :class:`This` trait to allow subclasses to
323 accept superclasses for :class:`This` values.
320 accept superclasses for :class:`This` values.
324 """
321 """
325
322
326 metadata = {}
323 metadata = {}
327 default_value = Undefined
324 default_value = Undefined
328 allow_none = False
325 allow_none = False
329 info_text = 'any value'
326 info_text = 'any value'
330
327
331 def __init__(self, default_value=NoDefaultSpecified, allow_none=None, **metadata):
328 def __init__(self, default_value=NoDefaultSpecified, allow_none=None, **metadata):
332 """Create a TraitType.
329 """Create a TraitType.
333 """
330 """
334 if default_value is not NoDefaultSpecified:
331 if default_value is not NoDefaultSpecified:
335 self.default_value = default_value
332 self.default_value = default_value
336 if allow_none is not None:
333 if allow_none is not None:
337 self.allow_none = allow_none
334 self.allow_none = allow_none
338
335
339 if 'default' in metadata:
336 if 'default' in metadata:
340 # Warn the user that they probably meant default_value.
337 # Warn the user that they probably meant default_value.
341 warn(
338 warn(
342 "Parameter 'default' passed to TraitType. "
339 "Parameter 'default' passed to TraitType. "
343 "Did you mean 'default_value'?"
340 "Did you mean 'default_value'?"
344 )
341 )
345
342
346 if len(metadata) > 0:
343 if len(metadata) > 0:
347 if len(self.metadata) > 0:
344 if len(self.metadata) > 0:
348 self._metadata = self.metadata.copy()
345 self._metadata = self.metadata.copy()
349 self._metadata.update(metadata)
346 self._metadata.update(metadata)
350 else:
347 else:
351 self._metadata = metadata
348 self._metadata = metadata
352 else:
349 else:
353 self._metadata = self.metadata
350 self._metadata = self.metadata
354
351
355 self.init()
352 self.init()
356
353
357 def init(self):
354 def init(self):
358 pass
355 pass
359
356
360 def get_default_value(self):
357 def get_default_value(self):
361 """Create a new instance of the default value."""
358 """Create a new instance of the default value."""
362 return self.default_value
359 return self.default_value
363
360
364 def instance_init(self):
361 def instance_init(self):
365 """Part of the initialization which may depends on the underlying
362 """Part of the initialization which may depends on the underlying
366 HasTraits instance.
363 HasTraits instance.
367
364
368 It is typically overloaded for specific trait types.
365 It is typically overloaded for specific trait types.
369
366
370 This method is called by :meth:`HasTraits.__new__` and in the
367 This method is called by :meth:`HasTraits.__new__` and in the
371 :meth:`TraitType.instance_init` method of trait types holding
368 :meth:`TraitType.instance_init` method of trait types holding
372 other trait types.
369 other trait types.
373 """
370 """
374 pass
371 pass
375
372
376 def init_default_value(self, obj):
373 def init_default_value(self, obj):
377 """Instantiate the default value for the trait type.
374 """Instantiate the default value for the trait type.
378
375
379 This method is called by :meth:`TraitType.set_default_value` in the
376 This method is called by :meth:`TraitType.set_default_value` in the
380 case a default value is provided at construction time or later when
377 case a default value is provided at construction time or later when
381 accessing the trait value for the first time in
378 accessing the trait value for the first time in
382 :meth:`HasTraits.__get__`.
379 :meth:`HasTraits.__get__`.
383 """
380 """
384 value = self.get_default_value()
381 value = self.get_default_value()
385 value = self._validate(obj, value)
382 value = self._validate(obj, value)
386 obj._trait_values[self.name] = value
383 obj._trait_values[self.name] = value
387 return value
384 return value
388
385
389 def set_default_value(self, obj):
386 def set_default_value(self, obj):
390 """Set the default value on a per instance basis.
387 """Set the default value on a per instance basis.
391
388
392 This method is called by :meth:`HasTraits.__new__` to instantiate and
389 This method is called by :meth:`HasTraits.__new__` to instantiate and
393 validate the default value. The creation and validation of
390 validate the default value. The creation and validation of
394 default values must be delayed until the parent :class:`HasTraits`
391 default values must be delayed until the parent :class:`HasTraits`
395 class has been instantiated.
392 class has been instantiated.
396 Parameters
393 Parameters
397 ----------
394 ----------
398 obj : :class:`HasTraits` instance
395 obj : :class:`HasTraits` instance
399 The parent :class:`HasTraits` instance that has just been
396 The parent :class:`HasTraits` instance that has just been
400 created.
397 created.
401 """
398 """
402 # Check for a deferred initializer defined in the same class as the
399 # Check for a deferred initializer defined in the same class as the
403 # trait declaration or above.
400 # trait declaration or above.
404 mro = type(obj).mro()
401 mro = type(obj).mro()
405 meth_name = '_%s_default' % self.name
402 meth_name = '_%s_default' % self.name
406 for cls in mro[:mro.index(self.this_class)+1]:
403 for cls in mro[:mro.index(self.this_class)+1]:
407 if meth_name in cls.__dict__:
404 if meth_name in cls.__dict__:
408 break
405 break
409 else:
406 else:
410 # We didn't find one. Do static initialization.
407 # We didn't find one. Do static initialization.
411 self.init_default_value(obj)
408 self.init_default_value(obj)
412 return
409 return
413 # Complete the dynamic initialization.
410 # Complete the dynamic initialization.
414 obj._trait_dyn_inits[self.name] = meth_name
411 obj._trait_dyn_inits[self.name] = meth_name
415
412
416 def __get__(self, obj, cls=None):
413 def __get__(self, obj, cls=None):
417 """Get the value of the trait by self.name for the instance.
414 """Get the value of the trait by self.name for the instance.
418
415
419 Default values are instantiated when :meth:`HasTraits.__new__`
416 Default values are instantiated when :meth:`HasTraits.__new__`
420 is called. Thus by the time this method gets called either the
417 is called. Thus by the time this method gets called either the
421 default value or a user defined value (they called :meth:`__set__`)
418 default value or a user defined value (they called :meth:`__set__`)
422 is in the :class:`HasTraits` instance.
419 is in the :class:`HasTraits` instance.
423 """
420 """
424 if obj is None:
421 if obj is None:
425 return self
422 return self
426 else:
423 else:
427 try:
424 try:
428 value = obj._trait_values[self.name]
425 value = obj._trait_values[self.name]
429 except KeyError:
426 except KeyError:
430 # Check for a dynamic initializer.
427 # Check for a dynamic initializer.
431 if self.name in obj._trait_dyn_inits:
428 if self.name in obj._trait_dyn_inits:
432 method = getattr(obj, obj._trait_dyn_inits[self.name])
429 method = getattr(obj, obj._trait_dyn_inits[self.name])
433 value = method()
430 value = method()
434 # FIXME: Do we really validate here?
431 # FIXME: Do we really validate here?
435 value = self._validate(obj, value)
432 value = self._validate(obj, value)
436 obj._trait_values[self.name] = value
433 obj._trait_values[self.name] = value
437 return value
434 return value
438 else:
435 else:
439 return self.init_default_value(obj)
436 return self.init_default_value(obj)
440 except Exception:
437 except Exception:
441 # HasTraits should call set_default_value to populate
438 # HasTraits should call set_default_value to populate
442 # this. So this should never be reached.
439 # this. So this should never be reached.
443 raise TraitError('Unexpected error in TraitType: '
440 raise TraitError('Unexpected error in TraitType: '
444 'default value not set properly')
441 'default value not set properly')
445 else:
442 else:
446 return value
443 return value
447
444
448 def __set__(self, obj, value):
445 def __set__(self, obj, value):
449 new_value = self._validate(obj, value)
446 new_value = self._validate(obj, value)
450 try:
447 try:
451 old_value = obj._trait_values[self.name]
448 old_value = obj._trait_values[self.name]
452 except KeyError:
449 except KeyError:
453 old_value = Undefined
450 old_value = Undefined
454
451
455 obj._trait_values[self.name] = new_value
452 obj._trait_values[self.name] = new_value
456 try:
453 try:
457 silent = bool(old_value == new_value)
454 silent = bool(old_value == new_value)
458 except:
455 except:
459 # if there is an error in comparing, default to notify
456 # if there is an error in comparing, default to notify
460 silent = False
457 silent = False
461 if silent is not True:
458 if silent is not True:
462 # we explicitly compare silent to True just in case the equality
459 # we explicitly compare silent to True just in case the equality
463 # comparison above returns something other than True/False
460 # comparison above returns something other than True/False
464 obj._notify_trait(self.name, old_value, new_value)
461 obj._notify_trait(self.name, old_value, new_value)
465
462
466 def _validate(self, obj, value):
463 def _validate(self, obj, value):
467 if value is None and self.allow_none:
464 if value is None and self.allow_none:
468 return value
465 return value
469 if hasattr(self, 'validate'):
466 if hasattr(self, 'validate'):
470 value = self.validate(obj, value)
467 value = self.validate(obj, value)
471 if obj._cross_validation_lock is False:
468 if obj._cross_validation_lock is False:
472 value = self._cross_validate(obj, value)
469 value = self._cross_validate(obj, value)
473 return value
470 return value
474
471
475 def _cross_validate(self, obj, value):
472 def _cross_validate(self, obj, value):
476 if hasattr(obj, '_%s_validate' % self.name):
473 if hasattr(obj, '_%s_validate' % self.name):
477 cross_validate = getattr(obj, '_%s_validate' % self.name)
474 cross_validate = getattr(obj, '_%s_validate' % self.name)
478 value = cross_validate(value, self)
475 value = cross_validate(value, self)
479 return value
476 return value
480
477
481 def __or__(self, other):
478 def __or__(self, other):
482 if isinstance(other, Union):
479 if isinstance(other, Union):
483 return Union([self] + other.trait_types)
480 return Union([self] + other.trait_types)
484 else:
481 else:
485 return Union([self, other])
482 return Union([self, other])
486
483
487 def info(self):
484 def info(self):
488 return self.info_text
485 return self.info_text
489
486
490 def error(self, obj, value):
487 def error(self, obj, value):
491 if obj is not None:
488 if obj is not None:
492 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
489 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
493 % (self.name, class_of(obj),
490 % (self.name, class_of(obj),
494 self.info(), repr_type(value))
491 self.info(), repr_type(value))
495 else:
492 else:
496 e = "The '%s' trait must be %s, but a value of %r was specified." \
493 e = "The '%s' trait must be %s, but a value of %r was specified." \
497 % (self.name, self.info(), repr_type(value))
494 % (self.name, self.info(), repr_type(value))
498 raise TraitError(e)
495 raise TraitError(e)
499
496
500 def get_metadata(self, key, default=None):
497 def get_metadata(self, key, default=None):
501 return getattr(self, '_metadata', {}).get(key, default)
498 return getattr(self, '_metadata', {}).get(key, default)
502
499
503 def set_metadata(self, key, value):
500 def set_metadata(self, key, value):
504 getattr(self, '_metadata', {})[key] = value
501 getattr(self, '_metadata', {})[key] = value
505
502
506
503
507 #-----------------------------------------------------------------------------
504 #-----------------------------------------------------------------------------
508 # The HasTraits implementation
505 # The HasTraits implementation
509 #-----------------------------------------------------------------------------
506 #-----------------------------------------------------------------------------
510
507
511
508
512 class MetaHasTraits(type):
509 class MetaHasTraits(type):
513 """A metaclass for HasTraits.
510 """A metaclass for HasTraits.
514
511
515 This metaclass makes sure that any TraitType class attributes are
512 This metaclass makes sure that any TraitType class attributes are
516 instantiated and sets their name attribute.
513 instantiated and sets their name attribute.
517 """
514 """
518
515
519 def __new__(mcls, name, bases, classdict):
516 def __new__(mcls, name, bases, classdict):
520 """Create the HasTraits class.
517 """Create the HasTraits class.
521
518
522 This instantiates all TraitTypes in the class dict and sets their
519 This instantiates all TraitTypes in the class dict and sets their
523 :attr:`name` attribute.
520 :attr:`name` attribute.
524 """
521 """
525 # print "MetaHasTraitlets (mcls, name): ", mcls, name
522 # print "MetaHasTraitlets (mcls, name): ", mcls, name
526 # print "MetaHasTraitlets (bases): ", bases
523 # print "MetaHasTraitlets (bases): ", bases
527 # print "MetaHasTraitlets (classdict): ", classdict
524 # print "MetaHasTraitlets (classdict): ", classdict
528 for k,v in iteritems(classdict):
525 for k,v in iteritems(classdict):
529 if isinstance(v, TraitType):
526 if isinstance(v, TraitType):
530 v.name = k
527 v.name = k
531 elif inspect.isclass(v):
528 elif inspect.isclass(v):
532 if issubclass(v, TraitType):
529 if issubclass(v, TraitType):
533 vinst = v()
530 vinst = v()
534 vinst.name = k
531 vinst.name = k
535 classdict[k] = vinst
532 classdict[k] = vinst
536 return super(MetaHasTraits, mcls).__new__(mcls, name, bases, classdict)
533 return super(MetaHasTraits, mcls).__new__(mcls, name, bases, classdict)
537
534
538 def __init__(cls, name, bases, classdict):
535 def __init__(cls, name, bases, classdict):
539 """Finish initializing the HasTraits class.
536 """Finish initializing the HasTraits class.
540
537
541 This sets the :attr:`this_class` attribute of each TraitType in the
538 This sets the :attr:`this_class` attribute of each TraitType in the
542 class dict to the newly created class ``cls``.
539 class dict to the newly created class ``cls``.
543 """
540 """
544 for k, v in iteritems(classdict):
541 for k, v in iteritems(classdict):
545 if isinstance(v, TraitType):
542 if isinstance(v, TraitType):
546 v.this_class = cls
543 v.this_class = cls
547 super(MetaHasTraits, cls).__init__(name, bases, classdict)
544 super(MetaHasTraits, cls).__init__(name, bases, classdict)
548
545
549
546
550 class HasTraits(py3compat.with_metaclass(MetaHasTraits, object)):
547 class HasTraits(py3compat.with_metaclass(MetaHasTraits, object)):
551
548
552 def __new__(cls, *args, **kw):
549 def __new__(cls, *args, **kw):
553 # This is needed because object.__new__ only accepts
550 # This is needed because object.__new__ only accepts
554 # the cls argument.
551 # the cls argument.
555 new_meth = super(HasTraits, cls).__new__
552 new_meth = super(HasTraits, cls).__new__
556 if new_meth is object.__new__:
553 if new_meth is object.__new__:
557 inst = new_meth(cls)
554 inst = new_meth(cls)
558 else:
555 else:
559 inst = new_meth(cls, **kw)
556 inst = new_meth(cls, **kw)
560 inst._trait_values = {}
557 inst._trait_values = {}
561 inst._trait_notifiers = {}
558 inst._trait_notifiers = {}
562 inst._trait_dyn_inits = {}
559 inst._trait_dyn_inits = {}
563 inst._cross_validation_lock = True
560 inst._cross_validation_lock = True
564 # Here we tell all the TraitType instances to set their default
561 # Here we tell all the TraitType instances to set their default
565 # values on the instance.
562 # values on the instance.
566 for key in dir(cls):
563 for key in dir(cls):
567 # Some descriptors raise AttributeError like zope.interface's
564 # Some descriptors raise AttributeError like zope.interface's
568 # __provides__ attributes even though they exist. This causes
565 # __provides__ attributes even though they exist. This causes
569 # AttributeErrors even though they are listed in dir(cls).
566 # AttributeErrors even though they are listed in dir(cls).
570 try:
567 try:
571 value = getattr(cls, key)
568 value = getattr(cls, key)
572 except AttributeError:
569 except AttributeError:
573 pass
570 pass
574 else:
571 else:
575 if isinstance(value, TraitType):
572 if isinstance(value, TraitType):
576 value.instance_init()
573 value.instance_init()
577 if key not in kw:
574 if key not in kw:
578 value.set_default_value(inst)
575 value.set_default_value(inst)
579 inst._cross_validation_lock = False
576 inst._cross_validation_lock = False
580 return inst
577 return inst
581
578
582 def __init__(self, *args, **kw):
579 def __init__(self, *args, **kw):
583 # Allow trait values to be set using keyword arguments.
580 # Allow trait values to be set using keyword arguments.
584 # We need to use setattr for this to trigger validation and
581 # We need to use setattr for this to trigger validation and
585 # notifications.
582 # notifications.
586 with self.hold_trait_notifications():
583 with self.hold_trait_notifications():
587 for key, value in iteritems(kw):
584 for key, value in iteritems(kw):
588 setattr(self, key, value)
585 setattr(self, key, value)
589
586
590 @contextlib.contextmanager
587 @contextlib.contextmanager
591 def hold_trait_notifications(self):
588 def hold_trait_notifications(self):
592 """Context manager for bundling trait change notifications and cross
589 """Context manager for bundling trait change notifications and cross
593 validation.
590 validation.
594
591
595 Use this when doing multiple trait assignments (init, config), to avoid
592 Use this when doing multiple trait assignments (init, config), to avoid
596 race conditions in trait notifiers requesting other trait values.
593 race conditions in trait notifiers requesting other trait values.
597 All trait notifications will fire after all values have been assigned.
594 All trait notifications will fire after all values have been assigned.
598 """
595 """
599 if self._cross_validation_lock is True:
596 if self._cross_validation_lock is True:
600 yield
597 yield
601 return
598 return
602 else:
599 else:
603 self._cross_validation_lock = True
600 self._cross_validation_lock = True
604 cache = {}
601 cache = {}
605 notifications = {}
602 notifications = {}
606 _notify_trait = self._notify_trait
603 _notify_trait = self._notify_trait
607
604
608 def cache_values(*a):
605 def cache_values(*a):
609 cache[a[0]] = a
606 cache[a[0]] = a
610
607
611 def hold_notifications(*a):
608 def hold_notifications(*a):
612 notifications[a[0]] = a
609 notifications[a[0]] = a
613
610
614 self._notify_trait = cache_values
611 self._notify_trait = cache_values
615
612
616 try:
613 try:
617 yield
614 yield
618 finally:
615 finally:
619 try:
616 try:
620 self._notify_trait = hold_notifications
617 self._notify_trait = hold_notifications
621 for name in cache:
618 for name in cache:
622 if hasattr(self, '_%s_validate' % name):
619 if hasattr(self, '_%s_validate' % name):
623 cross_validate = getattr(self, '_%s_validate' % name)
620 cross_validate = getattr(self, '_%s_validate' % name)
624 setattr(self, name, cross_validate(getattr(self, name), self))
621 setattr(self, name, cross_validate(getattr(self, name), self))
625 except TraitError as e:
622 except TraitError as e:
626 self._notify_trait = lambda *x: None
623 self._notify_trait = lambda *x: None
627 for name in cache:
624 for name in cache:
628 if cache[name][1] is not Undefined:
625 if cache[name][1] is not Undefined:
629 setattr(self, name, cache[name][1])
626 setattr(self, name, cache[name][1])
630 else:
627 else:
631 delattr(self, name)
628 delattr(self, name)
632 cache = {}
629 cache = {}
633 notifications = {}
630 notifications = {}
634 raise e
631 raise e
635 finally:
632 finally:
636 self._notify_trait = _notify_trait
633 self._notify_trait = _notify_trait
637 self._cross_validation_lock = False
634 self._cross_validation_lock = False
638 if isinstance(_notify_trait, types.MethodType):
635 if isinstance(_notify_trait, types.MethodType):
639 # FIXME: remove when support is bumped to 3.4.
636 # FIXME: remove when support is bumped to 3.4.
640 # when original method is restored,
637 # when original method is restored,
641 # remove the redundant value from __dict__
638 # remove the redundant value from __dict__
642 # (only used to preserve pickleability on Python < 3.4)
639 # (only used to preserve pickleability on Python < 3.4)
643 self.__dict__.pop('_notify_trait', None)
640 self.__dict__.pop('_notify_trait', None)
644 # trigger delayed notifications
641 # trigger delayed notifications
645 for v in dict(cache, **notifications).values():
642 for v in dict(cache, **notifications).values():
646 self._notify_trait(*v)
643 self._notify_trait(*v)
647
644
648 def _notify_trait(self, name, old_value, new_value):
645 def _notify_trait(self, name, old_value, new_value):
649
646
650 # First dynamic ones
647 # First dynamic ones
651 callables = []
648 callables = []
652 callables.extend(self._trait_notifiers.get(name,[]))
649 callables.extend(self._trait_notifiers.get(name,[]))
653 callables.extend(self._trait_notifiers.get('anytrait',[]))
650 callables.extend(self._trait_notifiers.get('anytrait',[]))
654
651
655 # Now static ones
652 # Now static ones
656 try:
653 try:
657 cb = getattr(self, '_%s_changed' % name)
654 cb = getattr(self, '_%s_changed' % name)
658 except:
655 except:
659 pass
656 pass
660 else:
657 else:
661 callables.append(cb)
658 callables.append(cb)
662
659
663 # Call them all now
660 # Call them all now
664 for c in callables:
661 for c in callables:
665 # Traits catches and logs errors here. I allow them to raise
662 # Traits catches and logs errors here. I allow them to raise
666 if callable(c):
663 if callable(c):
667 argspec = getargspec(c)
664 argspec = getargspec(c)
668
665
669 nargs = len(argspec[0])
666 nargs = len(argspec[0])
670 # Bound methods have an additional 'self' argument
667 # Bound methods have an additional 'self' argument
671 # I don't know how to treat unbound methods, but they
668 # I don't know how to treat unbound methods, but they
672 # can't really be used for callbacks.
669 # can't really be used for callbacks.
673 if isinstance(c, types.MethodType):
670 if isinstance(c, types.MethodType):
674 offset = -1
671 offset = -1
675 else:
672 else:
676 offset = 0
673 offset = 0
677 if nargs + offset == 0:
674 if nargs + offset == 0:
678 c()
675 c()
679 elif nargs + offset == 1:
676 elif nargs + offset == 1:
680 c(name)
677 c(name)
681 elif nargs + offset == 2:
678 elif nargs + offset == 2:
682 c(name, new_value)
679 c(name, new_value)
683 elif nargs + offset == 3:
680 elif nargs + offset == 3:
684 c(name, old_value, new_value)
681 c(name, old_value, new_value)
685 else:
682 else:
686 raise TraitError('a trait changed callback '
683 raise TraitError('a trait changed callback '
687 'must have 0-3 arguments.')
684 'must have 0-3 arguments.')
688 else:
685 else:
689 raise TraitError('a trait changed callback '
686 raise TraitError('a trait changed callback '
690 'must be callable.')
687 'must be callable.')
691
688
692
689
693 def _add_notifiers(self, handler, name):
690 def _add_notifiers(self, handler, name):
694 if name not in self._trait_notifiers:
691 if name not in self._trait_notifiers:
695 nlist = []
692 nlist = []
696 self._trait_notifiers[name] = nlist
693 self._trait_notifiers[name] = nlist
697 else:
694 else:
698 nlist = self._trait_notifiers[name]
695 nlist = self._trait_notifiers[name]
699 if handler not in nlist:
696 if handler not in nlist:
700 nlist.append(handler)
697 nlist.append(handler)
701
698
702 def _remove_notifiers(self, handler, name):
699 def _remove_notifiers(self, handler, name):
703 if name in self._trait_notifiers:
700 if name in self._trait_notifiers:
704 nlist = self._trait_notifiers[name]
701 nlist = self._trait_notifiers[name]
705 try:
702 try:
706 index = nlist.index(handler)
703 index = nlist.index(handler)
707 except ValueError:
704 except ValueError:
708 pass
705 pass
709 else:
706 else:
710 del nlist[index]
707 del nlist[index]
711
708
712 def on_trait_change(self, handler, name=None, remove=False):
709 def on_trait_change(self, handler, name=None, remove=False):
713 """Setup a handler to be called when a trait changes.
710 """Setup a handler to be called when a trait changes.
714
711
715 This is used to setup dynamic notifications of trait changes.
712 This is used to setup dynamic notifications of trait changes.
716
713
717 Static handlers can be created by creating methods on a HasTraits
714 Static handlers can be created by creating methods on a HasTraits
718 subclass with the naming convention '_[traitname]_changed'. Thus,
715 subclass with the naming convention '_[traitname]_changed'. Thus,
719 to create static handler for the trait 'a', create the method
716 to create static handler for the trait 'a', create the method
720 _a_changed(self, name, old, new) (fewer arguments can be used, see
717 _a_changed(self, name, old, new) (fewer arguments can be used, see
721 below).
718 below).
722
719
723 Parameters
720 Parameters
724 ----------
721 ----------
725 handler : callable
722 handler : callable
726 A callable that is called when a trait changes. Its
723 A callable that is called when a trait changes. Its
727 signature can be handler(), handler(name), handler(name, new)
724 signature can be handler(), handler(name), handler(name, new)
728 or handler(name, old, new).
725 or handler(name, old, new).
729 name : list, str, None
726 name : list, str, None
730 If None, the handler will apply to all traits. If a list
727 If None, the handler will apply to all traits. If a list
731 of str, handler will apply to all names in the list. If a
728 of str, handler will apply to all names in the list. If a
732 str, the handler will apply just to that name.
729 str, the handler will apply just to that name.
733 remove : bool
730 remove : bool
734 If False (the default), then install the handler. If True
731 If False (the default), then install the handler. If True
735 then unintall it.
732 then unintall it.
736 """
733 """
737 if remove:
734 if remove:
738 names = parse_notifier_name(name)
735 names = parse_notifier_name(name)
739 for n in names:
736 for n in names:
740 self._remove_notifiers(handler, n)
737 self._remove_notifiers(handler, n)
741 else:
738 else:
742 names = parse_notifier_name(name)
739 names = parse_notifier_name(name)
743 for n in names:
740 for n in names:
744 self._add_notifiers(handler, n)
741 self._add_notifiers(handler, n)
745
742
746 @classmethod
743 @classmethod
747 def class_trait_names(cls, **metadata):
744 def class_trait_names(cls, **metadata):
748 """Get a list of all the names of this class' traits.
745 """Get a list of all the names of this class' traits.
749
746
750 This method is just like the :meth:`trait_names` method,
747 This method is just like the :meth:`trait_names` method,
751 but is unbound.
748 but is unbound.
752 """
749 """
753 return cls.class_traits(**metadata).keys()
750 return cls.class_traits(**metadata).keys()
754
751
755 @classmethod
752 @classmethod
756 def class_traits(cls, **metadata):
753 def class_traits(cls, **metadata):
757 """Get a `dict` of all the traits of this class. The dictionary
754 """Get a `dict` of all the traits of this class. The dictionary
758 is keyed on the name and the values are the TraitType objects.
755 is keyed on the name and the values are the TraitType objects.
759
756
760 This method is just like the :meth:`traits` method, but is unbound.
757 This method is just like the :meth:`traits` method, but is unbound.
761
758
762 The TraitTypes returned don't know anything about the values
759 The TraitTypes returned don't know anything about the values
763 that the various HasTrait's instances are holding.
760 that the various HasTrait's instances are holding.
764
761
765 The metadata kwargs allow functions to be passed in which
762 The metadata kwargs allow functions to be passed in which
766 filter traits based on metadata values. The functions should
763 filter traits based on metadata values. The functions should
767 take a single value as an argument and return a boolean. If
764 take a single value as an argument and return a boolean. If
768 any function returns False, then the trait is not included in
765 any function returns False, then the trait is not included in
769 the output. This does not allow for any simple way of
766 the output. This does not allow for any simple way of
770 testing that a metadata name exists and has any
767 testing that a metadata name exists and has any
771 value because get_metadata returns None if a metadata key
768 value because get_metadata returns None if a metadata key
772 doesn't exist.
769 doesn't exist.
773 """
770 """
774 traits = dict([memb for memb in getmembers(cls) if
771 traits = dict([memb for memb in getmembers(cls) if
775 isinstance(memb[1], TraitType)])
772 isinstance(memb[1], TraitType)])
776
773
777 if len(metadata) == 0:
774 if len(metadata) == 0:
778 return traits
775 return traits
779
776
780 for meta_name, meta_eval in metadata.items():
777 for meta_name, meta_eval in metadata.items():
781 if type(meta_eval) is not FunctionType:
778 if type(meta_eval) is not FunctionType:
782 metadata[meta_name] = _SimpleTest(meta_eval)
779 metadata[meta_name] = _SimpleTest(meta_eval)
783
780
784 result = {}
781 result = {}
785 for name, trait in traits.items():
782 for name, trait in traits.items():
786 for meta_name, meta_eval in metadata.items():
783 for meta_name, meta_eval in metadata.items():
787 if not meta_eval(trait.get_metadata(meta_name)):
784 if not meta_eval(trait.get_metadata(meta_name)):
788 break
785 break
789 else:
786 else:
790 result[name] = trait
787 result[name] = trait
791
788
792 return result
789 return result
793
790
794 def trait_names(self, **metadata):
791 def trait_names(self, **metadata):
795 """Get a list of all the names of this class' traits."""
792 """Get a list of all the names of this class' traits."""
796 return self.traits(**metadata).keys()
793 return self.traits(**metadata).keys()
797
794
798 def traits(self, **metadata):
795 def traits(self, **metadata):
799 """Get a `dict` of all the traits of this class. The dictionary
796 """Get a `dict` of all the traits of this class. The dictionary
800 is keyed on the name and the values are the TraitType objects.
797 is keyed on the name and the values are the TraitType objects.
801
798
802 The TraitTypes returned don't know anything about the values
799 The TraitTypes returned don't know anything about the values
803 that the various HasTrait's instances are holding.
800 that the various HasTrait's instances are holding.
804
801
805 The metadata kwargs allow functions to be passed in which
802 The metadata kwargs allow functions to be passed in which
806 filter traits based on metadata values. The functions should
803 filter traits based on metadata values. The functions should
807 take a single value as an argument and return a boolean. If
804 take a single value as an argument and return a boolean. If
808 any function returns False, then the trait is not included in
805 any function returns False, then the trait is not included in
809 the output. This does not allow for any simple way of
806 the output. This does not allow for any simple way of
810 testing that a metadata name exists and has any
807 testing that a metadata name exists and has any
811 value because get_metadata returns None if a metadata key
808 value because get_metadata returns None if a metadata key
812 doesn't exist.
809 doesn't exist.
813 """
810 """
814 traits = dict([memb for memb in getmembers(self.__class__) if
811 traits = dict([memb for memb in getmembers(self.__class__) if
815 isinstance(memb[1], TraitType)])
812 isinstance(memb[1], TraitType)])
816
813
817 if len(metadata) == 0:
814 if len(metadata) == 0:
818 return traits
815 return traits
819
816
820 for meta_name, meta_eval in metadata.items():
817 for meta_name, meta_eval in metadata.items():
821 if type(meta_eval) is not FunctionType:
818 if type(meta_eval) is not FunctionType:
822 metadata[meta_name] = _SimpleTest(meta_eval)
819 metadata[meta_name] = _SimpleTest(meta_eval)
823
820
824 result = {}
821 result = {}
825 for name, trait in traits.items():
822 for name, trait in traits.items():
826 for meta_name, meta_eval in metadata.items():
823 for meta_name, meta_eval in metadata.items():
827 if not meta_eval(trait.get_metadata(meta_name)):
824 if not meta_eval(trait.get_metadata(meta_name)):
828 break
825 break
829 else:
826 else:
830 result[name] = trait
827 result[name] = trait
831
828
832 return result
829 return result
833
830
834 def trait_metadata(self, traitname, key, default=None):
831 def trait_metadata(self, traitname, key, default=None):
835 """Get metadata values for trait by key."""
832 """Get metadata values for trait by key."""
836 try:
833 try:
837 trait = getattr(self.__class__, traitname)
834 trait = getattr(self.__class__, traitname)
838 except AttributeError:
835 except AttributeError:
839 raise TraitError("Class %s does not have a trait named %s" %
836 raise TraitError("Class %s does not have a trait named %s" %
840 (self.__class__.__name__, traitname))
837 (self.__class__.__name__, traitname))
841 else:
838 else:
842 return trait.get_metadata(key, default)
839 return trait.get_metadata(key, default)
843
840
844 def add_trait(self, traitname, trait):
841 def add_trait(self, traitname, trait):
845 """Dynamically add a trait attribute to the HasTraits instance."""
842 """Dynamically add a trait attribute to the HasTraits instance."""
846 self.__class__ = type(self.__class__.__name__, (self.__class__,),
843 self.__class__ = type(self.__class__.__name__, (self.__class__,),
847 {traitname: trait})
844 {traitname: trait})
848 trait.set_default_value(self)
845 trait.set_default_value(self)
849
846
850 #-----------------------------------------------------------------------------
847 #-----------------------------------------------------------------------------
851 # Actual TraitTypes implementations/subclasses
848 # Actual TraitTypes implementations/subclasses
852 #-----------------------------------------------------------------------------
849 #-----------------------------------------------------------------------------
853
850
854 #-----------------------------------------------------------------------------
851 #-----------------------------------------------------------------------------
855 # TraitTypes subclasses for handling classes and instances of classes
852 # TraitTypes subclasses for handling classes and instances of classes
856 #-----------------------------------------------------------------------------
853 #-----------------------------------------------------------------------------
857
854
858
855
859 class ClassBasedTraitType(TraitType):
856 class ClassBasedTraitType(TraitType):
860 """
857 """
861 A trait with error reporting and string -> type resolution for Type,
858 A trait with error reporting and string -> type resolution for Type,
862 Instance and This.
859 Instance and This.
863 """
860 """
864
861
865 def _resolve_string(self, string):
862 def _resolve_string(self, string):
866 """
863 """
867 Resolve a string supplied for a type into an actual object.
864 Resolve a string supplied for a type into an actual object.
868 """
865 """
869 return import_item(string)
866 return import_item(string)
870
867
871 def error(self, obj, value):
868 def error(self, obj, value):
872 kind = type(value)
869 kind = type(value)
873 if (not py3compat.PY3) and kind is InstanceType:
870 if (not py3compat.PY3) and kind is InstanceType:
874 msg = 'class %s' % value.__class__.__name__
871 msg = 'class %s' % value.__class__.__name__
875 else:
872 else:
876 msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
873 msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
877
874
878 if obj is not None:
875 if obj is not None:
879 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
876 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
880 % (self.name, class_of(obj),
877 % (self.name, class_of(obj),
881 self.info(), msg)
878 self.info(), msg)
882 else:
879 else:
883 e = "The '%s' trait must be %s, but a value of %r was specified." \
880 e = "The '%s' trait must be %s, but a value of %r was specified." \
884 % (self.name, self.info(), msg)
881 % (self.name, self.info(), msg)
885
882
886 raise TraitError(e)
883 raise TraitError(e)
887
884
888
885
889 class Type(ClassBasedTraitType):
886 class Type(ClassBasedTraitType):
890 """A trait whose value must be a subclass of a specified class."""
887 """A trait whose value must be a subclass of a specified class."""
891
888
892 def __init__ (self, default_value=None, klass=None, allow_none=False,
889 def __init__ (self, default_value=None, klass=None, allow_none=False,
893 **metadata):
890 **metadata):
894 """Construct a Type trait
891 """Construct a Type trait
895
892
896 A Type trait specifies that its values must be subclasses of
893 A Type trait specifies that its values must be subclasses of
897 a particular class.
894 a particular class.
898
895
899 If only ``default_value`` is given, it is used for the ``klass`` as
896 If only ``default_value`` is given, it is used for the ``klass`` as
900 well.
897 well.
901
898
902 Parameters
899 Parameters
903 ----------
900 ----------
904 default_value : class, str or None
901 default_value : class, str or None
905 The default value must be a subclass of klass. If an str,
902 The default value must be a subclass of klass. If an str,
906 the str must be a fully specified class name, like 'foo.bar.Bah'.
903 the str must be a fully specified class name, like 'foo.bar.Bah'.
907 The string is resolved into real class, when the parent
904 The string is resolved into real class, when the parent
908 :class:`HasTraits` class is instantiated.
905 :class:`HasTraits` class is instantiated.
909 klass : class, str, None
906 klass : class, str, None
910 Values of this trait must be a subclass of klass. The klass
907 Values of this trait must be a subclass of klass. The klass
911 may be specified in a string like: 'foo.bar.MyClass'.
908 may be specified in a string like: 'foo.bar.MyClass'.
912 The string is resolved into real class, when the parent
909 The string is resolved into real class, when the parent
913 :class:`HasTraits` class is instantiated.
910 :class:`HasTraits` class is instantiated.
914 allow_none : bool [ default True ]
911 allow_none : bool [ default True ]
915 Indicates whether None is allowed as an assignable value. Even if
912 Indicates whether None is allowed as an assignable value. Even if
916 ``False``, the default value may be ``None``.
913 ``False``, the default value may be ``None``.
917 """
914 """
918 if default_value is None:
915 if default_value is None:
919 if klass is None:
916 if klass is None:
920 klass = object
917 klass = object
921 elif klass is None:
918 elif klass is None:
922 klass = default_value
919 klass = default_value
923
920
924 if not (inspect.isclass(klass) or isinstance(klass, py3compat.string_types)):
921 if not (inspect.isclass(klass) or isinstance(klass, py3compat.string_types)):
925 raise TraitError("A Type trait must specify a class.")
922 raise TraitError("A Type trait must specify a class.")
926
923
927 self.klass = klass
924 self.klass = klass
928
925
929 super(Type, self).__init__(default_value, allow_none=allow_none, **metadata)
926 super(Type, self).__init__(default_value, allow_none=allow_none, **metadata)
930
927
931 def validate(self, obj, value):
928 def validate(self, obj, value):
932 """Validates that the value is a valid object instance."""
929 """Validates that the value is a valid object instance."""
933 if isinstance(value, py3compat.string_types):
930 if isinstance(value, py3compat.string_types):
934 try:
931 try:
935 value = self._resolve_string(value)
932 value = self._resolve_string(value)
936 except ImportError:
933 except ImportError:
937 raise TraitError("The '%s' trait of %s instance must be a type, but "
934 raise TraitError("The '%s' trait of %s instance must be a type, but "
938 "%r could not be imported" % (self.name, obj, value))
935 "%r could not be imported" % (self.name, obj, value))
939 try:
936 try:
940 if issubclass(value, self.klass):
937 if issubclass(value, self.klass):
941 return value
938 return value
942 except:
939 except:
943 pass
940 pass
944
941
945 self.error(obj, value)
942 self.error(obj, value)
946
943
947 def info(self):
944 def info(self):
948 """ Returns a description of the trait."""
945 """ Returns a description of the trait."""
949 if isinstance(self.klass, py3compat.string_types):
946 if isinstance(self.klass, py3compat.string_types):
950 klass = self.klass
947 klass = self.klass
951 else:
948 else:
952 klass = self.klass.__name__
949 klass = self.klass.__name__
953 result = 'a subclass of ' + klass
950 result = 'a subclass of ' + klass
954 if self.allow_none:
951 if self.allow_none:
955 return result + ' or None'
952 return result + ' or None'
956 return result
953 return result
957
954
958 def instance_init(self):
955 def instance_init(self):
959 self._resolve_classes()
956 self._resolve_classes()
960 super(Type, self).instance_init()
957 super(Type, self).instance_init()
961
958
962 def _resolve_classes(self):
959 def _resolve_classes(self):
963 if isinstance(self.klass, py3compat.string_types):
960 if isinstance(self.klass, py3compat.string_types):
964 self.klass = self._resolve_string(self.klass)
961 self.klass = self._resolve_string(self.klass)
965 if isinstance(self.default_value, py3compat.string_types):
962 if isinstance(self.default_value, py3compat.string_types):
966 self.default_value = self._resolve_string(self.default_value)
963 self.default_value = self._resolve_string(self.default_value)
967
964
968 def get_default_value(self):
965 def get_default_value(self):
969 return self.default_value
966 return self.default_value
970
967
971
968
972 class DefaultValueGenerator(object):
969 class DefaultValueGenerator(object):
973 """A class for generating new default value instances."""
970 """A class for generating new default value instances."""
974
971
975 def __init__(self, *args, **kw):
972 def __init__(self, *args, **kw):
976 self.args = args
973 self.args = args
977 self.kw = kw
974 self.kw = kw
978
975
979 def generate(self, klass):
976 def generate(self, klass):
980 return klass(*self.args, **self.kw)
977 return klass(*self.args, **self.kw)
981
978
982
979
983 class Instance(ClassBasedTraitType):
980 class Instance(ClassBasedTraitType):
984 """A trait whose value must be an instance of a specified class.
981 """A trait whose value must be an instance of a specified class.
985
982
986 The value can also be an instance of a subclass of the specified class.
983 The value can also be an instance of a subclass of the specified class.
987
984
988 Subclasses can declare default classes by overriding the klass attribute
985 Subclasses can declare default classes by overriding the klass attribute
989 """
986 """
990
987
991 klass = None
988 klass = None
992
989
993 def __init__(self, klass=None, args=None, kw=None, allow_none=False,
990 def __init__(self, klass=None, args=None, kw=None, allow_none=False,
994 **metadata ):
991 **metadata ):
995 """Construct an Instance trait.
992 """Construct an Instance trait.
996
993
997 This trait allows values that are instances of a particular
994 This trait allows values that are instances of a particular
998 class or its subclasses. Our implementation is quite different
995 class or its subclasses. Our implementation is quite different
999 from that of enthough.traits as we don't allow instances to be used
996 from that of enthough.traits as we don't allow instances to be used
1000 for klass and we handle the ``args`` and ``kw`` arguments differently.
997 for klass and we handle the ``args`` and ``kw`` arguments differently.
1001
998
1002 Parameters
999 Parameters
1003 ----------
1000 ----------
1004 klass : class, str
1001 klass : class, str
1005 The class that forms the basis for the trait. Class names
1002 The class that forms the basis for the trait. Class names
1006 can also be specified as strings, like 'foo.bar.Bar'.
1003 can also be specified as strings, like 'foo.bar.Bar'.
1007 args : tuple
1004 args : tuple
1008 Positional arguments for generating the default value.
1005 Positional arguments for generating the default value.
1009 kw : dict
1006 kw : dict
1010 Keyword arguments for generating the default value.
1007 Keyword arguments for generating the default value.
1011 allow_none : bool [default True]
1008 allow_none : bool [default True]
1012 Indicates whether None is allowed as a value.
1009 Indicates whether None is allowed as a value.
1013
1010
1014 Notes
1011 Notes
1015 -----
1012 -----
1016 If both ``args`` and ``kw`` are None, then the default value is None.
1013 If both ``args`` and ``kw`` are None, then the default value is None.
1017 If ``args`` is a tuple and ``kw`` is a dict, then the default is
1014 If ``args`` is a tuple and ``kw`` is a dict, then the default is
1018 created as ``klass(*args, **kw)``. If exactly one of ``args`` or ``kw`` is
1015 created as ``klass(*args, **kw)``. If exactly one of ``args`` or ``kw`` is
1019 None, the None is replaced by ``()`` or ``{}``, respectively.
1016 None, the None is replaced by ``()`` or ``{}``, respectively.
1020 """
1017 """
1021 if klass is None:
1018 if klass is None:
1022 klass = self.klass
1019 klass = self.klass
1023
1020
1024 if (klass is not None) and (inspect.isclass(klass) or isinstance(klass, py3compat.string_types)):
1021 if (klass is not None) and (inspect.isclass(klass) or isinstance(klass, py3compat.string_types)):
1025 self.klass = klass
1022 self.klass = klass
1026 else:
1023 else:
1027 raise TraitError('The klass attribute must be a class'
1024 raise TraitError('The klass attribute must be a class'
1028 ' not: %r' % klass)
1025 ' not: %r' % klass)
1029
1026
1030 # self.klass is a class, so handle default_value
1027 # self.klass is a class, so handle default_value
1031 if args is None and kw is None:
1028 if args is None and kw is None:
1032 default_value = None
1029 default_value = None
1033 else:
1030 else:
1034 if args is None:
1031 if args is None:
1035 # kw is not None
1032 # kw is not None
1036 args = ()
1033 args = ()
1037 elif kw is None:
1034 elif kw is None:
1038 # args is not None
1035 # args is not None
1039 kw = {}
1036 kw = {}
1040
1037
1041 if not isinstance(kw, dict):
1038 if not isinstance(kw, dict):
1042 raise TraitError("The 'kw' argument must be a dict or None.")
1039 raise TraitError("The 'kw' argument must be a dict or None.")
1043 if not isinstance(args, tuple):
1040 if not isinstance(args, tuple):
1044 raise TraitError("The 'args' argument must be a tuple or None.")
1041 raise TraitError("The 'args' argument must be a tuple or None.")
1045
1042
1046 default_value = DefaultValueGenerator(*args, **kw)
1043 default_value = DefaultValueGenerator(*args, **kw)
1047
1044
1048 super(Instance, self).__init__(default_value, allow_none=allow_none, **metadata)
1045 super(Instance, self).__init__(default_value, allow_none=allow_none, **metadata)
1049
1046
1050 def validate(self, obj, value):
1047 def validate(self, obj, value):
1051 if isinstance(value, self.klass):
1048 if isinstance(value, self.klass):
1052 return value
1049 return value
1053 else:
1050 else:
1054 self.error(obj, value)
1051 self.error(obj, value)
1055
1052
1056 def info(self):
1053 def info(self):
1057 if isinstance(self.klass, py3compat.string_types):
1054 if isinstance(self.klass, py3compat.string_types):
1058 klass = self.klass
1055 klass = self.klass
1059 else:
1056 else:
1060 klass = self.klass.__name__
1057 klass = self.klass.__name__
1061 result = class_of(klass)
1058 result = class_of(klass)
1062 if self.allow_none:
1059 if self.allow_none:
1063 return result + ' or None'
1060 return result + ' or None'
1064
1061
1065 return result
1062 return result
1066
1063
1067 def instance_init(self):
1064 def instance_init(self):
1068 self._resolve_classes()
1065 self._resolve_classes()
1069 super(Instance, self).instance_init()
1066 super(Instance, self).instance_init()
1070
1067
1071 def _resolve_classes(self):
1068 def _resolve_classes(self):
1072 if isinstance(self.klass, py3compat.string_types):
1069 if isinstance(self.klass, py3compat.string_types):
1073 self.klass = self._resolve_string(self.klass)
1070 self.klass = self._resolve_string(self.klass)
1074
1071
1075 def get_default_value(self):
1072 def get_default_value(self):
1076 """Instantiate a default value instance.
1073 """Instantiate a default value instance.
1077
1074
1078 This is called when the containing HasTraits classes'
1075 This is called when the containing HasTraits classes'
1079 :meth:`__new__` method is called to ensure that a unique instance
1076 :meth:`__new__` method is called to ensure that a unique instance
1080 is created for each HasTraits instance.
1077 is created for each HasTraits instance.
1081 """
1078 """
1082 dv = self.default_value
1079 dv = self.default_value
1083 if isinstance(dv, DefaultValueGenerator):
1080 if isinstance(dv, DefaultValueGenerator):
1084 return dv.generate(self.klass)
1081 return dv.generate(self.klass)
1085 else:
1082 else:
1086 return dv
1083 return dv
1087
1084
1088
1085
1089 class ForwardDeclaredMixin(object):
1086 class ForwardDeclaredMixin(object):
1090 """
1087 """
1091 Mixin for forward-declared versions of Instance and Type.
1088 Mixin for forward-declared versions of Instance and Type.
1092 """
1089 """
1093 def _resolve_string(self, string):
1090 def _resolve_string(self, string):
1094 """
1091 """
1095 Find the specified class name by looking for it in the module in which
1092 Find the specified class name by looking for it in the module in which
1096 our this_class attribute was defined.
1093 our this_class attribute was defined.
1097 """
1094 """
1098 modname = self.this_class.__module__
1095 modname = self.this_class.__module__
1099 return import_item('.'.join([modname, string]))
1096 return import_item('.'.join([modname, string]))
1100
1097
1101
1098
1102 class ForwardDeclaredType(ForwardDeclaredMixin, Type):
1099 class ForwardDeclaredType(ForwardDeclaredMixin, Type):
1103 """
1100 """
1104 Forward-declared version of Type.
1101 Forward-declared version of Type.
1105 """
1102 """
1106 pass
1103 pass
1107
1104
1108
1105
1109 class ForwardDeclaredInstance(ForwardDeclaredMixin, Instance):
1106 class ForwardDeclaredInstance(ForwardDeclaredMixin, Instance):
1110 """
1107 """
1111 Forward-declared version of Instance.
1108 Forward-declared version of Instance.
1112 """
1109 """
1113 pass
1110 pass
1114
1111
1115
1112
1116 class This(ClassBasedTraitType):
1113 class This(ClassBasedTraitType):
1117 """A trait for instances of the class containing this trait.
1114 """A trait for instances of the class containing this trait.
1118
1115
1119 Because how how and when class bodies are executed, the ``This``
1116 Because how how and when class bodies are executed, the ``This``
1120 trait can only have a default value of None. This, and because we
1117 trait can only have a default value of None. This, and because we
1121 always validate default values, ``allow_none`` is *always* true.
1118 always validate default values, ``allow_none`` is *always* true.
1122 """
1119 """
1123
1120
1124 info_text = 'an instance of the same type as the receiver or None'
1121 info_text = 'an instance of the same type as the receiver or None'
1125
1122
1126 def __init__(self, **metadata):
1123 def __init__(self, **metadata):
1127 super(This, self).__init__(None, **metadata)
1124 super(This, self).__init__(None, **metadata)
1128
1125
1129 def validate(self, obj, value):
1126 def validate(self, obj, value):
1130 # What if value is a superclass of obj.__class__? This is
1127 # What if value is a superclass of obj.__class__? This is
1131 # complicated if it was the superclass that defined the This
1128 # complicated if it was the superclass that defined the This
1132 # trait.
1129 # trait.
1133 if isinstance(value, self.this_class) or (value is None):
1130 if isinstance(value, self.this_class) or (value is None):
1134 return value
1131 return value
1135 else:
1132 else:
1136 self.error(obj, value)
1133 self.error(obj, value)
1137
1134
1138
1135
1139 class Union(TraitType):
1136 class Union(TraitType):
1140 """A trait type representing a Union type."""
1137 """A trait type representing a Union type."""
1141
1138
1142 def __init__(self, trait_types, **metadata):
1139 def __init__(self, trait_types, **metadata):
1143 """Construct a Union trait.
1140 """Construct a Union trait.
1144
1141
1145 This trait allows values that are allowed by at least one of the
1142 This trait allows values that are allowed by at least one of the
1146 specified trait types. A Union traitlet cannot have metadata on
1143 specified trait types. A Union traitlet cannot have metadata on
1147 its own, besides the metadata of the listed types.
1144 its own, besides the metadata of the listed types.
1148
1145
1149 Parameters
1146 Parameters
1150 ----------
1147 ----------
1151 trait_types: sequence
1148 trait_types: sequence
1152 The list of trait types of length at least 1.
1149 The list of trait types of length at least 1.
1153
1150
1154 Notes
1151 Notes
1155 -----
1152 -----
1156 Union([Float(), Bool(), Int()]) attempts to validate the provided values
1153 Union([Float(), Bool(), Int()]) attempts to validate the provided values
1157 with the validation function of Float, then Bool, and finally Int.
1154 with the validation function of Float, then Bool, and finally Int.
1158 """
1155 """
1159 self.trait_types = trait_types
1156 self.trait_types = trait_types
1160 self.info_text = " or ".join([tt.info_text for tt in self.trait_types])
1157 self.info_text = " or ".join([tt.info_text for tt in self.trait_types])
1161 self.default_value = self.trait_types[0].get_default_value()
1158 self.default_value = self.trait_types[0].get_default_value()
1162 super(Union, self).__init__(**metadata)
1159 super(Union, self).__init__(**metadata)
1163
1160
1164 def instance_init(self):
1161 def instance_init(self):
1165 for trait_type in self.trait_types:
1162 for trait_type in self.trait_types:
1166 trait_type.name = self.name
1163 trait_type.name = self.name
1167 trait_type.this_class = self.this_class
1164 trait_type.this_class = self.this_class
1168 trait_type.instance_init()
1165 trait_type.instance_init()
1169 super(Union, self).instance_init()
1166 super(Union, self).instance_init()
1170
1167
1171 def validate(self, obj, value):
1168 def validate(self, obj, value):
1172 for trait_type in self.trait_types:
1169 for trait_type in self.trait_types:
1173 try:
1170 try:
1174 v = trait_type._validate(obj, value)
1171 v = trait_type._validate(obj, value)
1175 self._metadata = trait_type._metadata
1172 self._metadata = trait_type._metadata
1176 return v
1173 return v
1177 except TraitError:
1174 except TraitError:
1178 continue
1175 continue
1179 self.error(obj, value)
1176 self.error(obj, value)
1180
1177
1181 def __or__(self, other):
1178 def __or__(self, other):
1182 if isinstance(other, Union):
1179 if isinstance(other, Union):
1183 return Union(self.trait_types + other.trait_types)
1180 return Union(self.trait_types + other.trait_types)
1184 else:
1181 else:
1185 return Union(self.trait_types + [other])
1182 return Union(self.trait_types + [other])
1186
1183
1187 #-----------------------------------------------------------------------------
1184 #-----------------------------------------------------------------------------
1188 # Basic TraitTypes implementations/subclasses
1185 # Basic TraitTypes implementations/subclasses
1189 #-----------------------------------------------------------------------------
1186 #-----------------------------------------------------------------------------
1190
1187
1191
1188
1192 class Any(TraitType):
1189 class Any(TraitType):
1193 default_value = None
1190 default_value = None
1194 info_text = 'any value'
1191 info_text = 'any value'
1195
1192
1196
1193
1197 class Int(TraitType):
1194 class Int(TraitType):
1198 """An int trait."""
1195 """An int trait."""
1199
1196
1200 default_value = 0
1197 default_value = 0
1201 info_text = 'an int'
1198 info_text = 'an int'
1202
1199
1203 def validate(self, obj, value):
1200 def validate(self, obj, value):
1204 if isinstance(value, int):
1201 if isinstance(value, int):
1205 return value
1202 return value
1206 self.error(obj, value)
1203 self.error(obj, value)
1207
1204
1208 class CInt(Int):
1205 class CInt(Int):
1209 """A casting version of the int trait."""
1206 """A casting version of the int trait."""
1210
1207
1211 def validate(self, obj, value):
1208 def validate(self, obj, value):
1212 try:
1209 try:
1213 return int(value)
1210 return int(value)
1214 except:
1211 except:
1215 self.error(obj, value)
1212 self.error(obj, value)
1216
1213
1217 if py3compat.PY3:
1214 if py3compat.PY3:
1218 Long, CLong = Int, CInt
1215 Long, CLong = Int, CInt
1219 Integer = Int
1216 Integer = Int
1220 else:
1217 else:
1221 class Long(TraitType):
1218 class Long(TraitType):
1222 """A long integer trait."""
1219 """A long integer trait."""
1223
1220
1224 default_value = 0
1221 default_value = 0
1225 info_text = 'a long'
1222 info_text = 'a long'
1226
1223
1227 def validate(self, obj, value):
1224 def validate(self, obj, value):
1228 if isinstance(value, long):
1225 if isinstance(value, long):
1229 return value
1226 return value
1230 if isinstance(value, int):
1227 if isinstance(value, int):
1231 return long(value)
1228 return long(value)
1232 self.error(obj, value)
1229 self.error(obj, value)
1233
1230
1234
1231
1235 class CLong(Long):
1232 class CLong(Long):
1236 """A casting version of the long integer trait."""
1233 """A casting version of the long integer trait."""
1237
1234
1238 def validate(self, obj, value):
1235 def validate(self, obj, value):
1239 try:
1236 try:
1240 return long(value)
1237 return long(value)
1241 except:
1238 except:
1242 self.error(obj, value)
1239 self.error(obj, value)
1243
1240
1244 class Integer(TraitType):
1241 class Integer(TraitType):
1245 """An integer trait.
1242 """An integer trait.
1246
1243
1247 Longs that are unnecessary (<= sys.maxint) are cast to ints."""
1244 Longs that are unnecessary (<= sys.maxint) are cast to ints."""
1248
1245
1249 default_value = 0
1246 default_value = 0
1250 info_text = 'an integer'
1247 info_text = 'an integer'
1251
1248
1252 def validate(self, obj, value):
1249 def validate(self, obj, value):
1253 if isinstance(value, int):
1250 if isinstance(value, int):
1254 return value
1251 return value
1255 if isinstance(value, long):
1252 if isinstance(value, long):
1256 # downcast longs that fit in int:
1253 # downcast longs that fit in int:
1257 # note that int(n > sys.maxint) returns a long, so
1254 # note that int(n > sys.maxint) returns a long, so
1258 # we don't need a condition on this cast
1255 # we don't need a condition on this cast
1259 return int(value)
1256 return int(value)
1260 if sys.platform == "cli":
1257 if sys.platform == "cli":
1261 from System import Int64
1258 from System import Int64
1262 if isinstance(value, Int64):
1259 if isinstance(value, Int64):
1263 return int(value)
1260 return int(value)
1264 self.error(obj, value)
1261 self.error(obj, value)
1265
1262
1266
1263
1267 class Float(TraitType):
1264 class Float(TraitType):
1268 """A float trait."""
1265 """A float trait."""
1269
1266
1270 default_value = 0.0
1267 default_value = 0.0
1271 info_text = 'a float'
1268 info_text = 'a float'
1272
1269
1273 def validate(self, obj, value):
1270 def validate(self, obj, value):
1274 if isinstance(value, float):
1271 if isinstance(value, float):
1275 return value
1272 return value
1276 if isinstance(value, int):
1273 if isinstance(value, int):
1277 return float(value)
1274 return float(value)
1278 self.error(obj, value)
1275 self.error(obj, value)
1279
1276
1280
1277
1281 class CFloat(Float):
1278 class CFloat(Float):
1282 """A casting version of the float trait."""
1279 """A casting version of the float trait."""
1283
1280
1284 def validate(self, obj, value):
1281 def validate(self, obj, value):
1285 try:
1282 try:
1286 return float(value)
1283 return float(value)
1287 except:
1284 except:
1288 self.error(obj, value)
1285 self.error(obj, value)
1289
1286
1290 class Complex(TraitType):
1287 class Complex(TraitType):
1291 """A trait for complex numbers."""
1288 """A trait for complex numbers."""
1292
1289
1293 default_value = 0.0 + 0.0j
1290 default_value = 0.0 + 0.0j
1294 info_text = 'a complex number'
1291 info_text = 'a complex number'
1295
1292
1296 def validate(self, obj, value):
1293 def validate(self, obj, value):
1297 if isinstance(value, complex):
1294 if isinstance(value, complex):
1298 return value
1295 return value
1299 if isinstance(value, (float, int)):
1296 if isinstance(value, (float, int)):
1300 return complex(value)
1297 return complex(value)
1301 self.error(obj, value)
1298 self.error(obj, value)
1302
1299
1303
1300
1304 class CComplex(Complex):
1301 class CComplex(Complex):
1305 """A casting version of the complex number trait."""
1302 """A casting version of the complex number trait."""
1306
1303
1307 def validate (self, obj, value):
1304 def validate (self, obj, value):
1308 try:
1305 try:
1309 return complex(value)
1306 return complex(value)
1310 except:
1307 except:
1311 self.error(obj, value)
1308 self.error(obj, value)
1312
1309
1313 # We should always be explicit about whether we're using bytes or unicode, both
1310 # We should always be explicit about whether we're using bytes or unicode, both
1314 # for Python 3 conversion and for reliable unicode behaviour on Python 2. So
1311 # for Python 3 conversion and for reliable unicode behaviour on Python 2. So
1315 # we don't have a Str type.
1312 # we don't have a Str type.
1316 class Bytes(TraitType):
1313 class Bytes(TraitType):
1317 """A trait for byte strings."""
1314 """A trait for byte strings."""
1318
1315
1319 default_value = b''
1316 default_value = b''
1320 info_text = 'a bytes object'
1317 info_text = 'a bytes object'
1321
1318
1322 def validate(self, obj, value):
1319 def validate(self, obj, value):
1323 if isinstance(value, bytes):
1320 if isinstance(value, bytes):
1324 return value
1321 return value
1325 self.error(obj, value)
1322 self.error(obj, value)
1326
1323
1327
1324
1328 class CBytes(Bytes):
1325 class CBytes(Bytes):
1329 """A casting version of the byte string trait."""
1326 """A casting version of the byte string trait."""
1330
1327
1331 def validate(self, obj, value):
1328 def validate(self, obj, value):
1332 try:
1329 try:
1333 return bytes(value)
1330 return bytes(value)
1334 except:
1331 except:
1335 self.error(obj, value)
1332 self.error(obj, value)
1336
1333
1337
1334
1338 class Unicode(TraitType):
1335 class Unicode(TraitType):
1339 """A trait for unicode strings."""
1336 """A trait for unicode strings."""
1340
1337
1341 default_value = u''
1338 default_value = u''
1342 info_text = 'a unicode string'
1339 info_text = 'a unicode string'
1343
1340
1344 def validate(self, obj, value):
1341 def validate(self, obj, value):
1345 if isinstance(value, py3compat.unicode_type):
1342 if isinstance(value, py3compat.unicode_type):
1346 return value
1343 return value
1347 if isinstance(value, bytes):
1344 if isinstance(value, bytes):
1348 try:
1345 try:
1349 return value.decode('ascii', 'strict')
1346 return value.decode('ascii', 'strict')
1350 except UnicodeDecodeError:
1347 except UnicodeDecodeError:
1351 msg = "Could not decode {!r} for unicode trait '{}' of {} instance."
1348 msg = "Could not decode {!r} for unicode trait '{}' of {} instance."
1352 raise TraitError(msg.format(value, self.name, class_of(obj)))
1349 raise TraitError(msg.format(value, self.name, class_of(obj)))
1353 self.error(obj, value)
1350 self.error(obj, value)
1354
1351
1355
1352
1356 class CUnicode(Unicode):
1353 class CUnicode(Unicode):
1357 """A casting version of the unicode trait."""
1354 """A casting version of the unicode trait."""
1358
1355
1359 def validate(self, obj, value):
1356 def validate(self, obj, value):
1360 try:
1357 try:
1361 return py3compat.unicode_type(value)
1358 return py3compat.unicode_type(value)
1362 except:
1359 except:
1363 self.error(obj, value)
1360 self.error(obj, value)
1364
1361
1365
1362
1366 class ObjectName(TraitType):
1363 class ObjectName(TraitType):
1367 """A string holding a valid object name in this version of Python.
1364 """A string holding a valid object name in this version of Python.
1368
1365
1369 This does not check that the name exists in any scope."""
1366 This does not check that the name exists in any scope."""
1370 info_text = "a valid object identifier in Python"
1367 info_text = "a valid object identifier in Python"
1371
1368
1372 if py3compat.PY3:
1369 if py3compat.PY3:
1373 # Python 3:
1370 # Python 3:
1374 coerce_str = staticmethod(lambda _,s: s)
1371 coerce_str = staticmethod(lambda _,s: s)
1375
1372
1376 else:
1373 else:
1377 # Python 2:
1374 # Python 2:
1378 def coerce_str(self, obj, value):
1375 def coerce_str(self, obj, value):
1379 "In Python 2, coerce ascii-only unicode to str"
1376 "In Python 2, coerce ascii-only unicode to str"
1380 if isinstance(value, unicode):
1377 if isinstance(value, unicode):
1381 try:
1378 try:
1382 return str(value)
1379 return str(value)
1383 except UnicodeEncodeError:
1380 except UnicodeEncodeError:
1384 self.error(obj, value)
1381 self.error(obj, value)
1385 return value
1382 return value
1386
1383
1387 def validate(self, obj, value):
1384 def validate(self, obj, value):
1388 value = self.coerce_str(obj, value)
1385 value = self.coerce_str(obj, value)
1389
1386
1390 if isinstance(value, string_types) and py3compat.isidentifier(value):
1387 if isinstance(value, string_types) and py3compat.isidentifier(value):
1391 return value
1388 return value
1392 self.error(obj, value)
1389 self.error(obj, value)
1393
1390
1394 class DottedObjectName(ObjectName):
1391 class DottedObjectName(ObjectName):
1395 """A string holding a valid dotted object name in Python, such as A.b3._c"""
1392 """A string holding a valid dotted object name in Python, such as A.b3._c"""
1396 def validate(self, obj, value):
1393 def validate(self, obj, value):
1397 value = self.coerce_str(obj, value)
1394 value = self.coerce_str(obj, value)
1398
1395
1399 if isinstance(value, string_types) and py3compat.isidentifier(value, dotted=True):
1396 if isinstance(value, string_types) and py3compat.isidentifier(value, dotted=True):
1400 return value
1397 return value
1401 self.error(obj, value)
1398 self.error(obj, value)
1402
1399
1403
1400
1404 class Bool(TraitType):
1401 class Bool(TraitType):
1405 """A boolean (True, False) trait."""
1402 """A boolean (True, False) trait."""
1406
1403
1407 default_value = False
1404 default_value = False
1408 info_text = 'a boolean'
1405 info_text = 'a boolean'
1409
1406
1410 def validate(self, obj, value):
1407 def validate(self, obj, value):
1411 if isinstance(value, bool):
1408 if isinstance(value, bool):
1412 return value
1409 return value
1413 self.error(obj, value)
1410 self.error(obj, value)
1414
1411
1415
1412
1416 class CBool(Bool):
1413 class CBool(Bool):
1417 """A casting version of the boolean trait."""
1414 """A casting version of the boolean trait."""
1418
1415
1419 def validate(self, obj, value):
1416 def validate(self, obj, value):
1420 try:
1417 try:
1421 return bool(value)
1418 return bool(value)
1422 except:
1419 except:
1423 self.error(obj, value)
1420 self.error(obj, value)
1424
1421
1425
1422
1426 class Enum(TraitType):
1423 class Enum(TraitType):
1427 """An enum that whose value must be in a given sequence."""
1424 """An enum that whose value must be in a given sequence."""
1428
1425
1429 def __init__(self, values, default_value=None, **metadata):
1426 def __init__(self, values, default_value=None, **metadata):
1430 self.values = values
1427 self.values = values
1431 super(Enum, self).__init__(default_value, **metadata)
1428 super(Enum, self).__init__(default_value, **metadata)
1432
1429
1433 def validate(self, obj, value):
1430 def validate(self, obj, value):
1434 if value in self.values:
1431 if value in self.values:
1435 return value
1432 return value
1436 self.error(obj, value)
1433 self.error(obj, value)
1437
1434
1438 def info(self):
1435 def info(self):
1439 """ Returns a description of the trait."""
1436 """ Returns a description of the trait."""
1440 result = 'any of ' + repr(self.values)
1437 result = 'any of ' + repr(self.values)
1441 if self.allow_none:
1438 if self.allow_none:
1442 return result + ' or None'
1439 return result + ' or None'
1443 return result
1440 return result
1444
1441
1445 class CaselessStrEnum(Enum):
1442 class CaselessStrEnum(Enum):
1446 """An enum of strings that are caseless in validate."""
1443 """An enum of strings that are caseless in validate."""
1447
1444
1448 def validate(self, obj, value):
1445 def validate(self, obj, value):
1449 if not isinstance(value, py3compat.string_types):
1446 if not isinstance(value, py3compat.string_types):
1450 self.error(obj, value)
1447 self.error(obj, value)
1451
1448
1452 for v in self.values:
1449 for v in self.values:
1453 if v.lower() == value.lower():
1450 if v.lower() == value.lower():
1454 return v
1451 return v
1455 self.error(obj, value)
1452 self.error(obj, value)
1456
1453
1457 class Container(Instance):
1454 class Container(Instance):
1458 """An instance of a container (list, set, etc.)
1455 """An instance of a container (list, set, etc.)
1459
1456
1460 To be subclassed by overriding klass.
1457 To be subclassed by overriding klass.
1461 """
1458 """
1462 klass = None
1459 klass = None
1463 _cast_types = ()
1460 _cast_types = ()
1464 _valid_defaults = SequenceTypes
1461 _valid_defaults = SequenceTypes
1465 _trait = None
1462 _trait = None
1466
1463
1467 def __init__(self, trait=None, default_value=None, allow_none=False,
1464 def __init__(self, trait=None, default_value=None, allow_none=False,
1468 **metadata):
1465 **metadata):
1469 """Create a container trait type from a list, set, or tuple.
1466 """Create a container trait type from a list, set, or tuple.
1470
1467
1471 The default value is created by doing ``List(default_value)``,
1468 The default value is created by doing ``List(default_value)``,
1472 which creates a copy of the ``default_value``.
1469 which creates a copy of the ``default_value``.
1473
1470
1474 ``trait`` can be specified, which restricts the type of elements
1471 ``trait`` can be specified, which restricts the type of elements
1475 in the container to that TraitType.
1472 in the container to that TraitType.
1476
1473
1477 If only one arg is given and it is not a Trait, it is taken as
1474 If only one arg is given and it is not a Trait, it is taken as
1478 ``default_value``:
1475 ``default_value``:
1479
1476
1480 ``c = List([1,2,3])``
1477 ``c = List([1,2,3])``
1481
1478
1482 Parameters
1479 Parameters
1483 ----------
1480 ----------
1484
1481
1485 trait : TraitType [ optional ]
1482 trait : TraitType [ optional ]
1486 the type for restricting the contents of the Container. If unspecified,
1483 the type for restricting the contents of the Container. If unspecified,
1487 types are not checked.
1484 types are not checked.
1488
1485
1489 default_value : SequenceType [ optional ]
1486 default_value : SequenceType [ optional ]
1490 The default value for the Trait. Must be list/tuple/set, and
1487 The default value for the Trait. Must be list/tuple/set, and
1491 will be cast to the container type.
1488 will be cast to the container type.
1492
1489
1493 allow_none : bool [ default False ]
1490 allow_none : bool [ default False ]
1494 Whether to allow the value to be None
1491 Whether to allow the value to be None
1495
1492
1496 **metadata : any
1493 **metadata : any
1497 further keys for extensions to the Trait (e.g. config)
1494 further keys for extensions to the Trait (e.g. config)
1498
1495
1499 """
1496 """
1500 # allow List([values]):
1497 # allow List([values]):
1501 if default_value is None and not is_trait(trait):
1498 if default_value is None and not is_trait(trait):
1502 default_value = trait
1499 default_value = trait
1503 trait = None
1500 trait = None
1504
1501
1505 if default_value is None:
1502 if default_value is None:
1506 args = ()
1503 args = ()
1507 elif isinstance(default_value, self._valid_defaults):
1504 elif isinstance(default_value, self._valid_defaults):
1508 args = (default_value,)
1505 args = (default_value,)
1509 else:
1506 else:
1510 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1507 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1511
1508
1512 if is_trait(trait):
1509 if is_trait(trait):
1513 self._trait = trait() if isinstance(trait, type) else trait
1510 self._trait = trait() if isinstance(trait, type) else trait
1514 self._trait.name = 'element'
1511 self._trait.name = 'element'
1515 elif trait is not None:
1512 elif trait is not None:
1516 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1513 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1517
1514
1518 super(Container,self).__init__(klass=self.klass, args=args,
1515 super(Container,self).__init__(klass=self.klass, args=args,
1519 allow_none=allow_none, **metadata)
1516 allow_none=allow_none, **metadata)
1520
1517
1521 def element_error(self, obj, element, validator):
1518 def element_error(self, obj, element, validator):
1522 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1519 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1523 % (self.name, class_of(obj), validator.info(), repr_type(element))
1520 % (self.name, class_of(obj), validator.info(), repr_type(element))
1524 raise TraitError(e)
1521 raise TraitError(e)
1525
1522
1526 def validate(self, obj, value):
1523 def validate(self, obj, value):
1527 if isinstance(value, self._cast_types):
1524 if isinstance(value, self._cast_types):
1528 value = self.klass(value)
1525 value = self.klass(value)
1529 value = super(Container, self).validate(obj, value)
1526 value = super(Container, self).validate(obj, value)
1530 if value is None:
1527 if value is None:
1531 return value
1528 return value
1532
1529
1533 value = self.validate_elements(obj, value)
1530 value = self.validate_elements(obj, value)
1534
1531
1535 return value
1532 return value
1536
1533
1537 def validate_elements(self, obj, value):
1534 def validate_elements(self, obj, value):
1538 validated = []
1535 validated = []
1539 if self._trait is None or isinstance(self._trait, Any):
1536 if self._trait is None or isinstance(self._trait, Any):
1540 return value
1537 return value
1541 for v in value:
1538 for v in value:
1542 try:
1539 try:
1543 v = self._trait._validate(obj, v)
1540 v = self._trait._validate(obj, v)
1544 except TraitError:
1541 except TraitError:
1545 self.element_error(obj, v, self._trait)
1542 self.element_error(obj, v, self._trait)
1546 else:
1543 else:
1547 validated.append(v)
1544 validated.append(v)
1548 return self.klass(validated)
1545 return self.klass(validated)
1549
1546
1550 def instance_init(self):
1547 def instance_init(self):
1551 if isinstance(self._trait, TraitType):
1548 if isinstance(self._trait, TraitType):
1552 self._trait.this_class = self.this_class
1549 self._trait.this_class = self.this_class
1553 self._trait.instance_init()
1550 self._trait.instance_init()
1554 super(Container, self).instance_init()
1551 super(Container, self).instance_init()
1555
1552
1556
1553
1557 class List(Container):
1554 class List(Container):
1558 """An instance of a Python list."""
1555 """An instance of a Python list."""
1559 klass = list
1556 klass = list
1560 _cast_types = (tuple,)
1557 _cast_types = (tuple,)
1561
1558
1562 def __init__(self, trait=None, default_value=None, minlen=0, maxlen=sys.maxsize, **metadata):
1559 def __init__(self, trait=None, default_value=None, minlen=0, maxlen=sys.maxsize, **metadata):
1563 """Create a List trait type from a list, set, or tuple.
1560 """Create a List trait type from a list, set, or tuple.
1564
1561
1565 The default value is created by doing ``List(default_value)``,
1562 The default value is created by doing ``List(default_value)``,
1566 which creates a copy of the ``default_value``.
1563 which creates a copy of the ``default_value``.
1567
1564
1568 ``trait`` can be specified, which restricts the type of elements
1565 ``trait`` can be specified, which restricts the type of elements
1569 in the container to that TraitType.
1566 in the container to that TraitType.
1570
1567
1571 If only one arg is given and it is not a Trait, it is taken as
1568 If only one arg is given and it is not a Trait, it is taken as
1572 ``default_value``:
1569 ``default_value``:
1573
1570
1574 ``c = List([1,2,3])``
1571 ``c = List([1,2,3])``
1575
1572
1576 Parameters
1573 Parameters
1577 ----------
1574 ----------
1578
1575
1579 trait : TraitType [ optional ]
1576 trait : TraitType [ optional ]
1580 the type for restricting the contents of the Container. If unspecified,
1577 the type for restricting the contents of the Container. If unspecified,
1581 types are not checked.
1578 types are not checked.
1582
1579
1583 default_value : SequenceType [ optional ]
1580 default_value : SequenceType [ optional ]
1584 The default value for the Trait. Must be list/tuple/set, and
1581 The default value for the Trait. Must be list/tuple/set, and
1585 will be cast to the container type.
1582 will be cast to the container type.
1586
1583
1587 minlen : Int [ default 0 ]
1584 minlen : Int [ default 0 ]
1588 The minimum length of the input list
1585 The minimum length of the input list
1589
1586
1590 maxlen : Int [ default sys.maxsize ]
1587 maxlen : Int [ default sys.maxsize ]
1591 The maximum length of the input list
1588 The maximum length of the input list
1592
1589
1593 allow_none : bool [ default False ]
1590 allow_none : bool [ default False ]
1594 Whether to allow the value to be None
1591 Whether to allow the value to be None
1595
1592
1596 **metadata : any
1593 **metadata : any
1597 further keys for extensions to the Trait (e.g. config)
1594 further keys for extensions to the Trait (e.g. config)
1598
1595
1599 """
1596 """
1600 self._minlen = minlen
1597 self._minlen = minlen
1601 self._maxlen = maxlen
1598 self._maxlen = maxlen
1602 super(List, self).__init__(trait=trait, default_value=default_value,
1599 super(List, self).__init__(trait=trait, default_value=default_value,
1603 **metadata)
1600 **metadata)
1604
1601
1605 def length_error(self, obj, value):
1602 def length_error(self, obj, value):
1606 e = "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." \
1603 e = "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." \
1607 % (self.name, class_of(obj), self._minlen, self._maxlen, value)
1604 % (self.name, class_of(obj), self._minlen, self._maxlen, value)
1608 raise TraitError(e)
1605 raise TraitError(e)
1609
1606
1610 def validate_elements(self, obj, value):
1607 def validate_elements(self, obj, value):
1611 length = len(value)
1608 length = len(value)
1612 if length < self._minlen or length > self._maxlen:
1609 if length < self._minlen or length > self._maxlen:
1613 self.length_error(obj, value)
1610 self.length_error(obj, value)
1614
1611
1615 return super(List, self).validate_elements(obj, value)
1612 return super(List, self).validate_elements(obj, value)
1616
1613
1617 def validate(self, obj, value):
1614 def validate(self, obj, value):
1618 value = super(List, self).validate(obj, value)
1615 value = super(List, self).validate(obj, value)
1619 value = self.validate_elements(obj, value)
1616 value = self.validate_elements(obj, value)
1620 return value
1617 return value
1621
1618
1622
1619
1623 class Set(List):
1620 class Set(List):
1624 """An instance of a Python set."""
1621 """An instance of a Python set."""
1625 klass = set
1622 klass = set
1626 _cast_types = (tuple, list)
1623 _cast_types = (tuple, list)
1627
1624
1628
1625
1629 class Tuple(Container):
1626 class Tuple(Container):
1630 """An instance of a Python tuple."""
1627 """An instance of a Python tuple."""
1631 klass = tuple
1628 klass = tuple
1632 _cast_types = (list,)
1629 _cast_types = (list,)
1633
1630
1634 def __init__(self, *traits, **metadata):
1631 def __init__(self, *traits, **metadata):
1635 """Tuple(*traits, default_value=None, **medatata)
1632 """Tuple(*traits, default_value=None, **medatata)
1636
1633
1637 Create a tuple from a list, set, or tuple.
1634 Create a tuple from a list, set, or tuple.
1638
1635
1639 Create a fixed-type tuple with Traits:
1636 Create a fixed-type tuple with Traits:
1640
1637
1641 ``t = Tuple(Int, Str, CStr)``
1638 ``t = Tuple(Int, Str, CStr)``
1642
1639
1643 would be length 3, with Int,Str,CStr for each element.
1640 would be length 3, with Int,Str,CStr for each element.
1644
1641
1645 If only one arg is given and it is not a Trait, it is taken as
1642 If only one arg is given and it is not a Trait, it is taken as
1646 default_value:
1643 default_value:
1647
1644
1648 ``t = Tuple((1,2,3))``
1645 ``t = Tuple((1,2,3))``
1649
1646
1650 Otherwise, ``default_value`` *must* be specified by keyword.
1647 Otherwise, ``default_value`` *must* be specified by keyword.
1651
1648
1652 Parameters
1649 Parameters
1653 ----------
1650 ----------
1654
1651
1655 *traits : TraitTypes [ optional ]
1652 *traits : TraitTypes [ optional ]
1656 the types for restricting the contents of the Tuple. If unspecified,
1653 the types for restricting the contents of the Tuple. If unspecified,
1657 types are not checked. If specified, then each positional argument
1654 types are not checked. If specified, then each positional argument
1658 corresponds to an element of the tuple. Tuples defined with traits
1655 corresponds to an element of the tuple. Tuples defined with traits
1659 are of fixed length.
1656 are of fixed length.
1660
1657
1661 default_value : SequenceType [ optional ]
1658 default_value : SequenceType [ optional ]
1662 The default value for the Tuple. Must be list/tuple/set, and
1659 The default value for the Tuple. Must be list/tuple/set, and
1663 will be cast to a tuple. If `traits` are specified, the
1660 will be cast to a tuple. If `traits` are specified, the
1664 `default_value` must conform to the shape and type they specify.
1661 `default_value` must conform to the shape and type they specify.
1665
1662
1666 allow_none : bool [ default False ]
1663 allow_none : bool [ default False ]
1667 Whether to allow the value to be None
1664 Whether to allow the value to be None
1668
1665
1669 **metadata : any
1666 **metadata : any
1670 further keys for extensions to the Trait (e.g. config)
1667 further keys for extensions to the Trait (e.g. config)
1671
1668
1672 """
1669 """
1673 default_value = metadata.pop('default_value', None)
1670 default_value = metadata.pop('default_value', None)
1674 allow_none = metadata.pop('allow_none', True)
1671 allow_none = metadata.pop('allow_none', True)
1675
1672
1676 # allow Tuple((values,)):
1673 # allow Tuple((values,)):
1677 if len(traits) == 1 and default_value is None and not is_trait(traits[0]):
1674 if len(traits) == 1 and default_value is None and not is_trait(traits[0]):
1678 default_value = traits[0]
1675 default_value = traits[0]
1679 traits = ()
1676 traits = ()
1680
1677
1681 if default_value is None:
1678 if default_value is None:
1682 args = ()
1679 args = ()
1683 elif isinstance(default_value, self._valid_defaults):
1680 elif isinstance(default_value, self._valid_defaults):
1684 args = (default_value,)
1681 args = (default_value,)
1685 else:
1682 else:
1686 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1683 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1687
1684
1688 self._traits = []
1685 self._traits = []
1689 for trait in traits:
1686 for trait in traits:
1690 t = trait() if isinstance(trait, type) else trait
1687 t = trait() if isinstance(trait, type) else trait
1691 t.name = 'element'
1688 t.name = 'element'
1692 self._traits.append(t)
1689 self._traits.append(t)
1693
1690
1694 if self._traits and default_value is None:
1691 if self._traits and default_value is None:
1695 # don't allow default to be an empty container if length is specified
1692 # don't allow default to be an empty container if length is specified
1696 args = None
1693 args = None
1697 super(Container,self).__init__(klass=self.klass, args=args, allow_none=allow_none, **metadata)
1694 super(Container,self).__init__(klass=self.klass, args=args, allow_none=allow_none, **metadata)
1698
1695
1699 def validate_elements(self, obj, value):
1696 def validate_elements(self, obj, value):
1700 if not self._traits:
1697 if not self._traits:
1701 # nothing to validate
1698 # nothing to validate
1702 return value
1699 return value
1703 if len(value) != len(self._traits):
1700 if len(value) != len(self._traits):
1704 e = "The '%s' trait of %s instance requires %i elements, but a value of %s was specified." \
1701 e = "The '%s' trait of %s instance requires %i elements, but a value of %s was specified." \
1705 % (self.name, class_of(obj), len(self._traits), repr_type(value))
1702 % (self.name, class_of(obj), len(self._traits), repr_type(value))
1706 raise TraitError(e)
1703 raise TraitError(e)
1707
1704
1708 validated = []
1705 validated = []
1709 for t, v in zip(self._traits, value):
1706 for t, v in zip(self._traits, value):
1710 try:
1707 try:
1711 v = t._validate(obj, v)
1708 v = t._validate(obj, v)
1712 except TraitError:
1709 except TraitError:
1713 self.element_error(obj, v, t)
1710 self.element_error(obj, v, t)
1714 else:
1711 else:
1715 validated.append(v)
1712 validated.append(v)
1716 return tuple(validated)
1713 return tuple(validated)
1717
1714
1718 def instance_init(self):
1715 def instance_init(self):
1719 for trait in self._traits:
1716 for trait in self._traits:
1720 if isinstance(trait, TraitType):
1717 if isinstance(trait, TraitType):
1721 trait.this_class = self.this_class
1718 trait.this_class = self.this_class
1722 trait.instance_init()
1719 trait.instance_init()
1723 super(Container, self).instance_init()
1720 super(Container, self).instance_init()
1724
1721
1725
1722
1726 class Dict(Instance):
1723 class Dict(Instance):
1727 """An instance of a Python dict."""
1724 """An instance of a Python dict."""
1728 _trait = None
1725 _trait = None
1729
1726
1730 def __init__(self, trait=None, default_value=NoDefaultSpecified, allow_none=False, **metadata):
1727 def __init__(self, trait=None, default_value=NoDefaultSpecified, allow_none=False, **metadata):
1731 """Create a dict trait type from a dict.
1728 """Create a dict trait type from a dict.
1732
1729
1733 The default value is created by doing ``dict(default_value)``,
1730 The default value is created by doing ``dict(default_value)``,
1734 which creates a copy of the ``default_value``.
1731 which creates a copy of the ``default_value``.
1735
1732
1736 trait : TraitType [ optional ]
1733 trait : TraitType [ optional ]
1737 the type for restricting the contents of the Container. If unspecified,
1734 the type for restricting the contents of the Container. If unspecified,
1738 types are not checked.
1735 types are not checked.
1739
1736
1740 default_value : SequenceType [ optional ]
1737 default_value : SequenceType [ optional ]
1741 The default value for the Dict. Must be dict, tuple, or None, and
1738 The default value for the Dict. Must be dict, tuple, or None, and
1742 will be cast to a dict if not None. If `trait` is specified, the
1739 will be cast to a dict if not None. If `trait` is specified, the
1743 `default_value` must conform to the constraints it specifies.
1740 `default_value` must conform to the constraints it specifies.
1744
1741
1745 allow_none : bool [ default False ]
1742 allow_none : bool [ default False ]
1746 Whether to allow the value to be None
1743 Whether to allow the value to be None
1747
1744
1748 """
1745 """
1749 if default_value is NoDefaultSpecified and trait is not None:
1746 if default_value is NoDefaultSpecified and trait is not None:
1750 if not is_trait(trait):
1747 if not is_trait(trait):
1751 default_value = trait
1748 default_value = trait
1752 trait = None
1749 trait = None
1753 if default_value is NoDefaultSpecified:
1750 if default_value is NoDefaultSpecified:
1754 default_value = {}
1751 default_value = {}
1755 if default_value is None:
1752 if default_value is None:
1756 args = None
1753 args = None
1757 elif isinstance(default_value, dict):
1754 elif isinstance(default_value, dict):
1758 args = (default_value,)
1755 args = (default_value,)
1759 elif isinstance(default_value, SequenceTypes):
1756 elif isinstance(default_value, SequenceTypes):
1760 args = (default_value,)
1757 args = (default_value,)
1761 else:
1758 else:
1762 raise TypeError('default value of Dict was %s' % default_value)
1759 raise TypeError('default value of Dict was %s' % default_value)
1763
1760
1764 if is_trait(trait):
1761 if is_trait(trait):
1765 self._trait = trait() if isinstance(trait, type) else trait
1762 self._trait = trait() if isinstance(trait, type) else trait
1766 self._trait.name = 'element'
1763 self._trait.name = 'element'
1767 elif trait is not None:
1764 elif trait is not None:
1768 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1765 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1769
1766
1770 super(Dict,self).__init__(klass=dict, args=args,
1767 super(Dict,self).__init__(klass=dict, args=args,
1771 allow_none=allow_none, **metadata)
1768 allow_none=allow_none, **metadata)
1772
1769
1773 def element_error(self, obj, element, validator):
1770 def element_error(self, obj, element, validator):
1774 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1771 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1775 % (self.name, class_of(obj), validator.info(), repr_type(element))
1772 % (self.name, class_of(obj), validator.info(), repr_type(element))
1776 raise TraitError(e)
1773 raise TraitError(e)
1777
1774
1778 def validate(self, obj, value):
1775 def validate(self, obj, value):
1779 value = super(Dict, self).validate(obj, value)
1776 value = super(Dict, self).validate(obj, value)
1780 if value is None:
1777 if value is None:
1781 return value
1778 return value
1782 value = self.validate_elements(obj, value)
1779 value = self.validate_elements(obj, value)
1783 return value
1780 return value
1784
1781
1785 def validate_elements(self, obj, value):
1782 def validate_elements(self, obj, value):
1786 if self._trait is None or isinstance(self._trait, Any):
1783 if self._trait is None or isinstance(self._trait, Any):
1787 return value
1784 return value
1788 validated = {}
1785 validated = {}
1789 for key in value:
1786 for key in value:
1790 v = value[key]
1787 v = value[key]
1791 try:
1788 try:
1792 v = self._trait._validate(obj, v)
1789 v = self._trait._validate(obj, v)
1793 except TraitError:
1790 except TraitError:
1794 self.element_error(obj, v, self._trait)
1791 self.element_error(obj, v, self._trait)
1795 else:
1792 else:
1796 validated[key] = v
1793 validated[key] = v
1797 return self.klass(validated)
1794 return self.klass(validated)
1798
1795
1799 def instance_init(self):
1796 def instance_init(self):
1800 if isinstance(self._trait, TraitType):
1797 if isinstance(self._trait, TraitType):
1801 self._trait.this_class = self.this_class
1798 self._trait.this_class = self.this_class
1802 self._trait.instance_init()
1799 self._trait.instance_init()
1803 super(Dict, self).instance_init()
1800 super(Dict, self).instance_init()
1804
1801
1805
1802
1806 class EventfulDict(Instance):
1803 class EventfulDict(Instance):
1807 """An instance of an EventfulDict."""
1804 """An instance of an EventfulDict."""
1808
1805
1809 def __init__(self, default_value={}, allow_none=False, **metadata):
1806 def __init__(self, default_value={}, allow_none=False, **metadata):
1810 """Create a EventfulDict trait type from a dict.
1807 """Create a EventfulDict trait type from a dict.
1811
1808
1812 The default value is created by doing
1809 The default value is created by doing
1813 ``eventful.EvenfulDict(default_value)``, which creates a copy of the
1810 ``eventful.EvenfulDict(default_value)``, which creates a copy of the
1814 ``default_value``.
1811 ``default_value``.
1815 """
1812 """
1816 if default_value is None:
1813 if default_value is None:
1817 args = None
1814 args = None
1818 elif isinstance(default_value, dict):
1815 elif isinstance(default_value, dict):
1819 args = (default_value,)
1816 args = (default_value,)
1820 elif isinstance(default_value, SequenceTypes):
1817 elif isinstance(default_value, SequenceTypes):
1821 args = (default_value,)
1818 args = (default_value,)
1822 else:
1819 else:
1823 raise TypeError('default value of EventfulDict was %s' % default_value)
1820 raise TypeError('default value of EventfulDict was %s' % default_value)
1824
1821
1825 super(EventfulDict, self).__init__(klass=eventful.EventfulDict, args=args,
1822 super(EventfulDict, self).__init__(klass=eventful.EventfulDict, args=args,
1826 allow_none=allow_none, **metadata)
1823 allow_none=allow_none, **metadata)
1827
1824
1828
1825
1829 class EventfulList(Instance):
1826 class EventfulList(Instance):
1830 """An instance of an EventfulList."""
1827 """An instance of an EventfulList."""
1831
1828
1832 def __init__(self, default_value=None, allow_none=False, **metadata):
1829 def __init__(self, default_value=None, allow_none=False, **metadata):
1833 """Create a EventfulList trait type from a dict.
1830 """Create a EventfulList trait type from a dict.
1834
1831
1835 The default value is created by doing
1832 The default value is created by doing
1836 ``eventful.EvenfulList(default_value)``, which creates a copy of the
1833 ``eventful.EvenfulList(default_value)``, which creates a copy of the
1837 ``default_value``.
1834 ``default_value``.
1838 """
1835 """
1839 if default_value is None:
1836 if default_value is None:
1840 args = ((),)
1837 args = ((),)
1841 else:
1838 else:
1842 args = (default_value,)
1839 args = (default_value,)
1843
1840
1844 super(EventfulList, self).__init__(klass=eventful.EventfulList, args=args,
1841 super(EventfulList, self).__init__(klass=eventful.EventfulList, args=args,
1845 allow_none=allow_none, **metadata)
1842 allow_none=allow_none, **metadata)
1846
1843
1847
1844
1848 class TCPAddress(TraitType):
1845 class TCPAddress(TraitType):
1849 """A trait for an (ip, port) tuple.
1846 """A trait for an (ip, port) tuple.
1850
1847
1851 This allows for both IPv4 IP addresses as well as hostnames.
1848 This allows for both IPv4 IP addresses as well as hostnames.
1852 """
1849 """
1853
1850
1854 default_value = ('127.0.0.1', 0)
1851 default_value = ('127.0.0.1', 0)
1855 info_text = 'an (ip, port) tuple'
1852 info_text = 'an (ip, port) tuple'
1856
1853
1857 def validate(self, obj, value):
1854 def validate(self, obj, value):
1858 if isinstance(value, tuple):
1855 if isinstance(value, tuple):
1859 if len(value) == 2:
1856 if len(value) == 2:
1860 if isinstance(value[0], py3compat.string_types) and isinstance(value[1], int):
1857 if isinstance(value[0], py3compat.string_types) and isinstance(value[1], int):
1861 port = value[1]
1858 port = value[1]
1862 if port >= 0 and port <= 65535:
1859 if port >= 0 and port <= 65535:
1863 return value
1860 return value
1864 self.error(obj, value)
1861 self.error(obj, value)
1865
1862
1866 class CRegExp(TraitType):
1863 class CRegExp(TraitType):
1867 """A casting compiled regular expression trait.
1864 """A casting compiled regular expression trait.
1868
1865
1869 Accepts both strings and compiled regular expressions. The resulting
1866 Accepts both strings and compiled regular expressions. The resulting
1870 attribute will be a compiled regular expression."""
1867 attribute will be a compiled regular expression."""
1871
1868
1872 info_text = 'a regular expression'
1869 info_text = 'a regular expression'
1873
1870
1874 def validate(self, obj, value):
1871 def validate(self, obj, value):
1875 try:
1872 try:
1876 return re.compile(value)
1873 return re.compile(value)
1877 except:
1874 except:
1878 self.error(obj, value)
1875 self.error(obj, value)
General Comments 0
You need to be logged in to leave comments. Login now