##// END OF EJS Templates
Adding decorator forms of interact. Yeah!
Brian E. Granger -
Show More
@@ -1,258 +1,276
1 1 """Interact with functions using widgets."""
2 2
3 3 #-----------------------------------------------------------------------------
4 4 # Copyright (c) 2013, the IPython Development Team.
5 5 #
6 6 # Distributed under the terms of the Modified BSD License.
7 7 #
8 8 # The full license is in the file COPYING.txt, distributed with this software.
9 9 #-----------------------------------------------------------------------------
10 10
11 11 #-----------------------------------------------------------------------------
12 12 # Imports
13 13 #-----------------------------------------------------------------------------
14 14
15 15 from __future__ import print_function
16 16
17 17 try: # Python >= 3.3
18 18 from inspect import signature, Parameter
19 19 except ImportError:
20 20 from IPython.utils.signatures import signature, Parameter
21 21 from inspect import getcallargs
22 22
23 23 from IPython.html.widgets import (Widget, TextWidget,
24 24 FloatSliderWidget, IntSliderWidget, CheckboxWidget, DropdownWidget,
25 25 ContainerWidget)
26 26 from IPython.display import display, clear_output
27 27 from IPython.utils.py3compat import string_types, unicode_type
28 28
29 29 #-----------------------------------------------------------------------------
30 30 # Classes and Functions
31 31 #-----------------------------------------------------------------------------
32 32
33 33
34 34 def _matches(o, pattern):
35 35 """Match a pattern of types in a sequence."""
36 36 if not len(o) == len(pattern):
37 37 return False
38 38 comps = zip(o,pattern)
39 39 return all(isinstance(obj,kind) for obj,kind in comps)
40 40
41 41
42 42 def _get_min_max_value(min, max, value):
43 43 """Return min, max, value given input values with possible None."""
44 44 if value is None:
45 45 if not max > min:
46 46 raise ValueError('max must be greater than min: (min={0}, max={1})'.format(min, max))
47 47 value = min + abs(min-max)/2
48 48 value = type(min)(value)
49 49 elif min is None and max is None:
50 50 if value == 0.0:
51 51 min, max, value = 0.0, 1.0, 0.5
52 52 elif value == 0:
53 53 min, max, value = 0, 1, 0
54 54 elif isinstance(value, float):
55 55 min, max = (-value, 3.0*value) if value > 0 else (3.0*value, -value)
56 56 elif isinstance(value, int):
57 57 min, max = (-value, 3*value) if value > 0 else (3*value, -value)
58 58 else:
59 59 raise TypeError('expected a number, got: %r' % value)
60 60 else:
61 61 raise ValueError('unable to infer range, value from: ({0}, {1}, {2})'.format(min, max, value))
62 62 return min, max, value
63 63
64 64 def _widget_abbrev_single_value(o):
65 65 """Make widgets from single values, which can be used written as parameter defaults."""
66 66 if isinstance(o, string_types):
67 67 return TextWidget(value=unicode_type(o))
68 68 elif isinstance(o, dict):
69 69 labels = [unicode_type(k) for k in o]
70 70 values = o.values()
71 71 w = DropdownWidget(value=values[0], values=values, labels=labels)
72 72 return w
73 73 elif isinstance(o, bool):
74 74 return CheckboxWidget(value=o)
75 75 elif isinstance(o, float):
76 76 min, max, value = _get_min_max_value(None, None, o)
77 77 return FloatSliderWidget(value=o, min=min, max=max)
78 78 elif isinstance(o, int):
79 79 min, max, value = _get_min_max_value(None, None, o)
80 80 return IntSliderWidget(value=o, min=min, max=max)
81 81 else:
82 82 return None
83 83
84 84 def _widget_abbrev(o):
85 85 """Make widgets from abbreviations: single values, lists or tuples."""
86 86 if isinstance(o, (list, tuple)):
87 87 if _matches(o, (int, int)):
88 88 min, max, value = _get_min_max_value(o[0], o[1], None)
89 89 return IntSliderWidget(value=value, min=min, max=max)
90 90 elif _matches(o, (int, int, int)):
91 91 min, max, value = _get_min_max_value(o[0], o[1], None)
92 92 return IntSliderWidget(value=value, min=min, max=max, step=o[2])
93 93 elif _matches(o, (float, float)):
94 94 min, max, value = _get_min_max_value(o[0], o[1], None)
95 95 return FloatSliderWidget(value=value, min=min, max=max)
96 96 elif _matches(o, (float, float, float)):
97 97 min, max, value = _get_min_max_value(o[0], o[1], None)
98 98 return FloatSliderWidget(value=value, min=min, max=max, step=o[2])
99 99 elif _matches(o, (float, float, int)):
100 100 min, max, value = _get_min_max_value(o[0], o[1], None)
101 101 return FloatSliderWidget(value=value, min=min, max=max, step=float(o[2]))
102 102 elif all(isinstance(x, string_types) for x in o):
103 103 return DropdownWidget(value=unicode_type(o[0]),
104 104 values=[unicode_type(k) for k in o])
105 105 else:
106 106 return _widget_abbrev_single_value(o)
107 107
108 108 def _widget_from_abbrev(abbrev):
109 109 """Build a Widget intstance given an abbreviation or Widget."""
110 110 if isinstance(abbrev, Widget):
111 111 return abbrev
112 112
113 113 widget = _widget_abbrev(abbrev)
114 114 if widget is None:
115 115 raise ValueError("%r cannot be transformed to a Widget" % abbrev)
116 116 return widget
117 117
118 118 def _yield_abbreviations_for_parameter(param, args, kwargs):
119 119 """Get an abbreviation for a function parameter."""
120 120 # print(param, args, kwargs)
121 121 name = param.name
122 122 kind = param.kind
123 123 ann = param.annotation
124 124 default = param.default
125 125 empty = Parameter.empty
126 126 if kind == Parameter.POSITIONAL_ONLY:
127 127 if args:
128 128 yield name, args.pop(0), False
129 129 elif ann is not empty:
130 130 yield name, ann, False
131 131 else:
132 132 yield None, None, None
133 133 elif kind == Parameter.POSITIONAL_OR_KEYWORD:
134 134 if name in kwargs:
135 135 yield name, kwargs.pop(name), True
136 136 elif args:
137 137 yield name, args.pop(0), False
138 138 elif ann is not empty:
139 139 if default is empty:
140 140 yield name, ann, False
141 141 else:
142 142 yield name, ann, True
143 143 elif default is not empty:
144 144 yield name, default, True
145 145 else:
146 146 yield None, None, None
147 147 elif kind == Parameter.VAR_POSITIONAL:
148 148 # In this case name=args or something and we don't actually know the names.
149 149 for item in args[::]:
150 150 args.pop(0)
151 151 yield '', item, False
152 152 elif kind == Parameter.KEYWORD_ONLY:
153 153 if name in kwargs:
154 154 yield name, kwargs.pop(name), True
155 155 elif ann is not empty:
156 156 yield name, ann, True
157 157 elif default is not empty:
158 158 yield name, default, True
159 159 else:
160 160 yield None, None, None
161 161 elif kind == Parameter.VAR_KEYWORD:
162 162 # In this case name=kwargs and we yield the items in kwargs with their keys.
163 163 for k, v in kwargs.copy().items():
164 164 kwargs.pop(k)
165 165 yield k, v, True
166 166
167 167 def _find_abbreviations(f, args, kwargs):
168 168 """Find the abbreviations for a function and args/kwargs passed to interact."""
169 169 new_args = []
170 170 new_kwargs = []
171 171 for param in signature(f).parameters.values():
172 172 for name, value, kw in _yield_abbreviations_for_parameter(param, args, kwargs):
173 173 if value is None:
174 174 raise ValueError('cannot find widget or abbreviation for argument: {!r}'.format(name))
175 175 if kw:
176 176 new_kwargs.append((name, value))
177 177 else:
178 178 new_args.append((name, value))
179 179 return new_args, new_kwargs
180 180
181 181 def _widgets_from_abbreviations(seq):
182 182 """Given a sequence of (name, abbrev) tuples, return a sequence of Widgets."""
183 183 result = []
184 184 for name, abbrev in seq:
185 185 widget = _widget_from_abbrev(abbrev)
186 186 widget.description = name
187 187 result.append(widget)
188 188 return result
189 189
190 190 def interactive(f, *args, **kwargs):
191 191 """Build a group of widgets to interact with a function."""
192 192 co = kwargs.pop('clear_output', True)
193 193 args_widgets = []
194 194 kwargs_widgets = []
195 195 container = ContainerWidget()
196 196 container.result = None
197 197 container.args = []
198 198 container.kwargs = dict()
199 199 # We need this to be a list as we iteratively pop elements off it
200 200 args = list(args)
201 201 kwargs = kwargs.copy()
202 202
203 203 new_args, new_kwargs = _find_abbreviations(f, args, kwargs)
204 204 # Before we proceed, let's make sure that the user has passed a set of args+kwargs
205 205 # that will lead to a valid call of the function. This protects against unspecified
206 206 # and doubly-specified arguments.
207 207 getcallargs(f, *[v for n,v in new_args], **{n:v for n,v in new_kwargs})
208 208 # Now build the widgets from the abbreviations.
209 209 args_widgets.extend(_widgets_from_abbreviations(new_args))
210 210 kwargs_widgets.extend(_widgets_from_abbreviations(new_kwargs))
211 211 kwargs_widgets.extend(_widgets_from_abbreviations(sorted(kwargs.items(), key = lambda x: x[0])))
212 212
213 213 # This has to be done as an assignment, not using container.children.append,
214 214 # so that traitlets notices the update.
215 215 container.children = args_widgets + kwargs_widgets
216 216
217 217 # Build the callback
218 218 def call_f(name, old, new):
219 219 container.args = []
220 220 for widget in args_widgets:
221 221 value = widget.value
222 222 container.args.append(value)
223 223 for widget in kwargs_widgets:
224 224 value = widget.value
225 225 container.kwargs[widget.description] = value
226 226 if co:
227 227 clear_output(wait=True)
228 228 container.result = f(*container.args, **container.kwargs)
229 229
230 230 # Wire up the widgets
231 231 for widget in args_widgets:
232 232 widget.on_trait_change(call_f, 'value')
233 233 for widget in kwargs_widgets:
234 234 widget.on_trait_change(call_f, 'value')
235 235
236 236 container.on_displayed(lambda _: call_f(None, None, None))
237 237
238 238 return container
239 239
240 def interact(f, *args, **kwargs):
240 def interact(*args, **kwargs):
241 241 """Interact with a function using widgets."""
242 w = interactive(f, *args, **kwargs)
243 f.widget = w
244 display(w)
242 if args and callable(args[0]):
243 # This branch handles the cases:
244 # 1. interact(f, *args, **kwargs)
245 # 2. @interact
246 # def f(*args, **kwargs):
247 # ...
248 f = args[0]
249 w = interactive(f, *args[1:], **kwargs)
250 f.widget = w
251 display(w)
252 else:
253 # This branch handles the case:
254 # @interact(10, 20, a=30, b=40)
255 # def f(*args, **kwargs):
256 # ...
257 def dec(f):
258 w = interactive(f, *args, **kwargs)
259 f.widget = w
260 display(w)
261 return f
262 return dec
245 263
246 264 def annotate(**kwargs):
247 265 """Python 3 compatible function annotation for Python 2."""
248 266 if not kwargs:
249 267 raise ValueError('annotations must be provided as keyword arguments')
250 268 def dec(f):
251 269 if hasattr(f, '__annotations__'):
252 270 for k, v in kwargs.items():
253 271 f.__annotations__[k] = v
254 272 else:
255 273 f.__annotations__ = kwargs
256 274 return f
257 275 return dec
258 276
General Comments 0
You need to be logged in to leave comments. Login now